#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved
import contextlib
import copy
import io
import logging
import os
import os.path as osp
import numpy as np
import torch
from cvpods.structures import BoxMode
from cvpods.utils import PathManager, Timer
from ..base_dataset import BaseDataset
from ..detection_utils import (
annotations_to_instances,
check_image_size,
create_keypoint_hflip_indices,
filter_empty_instances,
read_image
)
from ..registry import DATASETS
from .paths_route import _PREDEFINED_SPLITS_CITYPERSONS
"""
This file contains functions to parse COCO-format annotations into dicts in "cvpods format".
"""
logger = logging.getLogger(__name__)
[docs]@DATASETS.register()
class CityPersonsDataset(BaseDataset):
def __init__(self, cfg, dataset_name, transforms=[], is_train=True):
super(CityPersonsDataset, self).__init__(cfg, dataset_name, transforms, is_train)
image_root, json_file = _PREDEFINED_SPLITS_CITYPERSONS["citypersons"][self.name]
self.json_file = osp.join(self.data_root, json_file) \
if "://" not in image_root else osp.join(image_root, json_file)
self.image_root = osp.join(self.data_root, image_root) \
if "://" not in image_root else image_root
self.meta = self._get_metadata()
self.dataset_dicts = self._load_annotations(
self.json_file,
self.image_root,
self.name,
extra_annotation_keys=None)
# fmt: off
self.data_format = cfg.INPUT.FORMAT
self.mask_on = cfg.MODEL.MASK_ON
self.mask_format = cfg.INPUT.MASK_FORMAT
self.filter_empty = cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
self.proposal_files = cfg.DATASETS.PROPOSAL_FILES_TRAIN
# fmt: on
if is_train:
self.dataset_dicts = self._filter_annotations(
filter_empty=self.filter_empty,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if self.keypoint_on else 0,
proposal_files=self.proposal_files if self.load_proposals else None,
)
self._set_group_flag()
self.eval_with_gt = cfg.TEST.get("WITH_GT", False)
if self.keypoint_on:
# Flip only makes sense in training
self.keypoint_hflip_indices = create_keypoint_hflip_indices(
cfg.DATASETS.TRAIN)
else:
self.keypoint_hflip_indices = None
[docs] def __getitem__(self, index):
"""Load data, apply transforms, converto to Instances.
"""
dataset_dict = copy.deepcopy(self.dataset_dicts[index])
# read image
image = read_image(dataset_dict["file_name"], format=self.data_format)
check_image_size(dataset_dict, image)
if "annotations" in dataset_dict:
annotations = dataset_dict.pop("annotations")
annotations = [
ann for ann in annotations if ann.get("iscrowd", 0) == 0]
else:
annotations = None
# apply transfrom
image, annotations = self._apply_transforms(
image, annotations)
if annotations is not None:
image_shape = image.shape[:2] # h, w
instances = annotations_to_instances(
annotations, image_shape, mask_format=self.mask_format
)
# # Create a tight bounding box from masks, useful when image is cropped
# if self.crop_gen and instances.has("gt_masks"):
# instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = filter_empty_instances(instances)
# convert to Instance type
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
# h, w, c -> c, h, w
dataset_dict["image"] = torch.as_tensor(
np.ascontiguousarray(image.transpose(2, 0, 1)))
return dataset_dict
def __reset__(self):
raise NotImplementedError
def __len__(self):
return len(self.dataset_dicts)
def _load_annotations( # noqa
self,
json_file,
image_root,
dataset_name=None,
extra_annotation_keys=None
):
"""
Load a json file with COCO's instances annotation format.
Currently supports instance detection, instance segmentation,
and person keypoints annotations.
Args:
json_file (str): full path to the json file in COCO instances annotation format.
image_root (str): the directory where the images in this json file exists.
dataset_name (str): the name of the dataset (e.g., coco_2017_train).
If provided, this function will also put "thing_classes" into
the metadata associated with this dataset.
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
"category_id", "segmentation"). The values for these keys will be returned as-is.
For example, the densepose annotations are loaded in this way.
Returns:
list[dict]: a list of dicts in cvpods standard format. (See
`Using Custom Datasets </tutorials/datasets.html>`_ )
Notes:
1. This function does not read the image files.
The results do not have the "image" field.
"""
from pycocotools.coco import COCO
timer = Timer()
json_file = PathManager.get_local_path(json_file)
with contextlib.redirect_stdout(io.StringIO()):
coco_api = COCO(json_file)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(
json_file, timer.seconds()))
id_map = None
if dataset_name is not None:
meta = self.meta
cat_ids = sorted(coco_api.getCatIds())
cats = coco_api.loadCats(cat_ids)
# The categories in a custom json file may not be sorted.
thing_classes = [
c["name"] for c in sorted(cats, key=lambda x: x["id"])
]
meta["thing_classes"] = thing_classes
# In COCO, certain category ids are artificially removed,
# and by convention they are always ignored.
# We deal with COCO's id issue and translate
# the category ids to contiguous ids in [0, 80).
# It works by looking at the "categories" field in the json, therefore
# if users' own json also have incontiguous ids, we'll
# apply this mapping as well but print a warning.
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
if "coco" not in dataset_name:
logger.warning("""
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
""")
id_map = {v: i for i, v in enumerate(cat_ids)}
meta["thing_dataset_id_to_contiguous_id"] = id_map
# sort indices for reproducible results
img_ids = sorted(coco_api.imgs.keys())
# imgs is a list of dicts, each looks something like:
# {'license': 4,
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
# 'file_name': 'COCO_val2014_000000001268.jpg',
# 'height': 427,
# 'width': 640,
# 'date_captured': '2013-11-17 05:57:24',
# 'id': 1268}
imgs = coco_api.loadImgs(img_ids)
# anns is a list[list[dict]], where each dict is an annotation
# record for an object. The inner list enumerates the objects in an image
# and the outer list enumerates over images. Example of anns[0]:
# [{'segmentation': [[192.81,
# 247.09,
# ...
# 219.03,
# 249.06]],
# 'area': 1035.749,
# 'iscrowd': 0,
# 'image_id': 1268,
# 'bbox': [192.81, 224.8, 74.73, 33.43],
# 'category_id': 16,
# 'id': 42986},
# ...]
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
if "minival" not in json_file:
# The popular valminusminival & minival annotations for COCO2014 contain this bug.
# However the ratio of buggy annotations there is tiny and does not affect accuracy.
# Therefore we explicitly white-list them.
ann_ids = [
ann["id"] for anns_per_image in anns for ann in anns_per_image
]
assert len(set(ann_ids)) == len(
ann_ids), "Annotation ids in '{}' are not unique!".format(
json_file)
imgs_anns = list(zip(imgs, anns))
logger.info("Loaded {} images in COCO format from {}".format(
len(imgs_anns), json_file))
dataset_dicts = []
ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"
] + (extra_annotation_keys or [])
num_instances_without_valid_segmentation = 0
num_instances_without_valid_bbox = 0
for (img_dict, anno_dict_list) in imgs_anns:
record = {}
record["file_name"] = os.path.join(image_root,
img_dict["file_name"])
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"]
objs = []
for anno in anno_dict_list:
# Check that the image_id in this annotation is the same as
# the image_id we're looking at.
# This fails only when the data parsing logic or the annotation file is buggy.
# The original COCO valminusminival2014 & minival2014 annotation files
# actually contains bugs that, together with certain ways of using COCO API,
# can trigger this assertion.
assert anno["image_id"] == image_id
# Some annotations have negative coordinates
if (np.array(anno["bbox"]) < 0).sum() > 0:
num_instances_without_valid_bbox += 1
continue
if anno.get("ignore", 0) != 0:
continue
obj = {key: anno[key] for key in ann_keys if key in anno}
segm = anno.get("segmentation", None)
if segm: # either list[list[float]] or dict(RLE)
if not isinstance(segm, dict):
# filter out invalid polygons (< 3 points)
segm = [
poly for poly in segm
if len(poly) % 2 == 0 and len(poly) >= 6
]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
obj["segmentation"] = segm
keypts = anno.get("keypoints", None)
if keypts: # list[int]
for idx, v in enumerate(keypts):
if idx % 3 != 2:
# COCO's segmentation coordinates are floating points in [0, H or W],
# but keypoint coordinates are integers in [0, H-1 or W-1]
# Therefore we assume the coordinates are "pixel indices" and
# add 0.5 to convert to floating point coordinates.
keypts[idx] = v + 0.5
obj["keypoints"] = keypts
obj["bbox_mode"] = BoxMode.XYWH_ABS
if id_map:
obj["category_id"] = id_map[obj["category_id"]]
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)
if num_instances_without_valid_bbox > 0:
logger.warning(
"Filtered out {} instances without valid bbox. "
"There might be issues in your dataset generation process.".
format(num_instances_without_valid_bbox))
if num_instances_without_valid_segmentation > 0:
logger.warning(
"Filtered out {} instances without valid segmentation. "
"There might be issues in your dataset generation process.".
format(num_instances_without_valid_segmentation))
return dataset_dicts
def _get_metadata(self):
meta = {"thing_classes": ["person"]}
meta["evaluator_type"] = _PREDEFINED_SPLITS_CITYPERSONS["evaluator_type"]["citypersons"]
meta["image_root"] = self.image_root
meta["json_file"] = self.json_file
return meta
[docs] def evaluate(self, predictions):
"""Dataset must provide a evaluation function to evaluate model."""
raise NotImplementedError
@property
def ground_truth_annotations(self):
return self.dataset_dicts