Source code for cvpods.layers.deform_conv_with_off

#!/usr/bin/python3
# -*- coding:utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved

import torch
import torch.nn as nn

from .deform_conv import DeformConv, ModulatedDeformConv


[docs]class DeformConvWithOff(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, deformable_groups=1): super(DeformConvWithOff, self).__init__() self.offset_conv = nn.Conv2d( in_channels, deformable_groups * 2 * kernel_size * kernel_size, kernel_size=kernel_size, stride=stride, padding=padding, ) self.dcn = DeformConv( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, deformable_groups=deformable_groups, )
[docs] def forward(self, input): offset = self.offset_conv(input) output = self.dcn(input, offset) return output
[docs]class ModulatedDeformConvWithOff(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, deformable_groups=1): super(ModulatedDeformConvWithOff, self).__init__() self.offset_mask_conv = nn.Conv2d( in_channels, deformable_groups * 3 * kernel_size * kernel_size, kernel_size=kernel_size, stride=stride, padding=padding, ) self.dcnv2 = ModulatedDeformConv( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, deformable_groups=deformable_groups, )
[docs] def forward(self, input): x = self.offset_mask_conv(input) o1, o2, mask = torch.chunk(x, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) output = self.dcnv2(input, offset, mask) return output