# ----------------------------------------------------------
# From mmdetection
# Modified by jiang.g.f
# ----------------------------------------------------------
import numpy as np
from .soft_nms_cpu import soft_nms_cpu
[文档]def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
if isinstance(dets, torch.Tensor):
is_tensor = True
dets_np = dets.detach().cpu().numpy()
elif isinstance(dets, np.ndarray):
is_tensor = False
dets_np = dets
else:
raise TypeError(
'dets must be either a Tensor or numpy array, but got {}'.format(
type(dets)))
method_codes = {'linear': 1, 'gaussian': 2}
if method not in method_codes:
raise ValueError('Invalid method for SoftNMS: {}'.format(method))
new_dets, inds = soft_nms_cpu(
dets_np,
iou_thr,
method=method_codes[method],
sigma=sigma,
min_score=min_score)
if is_tensor:
return dets.new_tensor(new_dets), dets.new_tensor(
inds, dtype=torch.long)
else:
return new_dets.astype(np.float32), inds.astype(np.int64)