Source code for cvpods.configs.config_helper

#!/usr/bin/python3
# -*- coding:utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved

import logging
import re
import time
import warnings
from ast import literal_eval
from colorama import Back, Fore, Style
from easydict import EasyDict


[docs]def highlight(keyword, target, color=Fore.BLACK + Back.YELLOW): """ use given color to highlight keyword in target string Args: keyword(str): highlight string target(str): target string color(str): string represent the color, use black foreground and yellow background as default Returns: (str) target string with keyword highlighted """ return re.sub(keyword, color + r"\g<0>" + Style.RESET_ALL, target)
[docs]def find_key(param_dict: dict, key: str) -> dict: """ find key in dict Args: param_dict(dict): key(str): Returns: (dict) Examples:: >>> d = dict(abc=2, ab=4, c=4) >>> find_key(d, "ab") {'abc': 2, 'ab':4} """ find_result = {} for k, v in param_dict.items(): if re.search(key, k): find_result[k] = v if isinstance(v, dict): res = find_key(v, key) if res: find_result[k] = res return find_result
[docs]def diff_dict(src, dst): """ find difference between src dict and dst dict Args: src(dict): src dict dst(dict): dst dict Returns: (dict) dict contains all the difference key """ diff_result = {} for k, v in src.items(): if k not in dst: diff_result[k] = v elif dst[k] != v: if isinstance(v, dict): diff_result[k] = diff_dict(v, dst[k]) else: diff_result[k] = v return diff_result
[docs]def version_update(config): """ Backward compatibility of old config's Augmentation pipelines and Optimizer configs; Convert old format into new ones. """ # Old Augmentation config input_cfg = config.INPUT train_input_pop_list = [ ("MIN_SIZE_TRAIN", "short_edge_length"), ("MIN_SIZE_TRAIN_SAMPLING", "sample_style"), ("MAX_SIZE_TRAIN", "max_size") ] test_input_pop_list = [ ("MIN_SIZE_TEST", "short_edge_length"), ("MAX_SIZE_TEST", "max_size"), ] train_contains = [k for k in train_input_pop_list if k[0] in input_cfg] test_contains = [k for k in test_input_pop_list if k[0] in input_cfg] train_transforms = [t[0] for t in input_cfg["AUG"]["TRAIN_PIPELINES"]] test_transforms = [t[0] for t in input_cfg["AUG"]["TEST_PIPELINES"]] warnings.filterwarnings( "default", category=DeprecationWarning, module=__name__ ) # only change local warning level if train_contains: warnings.warn("Old format training config will be deprecated ", DeprecationWarning) time.sleep(1) for k in train_contains: if "ResizeShortestEdge" in train_transforms: idx = idx = train_transforms.index("ResizeShortestEdge") input_cfg["AUG"]["TRAIN_PIPELINES"][idx][1][k[1]] = input_cfg[k[0]] if test_contains: warnings.warn("Old format testing config will be deprecated ", DeprecationWarning) time.sleep(1) for k in test_contains: if "ResizeShortestEdge" in test_transforms: idx = test_transforms.index("ResizeShortestEdge") input_cfg["AUG"]["TEST_PIPELINES"][idx][1][k[1]] = input_cfg[k[0]] for elem in train_contains + test_contains: config.INPUT.pop(elem[0]) # Old SOLVER format solver_cfg = config.SOLVER candidates = list(solver_cfg.keys()) lr_pop_list = [ "LR_SCHEDULER_NAME", "MAX_ITER", "STEPS", "WARMUP_FACTOR", "WARMUP_ITERS", "WARMUP_METHOD", "GAMMA" ] optim_pop_list = [ "BASE_LR", "BIAS_LR_FACTOR", "WEIGHT_DECAY", "WEIGHT_DECAY_NORM", "WEIGHT_DECAY_BIAS", "MOMENTUM" ] if any(item in candidates for item in lr_pop_list + optim_pop_list): warnings.warn("Old format solver config will be deprecated ", DeprecationWarning) time.sleep(1) if "LR_SCHEDULER" in solver_cfg and "OPTIMIZER" in solver_cfg: for k in candidates: if k in lr_pop_list: solver_cfg["LR_SCHEDULER"][k] = solver_cfg[k] elif k in optim_pop_list: solver_cfg["OPTIMIZER"][k] = solver_cfg[k] else: solver_cfg["LR_SCHEDULER"] = dict( NAME=solver_cfg.get("LR_SCHEDULER_NAME", "WarmupMultiStepLR"), MAX_ITER=solver_cfg.get("MAX_ITER", 30000), STEPS=solver_cfg.get("STEPS", (20000,)), WARMUP_FACTOR=solver_cfg.get("WARMUP_FACTOR", 1.0 / 1000), WARMUP_ITERS=solver_cfg.get("WARMUP_ITERS", 1000), WARMUP_METHOD=solver_cfg.get("WARMUP_METHOD", "linear"), GAMMA=solver_cfg.get("GAMMA", 0.1), ) solver_cfg["OPTIMIZER"] = dict( NAME="SGD", BASE_LR=solver_cfg.get("BASE_LR", 0.001), BIAS_LR_FACTOR=solver_cfg.get("BIAS_LR_FACTOR", 1.0), WEIGHT_DECAY=solver_cfg.get("WEIGHT_DECAY", 0.0001), WEIGHT_DECAY_NORM=solver_cfg.get("WEIGHT_DECAY_NORM", 0.0), WEIGHT_DECAY_BIAS=solver_cfg.get("WEIGHT_DECAY_BIAS", 0.0001), MOMENTUM=solver_cfg.get("MOMENTUM", 0.9), ) for k in candidates: if k in lr_pop_list + optim_pop_list: solver_cfg.pop(k)
def _assert_with_logging(cond, msg): logger = logging.getLogger(__name__) if not cond: logger.error(msg) assert cond, msg def _decode_cfg_value(value): """ Decodes a raw config value (e.g., from a yaml config files or command line argument) into a Python object. If the value is a dict, it will be interpreted as a new config dict. If the value is a str, it will be evaluated as literals. Otherwise it is returned as-is. Args: value (dict or str): value to be decoded """ if isinstance(value, str): # Try to interpret `value` as a string, number, tuple, list, dict, boolean, or None try: value = literal_eval(value) # The following two excepts allow v to pass through when it represents a # string. # # Longer explanation: # The type of v is always a string (before calling literal_eval), but # sometimes it *represents* a string and other times a data structure, like # a list. In the case that v represents a string, what we got back from the # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is # ok with '"foo"', but will raise a ValueError if given 'foo'. In other # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval # will raise a SyntaxError. except (ValueError, SyntaxError): pass if isinstance(value, dict): return EasyDict(value) else: return value def _cast_cfg_value_type(replacement, original, full_key): """ Checks that `replacement`, which is intended to replace `original` is of the right type. The type is correct if it matches exactly or is one of a few cases in which the type can be easily coerced. """ logger = logging.getLogger(__name__) ori_type = type(original) new_type = type(replacement) if original is None or replacement is None: logger.info("None type, {} to {}".format(ori_type, new_type)) return replacement # The types must match (with some exceptions) if new_type == ori_type: logger.info( "change value of '{}' from {} to {}".format(full_key, original, replacement) ) return replacement # try to casts replacement to original type try: replacement = ori_type(replacement) return replacement except Exception: logger.error( "Could not cast '{}' from {} to {} with values ({} vs. {})".format( full_key, new_type, ori_type, replacement, original) ) raise ValueError