deep-tempest/end-to-end/models/op/deform_attn.py

192 lines
7.4 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
import os
import torch
from torch import nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn import functional as F
from einops.layers.torch import Rearrange
from distutils.version import LooseVersion
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
deform_attn_ext = load(
'deform_attn',
sources=[
os.path.join(module_path, 'deform_attn_ext.cpp'),
os.path.join(module_path, 'deform_attn_cuda_pt110.cpp' if LooseVersion(torch.__version__) >= LooseVersion(
'1.10.0') else 'deform_attn_cuda_pt109.cpp'),
os.path.join(module_path, 'deform_attn_cuda_kernel.cu'),
],
)
class Mlp(nn.Module):
""" Multilayer perceptron.
Args:
x: (B, D, H, W, C)
Returns:
x: (B, D, H, W, C)
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class DeformAttnFunction(Function):
@staticmethod
def forward(ctx,
q,
kv,
offset,
kernel_h,
kernel_w,
stride=1,
padding=0,
dilation=1,
attention_heads=1,
deformable_groups=1,
clip_size=1):
ctx.kernel_h = kernel_h
ctx.kernel_w = kernel_w
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.attention_heads = attention_heads
ctx.deformable_groups = deformable_groups
ctx.clip_size = clip_size
if q.requires_grad or kv.requires_grad or offset.requires_grad:
ctx.save_for_backward(q, kv, offset)
output = q.new_empty(q.shape)
ctx._bufs = [q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0)]
deform_attn_ext.deform_attn_forward(q, kv, offset, output,
ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx.kernel_h, ctx.kernel_w, ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.attention_heads, ctx.deformable_groups, ctx.clip_size)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
q, kv, offset = ctx.saved_tensors
grad_q = torch.zeros_like(q)
grad_kv = torch.zeros_like(kv)
grad_offset = torch.zeros_like(offset)
deform_attn_ext.deform_attn_backward(q, kv, offset, ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx._bufs[3], ctx._bufs[4],
grad_q, grad_kv, grad_offset,
grad_output, ctx.kernel_h, ctx.kernel_w, ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.attention_heads, ctx.deformable_groups, ctx.clip_size)
return (grad_q, grad_kv, grad_offset, None, None, None, None, None, None, None, None)
deform_attn = DeformAttnFunction.apply
class DeformAttn(nn.Module):
def __init__(self,
in_channels,
out_channels,
attention_window=[3, 3],
deformable_groups=12,
attention_heads=12,
clip_size=1):
super(DeformAttn, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_h = attention_window[0]
self.kernel_w = attention_window[1]
self.attn_size = self.kernel_h * self.kernel_w
self.deformable_groups = deformable_groups
self.attention_heads = attention_heads
self.clip_size = clip_size
self.stride = 1
self.padding = self.kernel_h//2
self.dilation = 1
self.proj_q = nn.Sequential(Rearrange('n d c h w -> n d h w c'),
nn.Linear(self.in_channels, self.in_channels),
Rearrange('n d h w c -> n d c h w'))
self.proj_k = nn.Sequential(Rearrange('n d c h w -> n d h w c'),
nn.Linear(self.in_channels, self.in_channels),
Rearrange('n d h w c -> n d c h w'))
self.proj_v = nn.Sequential(Rearrange('n d c h w -> n d h w c'),
nn.Linear(self.in_channels, self.in_channels),
Rearrange('n d h w c -> n d c h w'))
self.mlp = nn.Sequential(Rearrange('n d c h w -> n d h w c'),
Mlp(self.in_channels, self.in_channels * 2),
Rearrange('n d h w c -> n d c h w'))
def forward(self, q, k, v, offset):
q = self.proj_q(q)
kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2)
v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation,
self.attention_heads, self.deformable_groups, self.clip_size)
v = v + self.mlp(v)
return v
class DeformAttnPack(DeformAttn):
"""A Deformable Attention Encapsulation that acts as normal attention layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
attention_window (int or tuple[int]): Attention window size. Default: [3, 3].
attention_heads (int): Attention head number. Default: 12.
deformable_groups (int): Deformable offset groups. Default: 12.
clip_size (int): clip size. Default: 2.
"""
def __init__(self, *args, **kwargs):
super(DeformAttnPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels * (1 + self.clip_size),
self.clip_size * self.deformable_groups * self.attn_size * 2,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
dilation=(1, 1),
bias=True)
self.init_weight()
def init_weight(self):
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, q, k, v):
out = self.conv_offset(torch.cat([q.flatten(1, 2), k.flatten(1, 2)], 1))
o1, o2 = torch.chunk(out, 2, dim=1)
offset = torch.cat((o1, o2), dim=1)
q = self.proj_q(q)
kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2)
v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation,
self.attention_heads, self.deformable_groups, self.clip_size)
v = v + self.mlp(v)
return v