192 lines
7.4 KiB
Python
192 lines
7.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import math
|
|
import os
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.autograd import Function
|
|
from torch.autograd.function import once_differentiable
|
|
from torch.nn import functional as F
|
|
from einops.layers.torch import Rearrange
|
|
from distutils.version import LooseVersion
|
|
from torch.utils.cpp_extension import load
|
|
|
|
module_path = os.path.dirname(__file__)
|
|
deform_attn_ext = load(
|
|
'deform_attn',
|
|
sources=[
|
|
os.path.join(module_path, 'deform_attn_ext.cpp'),
|
|
os.path.join(module_path, 'deform_attn_cuda_pt110.cpp' if LooseVersion(torch.__version__) >= LooseVersion(
|
|
'1.10.0') else 'deform_attn_cuda_pt109.cpp'),
|
|
os.path.join(module_path, 'deform_attn_cuda_kernel.cu'),
|
|
],
|
|
)
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
""" Multilayer perceptron.
|
|
|
|
Args:
|
|
x: (B, D, H, W, C)
|
|
|
|
Returns:
|
|
x: (B, D, H, W, C)
|
|
"""
|
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
def forward(self, x):
|
|
return self.fc2(self.act(self.fc1(x)))
|
|
|
|
|
|
class DeformAttnFunction(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx,
|
|
q,
|
|
kv,
|
|
offset,
|
|
kernel_h,
|
|
kernel_w,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
attention_heads=1,
|
|
deformable_groups=1,
|
|
clip_size=1):
|
|
ctx.kernel_h = kernel_h
|
|
ctx.kernel_w = kernel_w
|
|
ctx.stride = stride
|
|
ctx.padding = padding
|
|
ctx.dilation = dilation
|
|
ctx.attention_heads = attention_heads
|
|
ctx.deformable_groups = deformable_groups
|
|
ctx.clip_size = clip_size
|
|
if q.requires_grad or kv.requires_grad or offset.requires_grad:
|
|
ctx.save_for_backward(q, kv, offset)
|
|
output = q.new_empty(q.shape)
|
|
ctx._bufs = [q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0)]
|
|
deform_attn_ext.deform_attn_forward(q, kv, offset, output,
|
|
ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx.kernel_h, ctx.kernel_w, ctx.stride,
|
|
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
|
ctx.attention_heads, ctx.deformable_groups, ctx.clip_size)
|
|
return output
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, grad_output):
|
|
if not grad_output.is_cuda:
|
|
raise NotImplementedError
|
|
q, kv, offset = ctx.saved_tensors
|
|
grad_q = torch.zeros_like(q)
|
|
grad_kv = torch.zeros_like(kv)
|
|
grad_offset = torch.zeros_like(offset)
|
|
deform_attn_ext.deform_attn_backward(q, kv, offset, ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx._bufs[3], ctx._bufs[4],
|
|
grad_q, grad_kv, grad_offset,
|
|
grad_output, ctx.kernel_h, ctx.kernel_w, ctx.stride,
|
|
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
|
ctx.attention_heads, ctx.deformable_groups, ctx.clip_size)
|
|
|
|
return (grad_q, grad_kv, grad_offset, None, None, None, None, None, None, None, None)
|
|
|
|
|
|
deform_attn = DeformAttnFunction.apply
|
|
|
|
|
|
class DeformAttn(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
attention_window=[3, 3],
|
|
deformable_groups=12,
|
|
attention_heads=12,
|
|
clip_size=1):
|
|
super(DeformAttn, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_h = attention_window[0]
|
|
self.kernel_w = attention_window[1]
|
|
self.attn_size = self.kernel_h * self.kernel_w
|
|
self.deformable_groups = deformable_groups
|
|
self.attention_heads = attention_heads
|
|
self.clip_size = clip_size
|
|
self.stride = 1
|
|
self.padding = self.kernel_h//2
|
|
self.dilation = 1
|
|
|
|
self.proj_q = nn.Sequential(Rearrange('n d c h w -> n d h w c'),
|
|
nn.Linear(self.in_channels, self.in_channels),
|
|
Rearrange('n d h w c -> n d c h w'))
|
|
self.proj_k = nn.Sequential(Rearrange('n d c h w -> n d h w c'),
|
|
nn.Linear(self.in_channels, self.in_channels),
|
|
Rearrange('n d h w c -> n d c h w'))
|
|
self.proj_v = nn.Sequential(Rearrange('n d c h w -> n d h w c'),
|
|
nn.Linear(self.in_channels, self.in_channels),
|
|
Rearrange('n d h w c -> n d c h w'))
|
|
self.mlp = nn.Sequential(Rearrange('n d c h w -> n d h w c'),
|
|
Mlp(self.in_channels, self.in_channels * 2),
|
|
Rearrange('n d h w c -> n d c h w'))
|
|
|
|
def forward(self, q, k, v, offset):
|
|
q = self.proj_q(q)
|
|
kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2)
|
|
v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation,
|
|
self.attention_heads, self.deformable_groups, self.clip_size)
|
|
v = v + self.mlp(v)
|
|
return v
|
|
|
|
|
|
class DeformAttnPack(DeformAttn):
|
|
"""A Deformable Attention Encapsulation that acts as normal attention layers.
|
|
|
|
Args:
|
|
in_channels (int): Same as nn.Conv2d.
|
|
out_channels (int): Same as nn.Conv2d.
|
|
attention_window (int or tuple[int]): Attention window size. Default: [3, 3].
|
|
attention_heads (int): Attention head number. Default: 12.
|
|
deformable_groups (int): Deformable offset groups. Default: 12.
|
|
clip_size (int): clip size. Default: 2.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(DeformAttnPack, self).__init__(*args, **kwargs)
|
|
|
|
self.conv_offset = nn.Conv2d(
|
|
self.in_channels * (1 + self.clip_size),
|
|
self.clip_size * self.deformable_groups * self.attn_size * 2,
|
|
kernel_size=(3, 3),
|
|
stride=(1, 1),
|
|
padding=(1, 1),
|
|
dilation=(1, 1),
|
|
bias=True)
|
|
self.init_weight()
|
|
|
|
def init_weight(self):
|
|
if hasattr(self, 'conv_offset'):
|
|
self.conv_offset.weight.data.zero_()
|
|
self.conv_offset.bias.data.zero_()
|
|
|
|
def forward(self, q, k, v):
|
|
out = self.conv_offset(torch.cat([q.flatten(1, 2), k.flatten(1, 2)], 1))
|
|
o1, o2 = torch.chunk(out, 2, dim=1)
|
|
offset = torch.cat((o1, o2), dim=1)
|
|
|
|
q = self.proj_q(q)
|
|
kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2)
|
|
v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation,
|
|
self.attention_heads, self.deformable_groups, self.clip_size)
|
|
v = v + self.mlp(v)
|
|
return v
|