cvtools.data_augs.crop.crop 源代码

# -*- encoding:utf-8 -*-
# @Time    : 2019/11/20 16:12
# @Author  : gfjiang
# @Site    : 
# @File    : crop.py
# @Software: PyCharm
"""顶层通用Crop类,此类无须修改。只需将CropDataset和CropMethod传递给此类。
支持对少样本类别过采样,支持自定义裁剪数据集,支持自定义裁剪方法"""
import os.path as osp
from collections import defaultdict

import cvtools
from cvtools.data_augs.crop.crop_abc import Crop
from cvtools.data_augs.crop.crop_method import CropImageProtected


[文档]class CropLargeImages(Crop): def __init__(self, dataset, crop_method, over_strict=True): self.dataset = dataset self.crop_method = crop_method self.ovover_strict = over_strict self.crops = [] self.cat_id_to_name = { cat['id']: cat['name'] for cat in self.dataset.crop_dataset['categories'] } self.crop_for_protected = CropImageProtected(strict=over_strict)
[文档] def crop_for_train(self, over_samples=None): """训练集裁剪 Args: over_samples (dict): {类别: 重采样次数, ...} """ for i in range(len(self.dataset)): data = self.dataset[i] # 索引或迭代dataset必须提供包含image字段和anns字段信息 anns = data['anns'] if len(anns) == 0: print('{} has no label'.format(data['image'])) img = cvtools.imread(data['image']) self.crop_method.crop(img, anns) # croped可为空,即没有任何裁剪,同时原始图亦不保留 cropped = self.crop_method.match_anns(anns) # 过采样扩展,对少样本类别过采样 if over_samples is not None: add_croped = self.over_sample(img, anns, over_samples) cropped.update(add_croped) self.crops.append(cropped) print('crop image %d of %d: %s' % (i, len(self.dataset), osp.basename(data['image']))) # 打印和清空统计信息 print(self.crop_method.get_stats()) self.crop_method.reset_stats() if hasattr(self.crop_method, 'stats_crop'): print(self.crop_method.stats_crop) self.crop_method.stats_crop = {}
[文档] def over_sample(self, img, anns, over_samples): add_crops = defaultdict() # self.crop_for_protected.size_th = max(img.shape[:2]) for over_cat in over_samples: # 选出少样本类别实例 protected_anns, protected_ann_ids = [], [] for ann_index, ann in enumerate(anns): if self.cat_id_to_name[ann['category_id']] == over_cat: protected_anns.append(ann) protected_ann_ids.append(ann_index) if len(protected_anns) == 0: continue for _ in range(over_samples[over_cat]): if len(self.crop_for_protected(img, protected_anns)): add_crop = self.crop_for_protected.match_anns( anns) # fix bug! must using all anns add_crops.update(add_crop) else: print('Protection cropping failure!') return add_crops
[文档] def crop_for_test(self): pass
[文档] def save(self, to_file, limit_border=False): self.dataset.save(self.crops, to_file, limit_border) self.crops = []