#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict, List, Set
import torch
from torch import optim
from cvpods.utils.registry import Registry
from .lars_sgd import LARS_SGD
OPTIMIZER_BUILDER = Registry("Optimizer builder")
NORM_MODULE_TYPES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
def exclude_from_wd(named_params, weight_decay, skip_list=['bias', 'bn']):
params = []
excluded_params = []
for name, param in named_params:
if not param.requires_grad:
continue
elif any(layer_name in name for layer_name in skip_list):
excluded_params.append(param)
else:
params.append(param)
return [
{'params': params, 'weight_decay': weight_decay},
{'params': excluded_params, 'weight_decay': 0., 'lars_exclude': True},
]
[docs]@OPTIMIZER_BUILDER.register()
class OptimizerBuilder:
[docs] @staticmethod
def build(model, cfg):
raise NotImplementedError
[docs]@OPTIMIZER_BUILDER.register()
class SGDBuilder(OptimizerBuilder):
[docs] @staticmethod
def build(model, cfg):
optimizer = optim.SGD(
model.parameters(),
lr=cfg.SOLVER.OPTIMIZER.BASE_LR,
weight_decay=cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY,
momentum=cfg.SOLVER.OPTIMIZER.MOMENTUM,
)
return optimizer
@OPTIMIZER_BUILDER.register()
class D2SGDBuilder(OptimizerBuilder):
@staticmethod
def build(model, cfg):
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module in model.modules():
for key, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.OPTIMIZER.BASE_LR
weight_decay = cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY
if isinstance(module, NORM_MODULE_TYPES):
weight_decay = cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY_NORM
elif key == "bias":
# NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
# and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
# hyperparameters are by default exactly the same as for regular
# weights.
lr = cfg.SOLVER.OPTIMIZER.BASE_LR * cfg.SOLVER.OPTIMIZER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
optimizer = optim.SGD(
params,
cfg.SOLVER.OPTIMIZER.BASE_LR,
momentum=cfg.SOLVER.OPTIMIZER.MOMENTUM
)
return optimizer
@OPTIMIZER_BUILDER.register()
class LARS_SGDBuilder(OptimizerBuilder):
@staticmethod
def build(model, cfg):
exclude = cfg.SOLVER.OPTIMIZER.get("WD_EXCLUDE_BN_BIAS", False)
if exclude:
param = exclude_from_wd(
model.named_parameters(), cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY
)
else:
param = model.parameters()
optimizer = LARS_SGD(
param,
lr=cfg.SOLVER.OPTIMIZER.BASE_LR,
momentum=cfg.SOLVER.OPTIMIZER.MOMENTUM,
weight_decay=cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY,
nesterov=cfg.SOLVER.OPTIMIZER.get("NESTERROV", False),
eta=cfg.SOLVER.OPTIMIZER.TRUST_COEF,
eps=cfg.SOLVER.OPTIMIZER.EPS,
)
return optimizer
[docs]@OPTIMIZER_BUILDER.register()
class AdamBuilder(OptimizerBuilder):
[docs] @staticmethod
def build(model, cfg):
lr = cfg.SOLVER.OPTIMIZER.BASE_LR
optimizer = optim.Adam(
model.parameters(),
lr=lr,
weight_decay=cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY,
amsgrad=cfg.SOLVER.OPTIMIZER.AMSGRAD
)
return optimizer
[docs]@OPTIMIZER_BUILDER.register()
class AdamWBuilder(OptimizerBuilder):
[docs] @staticmethod
def build(model, cfg):
lr = cfg.SOLVER.OPTIMIZER.BASE_LR
optimizer = optim.AdamW(
model.parameters(),
lr=lr,
betas=cfg.SOLVER.OPTIMIZER.BETAS,
weight_decay=cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY,
amsgrad=cfg.SOLVER.OPTIMIZER.AMSGRAD
)
return optimizer
[docs]@OPTIMIZER_BUILDER.register()
class SGDGateLRBuilder(OptimizerBuilder):
"""
SGD Gate LR optimizer builder, used for DynamicRouting in cvpods.
This optimizer will ultiply lr for gating function.
"""
[docs] @staticmethod
def build(model, cfg):
gate_lr_multi = cfg.SOLVER.OPTIMIZER.GATE_LR_MULTI
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for name, module in model.named_modules():
for key, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.OPTIMIZER.BASE_LR
weight_decay = cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY
if isinstance(module, NORM_MODULE_TYPES):
weight_decay = cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY_NORM
elif key == "bias":
# NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
# and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
# hyperparameters are by default exactly the same as for regular
# weights.
lr = cfg.SOLVER.OPTIMIZER.BASE_LR * cfg.SOLVER.OPTIMIZER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.OPTIMIZER.WEIGHT_DECAY
if gate_lr_multi > 0.0 and "gate_conv" in name:
lr *= gate_lr_multi
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
optimizer = torch.optim.SGD(
params,
cfg.SOLVER.OPTIMIZER.BASE_LR,
momentum=cfg.SOLVER.OPTIMIZER.MOMENTUM
)
return optimizer