Source code for cvpods.data.datasets.imagenetlt

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved

import logging
import os
import os.path as osp
from copy import deepcopy

import numpy as np

import torch

from cvpods.utils import Timer

from ..base_dataset import BaseDataset
from ..registry import DATASETS
from .paths_route import _PREDEFINED_SPLITS_IMAGENETLT
from .imagenet_categories import IMAGENET_CATEGORIES

"""
This file contains functions to parse ImageNet-format annotations into dicts in "cvpods format".
"""

logger = logging.getLogger(__name__)


[docs]@DATASETS.register() class ImageNetLTDataset(BaseDataset): def __init__(self, cfg, dataset_name, transforms=[], is_train=True): super(ImageNetLTDataset, self).__init__(cfg, dataset_name, transforms, is_train) image_root, label_file = _PREDEFINED_SPLITS_IMAGENETLT["imagenetlt"][self.name] self.label_file = osp.join(self.data_root, label_file) self.image_root = osp.join(self.data_root, image_root) self.meta = self._get_metadata() self.dataset_dicts = self._load_annotations() self._set_group_flag() self.eval_with_gt = cfg.TEST.get("WITH_GT", False)
[docs] def __getitem__(self, index): """Load data, apply transforms, converto to Instances. """ dataset_dict = deepcopy(self.dataset_dicts[index]) # read image image = self._read_data(dataset_dict["file_name"]) annotations = dataset_dict.get("annotations", None) # apply transfrom images, annotations = self._apply_transforms( image, annotations) def process(dd, img, annos): if isinstance(annos, list): annos = [a for a in annos if a is not None] # image shape: CHW / NCHW # TODO: fix hack if img.shape[0] == 3: # CHW dd["image"] = torch.as_tensor(np.ascontiguousarray(img)) elif len(img.shape) == 3 and img.shape[-1] == 3: dd["image"] = torch.as_tensor( np.ascontiguousarray(img.transpose(2, 0, 1))) elif len(img.shape) == 4 and img.shape[-1] == 3: # NHWC -> NCHW dd["image"] = torch.as_tensor( np.ascontiguousarray(img.transpose(0, 3, 1, 2))) return dd if isinstance(images, dict): ret = {} # multiple input pipelines for desc, item in images.items(): img, anno = item ret[desc] = process(deepcopy(dataset_dict), img, anno) return ret else: return process(dataset_dict, images, annotations)
def __len__(self): return len(self.dataset_dicts) def _get_metadata(self): assert len(IMAGENET_CATEGORIES.keys()) == 1000 cat_ids = [v[0] for v in IMAGENET_CATEGORIES.values()] assert min(cat_ids) == 1 and max(cat_ids) == len(cat_ids), \ "Category ids are not in [1, #categories], as expected" # Ensure that the category list is sroted by id imagenet_categories = sorted(IMAGENET_CATEGORIES.items(), key=lambda x: x[1][0]) thing_classes = [v[1][1] for v in imagenet_categories] meta = { "thing_classes": thing_classes, "evaluator_type": _PREDEFINED_SPLITS_IMAGENETLT["evaluator_type"]["imagenetlt"], } return meta def _load_annotations(self): timer = Timer() """Constructs the imdb.""" # Compile the split data path logger.info('{} data path: {}'.format(self.name, self.label_file)) # Construct the image db imdb = [] f = open(self.label_file, "r") for line in f.readlines(): img_path, label = line.strip().split(" ") imdb.append({ "im_path": os.path.join(self.image_root, img_path), "class": int(label), }) f.close() logging.info("Loading {} takes {:.2f} seconds.".format(self.label_file, timer.seconds())) dataset_dicts = [] for i, item in enumerate(imdb): dataset_dicts.append({ "image_id": i, "category_id": item["class"], "file_name": item["im_path"], }) return dataset_dicts