cvtools.label_convert.coco.split_coco 源代码

# -*- coding:utf-8 -*-
# author   : gfjiangly
# time     : 2019/7/17 16:51
# e-mail   : jgf0719@foxmail.com
# software : PyCharm
import os.path as osp

import cvtools
from cvtools.utils.misc import sort_dict


[文档]class SplitDataset(object): """coco-like datasets analysis""" def __init__(self, ann_file): self.ann_file = ann_file self.coco_dataset = cvtools.load_json(ann_file) self.COCO = cvtools.COCO(ann_file)
[文档] def split_dataset(self, val_size=0.1, to_file=None): imgs_train, imgs_val = cvtools.split_dict(self.COCO.imgs, val_size) print('images: {} train, {} test.'.format(len(imgs_train), len(imgs_val))) # deal train data train = dict(info=self.coco_dataset['info'], categories=self.coco_dataset['categories']) train['images'] = list(imgs_train.values()) # bad design anns = [] for key in imgs_train.keys(): anns += self.COCO.imgToAnns[key] train['annotations'] = anns # deal test data val = dict(info=train['info'], categories=train['categories']) val['images'] = list(imgs_val.values()) anns = [] for key in imgs_val.keys(): anns += self.COCO.imgToAnns[key] val['annotations'] = anns if to_file: path, name = osp.split(to_file) cvtools.dump_json(train, to_file=osp.join(path, 'train_'+name)) cvtools.dump_json(val, to_file=osp.join(path, 'val_'+name)) return val, train
[文档] def split_cats(self): catToImgs = sort_dict(self.COCO.catToImgs) self.catToDatasets = [] for cat, img_ids in catToImgs.items(): img_ids = set(img_ids) categories = [cat_info for cat_info in self.coco_dataset['categories'] if cat_info['id'] == cat] images = [img_info for img_info in self.coco_dataset['images'] if img_info['id'] in img_ids] annotations = [ann_info for ann_info in self.coco_dataset['annotations'] if ann_info['category_id'] == cat] self.catToDatasets.append( {'info': self.coco_dataset['info'], 'categories': categories, 'images': images, 'annotations': annotations} )
[文档] def save_cat_datasets(self, to_file): for dataset in self.catToDatasets: cvtools.dump_json( dataset, to_file=to_file.format( dataset['categories'][0]['name']) )
if __name__ == '__main__': ann_file = 'dota/test_dota+crop800x800.json' split_data = SplitDataset(ann_file) split_data.split_cats() split_data.save_cat_datasets( to_file='dota/test_dota_{}+crop800x800.json' )