# -*- encoding:utf-8 -*-
# @Time : 2019/12/26 16:00
# @Author : jiang.g.f
# @File : coco_to_dets.py
# @Software: PyCharm
import numpy as np
from collections import defaultdict
import cv2.cv2 as cv
import cvtools
[文档]class COCO2Dets(object):
"""
将DOTA-COCO兼容格式GT转成检测结果表达形式results,保存成pkl
results:
{
image_id: dets, # image_id必须是anns中有效的id
image_id: dets,
...
}
dets:
{
cls_id:[[位置坐标,得分], [...], ...],
cls_id: [[位置坐标,得分], [...], ...],
...
},
"""
def __init__(self, anns_file, num_coors=4):
assert num_coors in (4, 8), "不支持的检测位置表示"
self.coco = anns_file
if cvtools.is_str(anns_file):
self.coco = cvtools.COCO(anns_file)
self.results = defaultdict() # 动态创建嵌套字典
self.num_coors = num_coors
[文档] def handle_ann(self, ann):
"""如果想自定义ann处理方式,继承此类,然后重新实现此方法"""
if self.num_coors == 4:
bboxes = cvtools.x1y1wh_to_x1y1x2y2(ann['bbox'])
elif self.num_coors == 8:
segm = ann['segmentation'][0]
if len(segm) != 8:
segm_hull = cv.convexHull(
np.array(segm).reshape(-1, 2).astype(np.float32),
clockwise=False)
xywha = cv.minAreaRect(segm_hull)
segm = cv.boxPoints(xywha).reshape(-1).tolist()
bboxes = segm
else:
raise RuntimeError('不支持的坐标数!')
return bboxes + [1.]
[文档] def convert(self, to_file=None):
for img_id, img_info in self.coco.imgs.items():
dets = defaultdict(list)
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
anns = self.coco.loadAnns(ann_ids)
for ann in anns:
dets[ann['category_id']].append(self.handle_ann(ann))
for cls, det in dets.items():
dets[cls] = np.array(det, dtype=np.float)
self.results[img_id] = dets
if to_file is not None:
self.save_pkl(to_file)
return self.results
[文档] def save_pkl(self, to_file):
cvtools.dump_pkl(self.results, to_file)
[文档] def save_json(self, to_file):
cvtools.dump_json(self.results, to_file)