Source code for cvpods.layers.deform_conv_with_off
#!/usr/bin/python3
# -*- coding:utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn
from .deform_conv import DeformConv, ModulatedDeformConv
[docs]class DeformConvWithOff(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, padding=1,
dilation=1, deformable_groups=1):
super(DeformConvWithOff, self).__init__()
self.offset_conv = nn.Conv2d(
in_channels,
deformable_groups * 2 * kernel_size * kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
self.dcn = DeformConv(
in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation,
deformable_groups=deformable_groups,
)
[docs] def forward(self, input):
offset = self.offset_conv(input)
output = self.dcn(input, offset)
return output
[docs]class ModulatedDeformConvWithOff(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, padding=1,
dilation=1, deformable_groups=1):
super(ModulatedDeformConvWithOff, self).__init__()
self.offset_mask_conv = nn.Conv2d(
in_channels,
deformable_groups * 3 * kernel_size * kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
self.dcnv2 = ModulatedDeformConv(
in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation,
deformable_groups=deformable_groups,
)
[docs] def forward(self, input):
x = self.offset_mask_conv(input)
o1, o2, mask = torch.chunk(x, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
output = self.dcnv2(input, offset, mask)
return output