# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import torch
import torch.distributed as dist
from torch import nn
from torch.autograd.function import Function
from torch.nn import functional as F
from cvpods.utils import comm
from .wrappers import BatchNorm1d, BatchNorm2d
[docs]class FrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
It contains non-trainable buffers called
"weight" and "bias", "running_mean", "running_var",
initialized to perform identity transformation.
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
which are computed from the original four parameters of BN.
The affine transform `x * weight + bias` will perform the equivalent
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
When loading a backbone model from Caffe2, "running_mean" and "running_var"
will be left unchanged as identity transformation.
Other pre-trained backbone models may contain all 4 parameters.
The forward is implemented by `F.batch_norm(..., training=False)`.
"""
_version = 3
def __init__(self, num_features, eps=1e-5):
super().__init__()
self.num_features = num_features
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features) - eps)
[docs] def forward(self, x):
scale = self.weight * (self.running_var + self.eps).rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return x * scale + bias
if x.requires_grad:
# When gradients are needed, F.batch_norm will use extra memory
# because its backward op computes gradients for weight/bias as well.
scale = self.weight * (self.running_var + self.eps).rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return x * scale + bias
else:
# When gradients are not needed, F.batch_norm is a single fused op
# and provide more optimization opportunities.
return F.batch_norm(
x,
self.running_mean,
self.running_var,
self.weight,
self.bias,
training=False,
eps=self.eps,
)
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
version = local_metadata.get("version", None)
if version is None:
# keep the origin key if version is None
if prefix + "running_mean" not in state_dict:
state_dict[prefix + "running_mean"] = self.running_mean.clone().detach()
if prefix + "running_var" not in state_dict:
state_dict[prefix + "running_var"] = self.running_var.clone().detach()
else:
if version < 2:
# No running_mean/var in early versions
# This will silent the warnings
if prefix + "running_mean" not in state_dict:
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
if prefix + "running_var" not in state_dict:
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
if version < 3:
logger = logging.getLogger(__name__)
logger.info(
"FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))
)
# In version < 3, running_var are used without +eps.
state_dict[prefix + "running_var"] -= self.eps
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs
)
def __repr__(self):
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
[docs] @classmethod
def convert_frozen_batchnorm(cls, module):
"""
Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
Args:
module (torch.nn.Module):
Returns:
If module is BatchNorm/SyncBatchNorm, returns a new module.
Otherwise, in-place convert module and return it.
Similar to convert_sync_batchnorm in
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
"""
bn_module = nn.modules.batchnorm
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
res = module
if isinstance(module, bn_module):
res = cls(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = cls.convert_frozen_batchnorm(child)
if new_child is not child:
res.add_module(name, new_child)
return res
[docs]def get_norm(norm, out_channels):
"""
Args:
norm (str or callable):
Returns:
nn.Module or None: the normalization layer
"""
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"BN": BatchNorm2d,
"SyncBN": NaiveSyncBatchNorm,
"SyncBN1d": NaiveSyncBatchNorm1d,
"FrozenBN": FrozenBatchNorm2d,
"GN": lambda channels: nn.GroupNorm(32, channels),
"nnSyncBN": nn.SyncBatchNorm, # keep for debugging
}[norm]
return norm(out_channels)
[docs]def get_activation(activation):
"""
Args:
norm (str or callable):
Returns:
nn.Module or None: the normalization layer
"""
if activation is None:
return None
atype = activation.NAME
inplace = activation.INPLACE
act = {
"ReLU": nn.ReLU,
"ReLU6": nn.ReLU6,
}[atype]
return act(inplace=inplace)
class AllReduce(Function):
@staticmethod
def forward(ctx, input):
input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())]
# Use allgather instead of allreduce since I don't trust in-place operations ..
dist.all_gather(input_list, input, async_op=False)
inputs = torch.stack(input_list, dim=0)
return torch.sum(inputs, dim=0)
@staticmethod
def backward(ctx, grad_output):
dist.all_reduce(grad_output, async_op=False)
return grad_output
[docs]class NaiveSyncBatchNorm(BatchNorm2d):
"""
`torch.nn.SyncBatchNorm` has known unknown bugs.
It produces significantly worse AP (and sometimes goes NaN)
when the batch size on each worker is quite different
(e.g., when scale augmentation is used, or when it is applied to mask head).
Use this implementation before `nn.SyncBatchNorm` is fixed.
It is slower than `nn.SyncBatchNorm`.
"""
[docs] def forward(self, input):
if comm.get_world_size() == 1 or not self.training:
return super().forward(input)
assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
C = input.shape[1]
mean = torch.mean(input, dim=[0, 2, 3])
meansqr = torch.mean(input * input, dim=[0, 2, 3])
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)
var = meansqr - mean * mean
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
self.running_var += self.momentum * (var.detach() - self.running_var)
invstd = torch.rsqrt(var + self.eps)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return input * scale + bias
class NaiveSyncBatchNorm1d(BatchNorm1d):
"""
`torch.nn.SyncBatchNorm` has known unknown bugs.
It produces significantly worse AP (and sometimes goes NaN)
when the batch size on each worker is quite different
(e.g., when scale augmentation is used, or when it is applied to mask head).
Use this implementation before `nn.SyncBatchNorm` is fixed.
It is slower than `nn.SyncBatchNorm`.
"""
def forward(self, input):
if comm.get_world_size() == 1 or not self.training:
return super().forward(input)
assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
C = input.shape[1]
mean = torch.mean(input, dim=[0])
meansqr = torch.mean(input * input, dim=[0])
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)
var = meansqr - mean * mean
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
self.running_var += self.momentum * (var.detach() - self.running_var)
invstd = torch.rsqrt(var + self.eps)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape(1, -1)
bias = bias.reshape(1, -1)
return input * scale + bias