cvtools.data_augs.crop.crop_dataset 源代码

# -*- encoding:utf-8 -*-
# @Time    : 2019/11/20 14:53
# @Author  : gfjiang
# @Site    : 
# @File    : crop_dataset.py
# @Software: PyCharm
"""所有需要裁剪的数据集需继承CropDataset,必须实现__getitem__和__len__接口"""
import os.path as osp
import numpy as np
import copy

import cvtools
from cvtools.data_augs.crop.crop_abc import CropDataset
from cvtools.data_augs.crop.crop_abc import cal_iof
from cvtools.utils.boxes import x1y1wh_to_x1y1x2y2x3y3x4y4
from cvtools.utils.boxes import trans_polygon_to_rbox
from cvtools.utils.boxes import cut_polygon


"""
迭代器是一种最简单也最常见的设计模式。
它可以让用户透过特定的接口巡访容器中的每一个元素而不用了解底层的实现。
"""


[文档]class CocoDatasetForCrop(CropDataset): def __init__(self, img_prefix, ann_file): super(CocoDatasetForCrop, self).__init__() self.img_prefix = img_prefix self.ann_file = ann_file self.crop_dataset = cvtools.load_json(ann_file) self.COCO = cvtools.COCO(ann_file) image_ids = self.COCO.getImgIds() image_ids.sort() self.roidb = self.COCO.loadImgs(image_ids) if cvtools._DEBUG: self.roidb = self.roidb[:cvtools._NUM_DATA] def __getitem__(self, item): entry = self.roidb[item] image_name = entry['file_name'] image_file = osp.join(self.img_prefix, image_name) ann_ids = self.COCO.getAnnIds(imgIds=entry['id'], iscrowd=None) anns = self.COCO.loadAnns(ann_ids) anns = [ { 'id': ann['id'], 'bbox': ann['bbox'], 'segmentation': ann['segmentation'], 'category_id': ann['category_id'] } # 此处可以订制 for ann in anns] return {'image': image_file, 'anns': anns} def __len__(self): return len(self.roidb)
[文档] def save(self, crops, to_file, limit_border=False): """通过自然索引对齐两组数据要小心""" assert len(self.roidb) == len(crops) new_images = [] new_annotations = [] image_id = 1 ann_id = 1 for image_i in range(len(self.roidb)): image_info = self.roidb[image_i] ann_ids = self.COCO.getAnnIds(imgIds=image_info['id'], iscrowd=None) anns = self.COCO.loadAnns(ann_ids) if not crops[image_i]: continue for img_box, ann_indexes in crops[image_i].items(): new_image_info = copy.deepcopy(image_info) new_image_info['crop'] = img_box new_image_info['id'] = image_id new_images.append(new_image_info) crop_anns = [anns[index] for index in ann_indexes] # 不能修改原始数据,因为同一个ann可能分布在多个图片中 crop_anns = copy.deepcopy(crop_anns) if limit_border: self.recalc_anns(img_box, crop_anns) for ann in crop_anns: self.trans_ann(ann, img_box) ann['id'] = ann_id ann['image_id'] = image_id new_annotations.append(ann) ann_id += 1 image_id += 1 self.crop_dataset['images'] = new_images self.crop_dataset['annotations'] = new_annotations cvtools.dump_json(self.crop_dataset, to_file)
[文档] def trans_ann(self, ann, img_box): segm = np.array(ann['segmentation'][0]).reshape(-1, 2) new_segm = segm - np.array(img_box[:2]).reshape(-1, 2) ann['segmentation'] = [new_segm.reshape(-1).tolist()] ann['bbox'][0] -= img_box[0] ann['bbox'][1] -= img_box[1]
[文档] def recalc_anns(self, img_box, anns): segms = np.array([ann['segmentation'][0] for ann in anns]) boxes = [ann['bbox'] for ann in anns] iof = cal_iof(boxes, [img_box])[..., -1] cutted_segms_ids = np.where(iof < 1)[0] for id in cutted_segms_ids: segm = segms[id] img_box_polygon = np.array(x1y1wh_to_x1y1x2y2x3y3x4y4( cvtools.x1y1x2y2_to_x1y1wh(img_box))).reshape(-1, 2) # 注意这里polygon不一定为四边形 inters, bounds = cut_polygon(segm, img_box_polygon) # TODO: 这里通用性较差 if len(inters.reshape(-1).tolist()) != len(segm): inters = trans_polygon_to_rbox(inters) anns[id]['segmentation'] = [inters] anns[id]['bbox'] = cvtools.x1y1x2y2_to_x1y1wh(list(bounds))