summaryrefslogtreecommitdiffstats
path: root/r_basicsr/ops/dcn/src/deform_conv_cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'r_basicsr/ops/dcn/src/deform_conv_cuda.cpp')
-rw-r--r--r_basicsr/ops/dcn/src/deform_conv_cuda.cpp685
1 files changed, 685 insertions, 0 deletions
diff --git a/r_basicsr/ops/dcn/src/deform_conv_cuda.cpp b/r_basicsr/ops/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 0000000..191298a
--- /dev/null
+++ b/r_basicsr/ops/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,685 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include <torch/extension.h>
+#include <ATen/DeviceGuard.h>
+
+#include <cmath>
+#include <vector>
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im);
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const int channels, const int height,
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+ at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+ int padW, int dilationH, int dilationW, int group,
+ int deformable_group) {
+ TORCH_CHECK(weight.ndimension() == 4,
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+ "but got: %s",
+ weight.ndimension());
+
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ TORCH_CHECK(kW > 0 && kH > 0,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+ kW);
+
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+ "kernel size should be consistent with weight, ",
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+ kW, weight.size(2), weight.size(3));
+
+ TORCH_CHECK(dW > 0 && dH > 0,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+ TORCH_CHECK(
+ dilationW > 0 && dilationH > 0,
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+ dilationH, dilationW);
+
+ int ndim = input.ndimension();
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+ ndim);
+
+ long nInputPlane = weight.size(1) * group;
+ long inputHeight = input.size(dimh);
+ long inputWidth = input.size(dimw);
+ long nOutputPlane = weight.size(0);
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
+ "input channels must divide deformable group size");
+
+ if (outputWidth < 1 || outputHeight < 1)
+ AT_ERROR(
+ "Given input size: (%ld x %ld x %ld). "
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+ outputWidth);
+
+ TORCH_CHECK(input.size(1) == nInputPlane,
+ "invalid number of input planes, expected: %d, but got: %d",
+ nInputPlane, input.size(1));
+
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+ "input image is smaller than kernel");
+
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+ "invalid spatial size of offset, expected height: %d width: %d, but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+ "invalid number of channels of offset");
+
+ if (gradOutput != NULL) {
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
+ nOutputPlane, gradOutput->size(dimf));
+
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+ gradOutput->size(dimw) == outputWidth),
+ "invalid size of gradOutput, expected height: %d width: %d , but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, gradOutput->size(dimh),
+ gradOutput->size(dimw));
+ }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ // todo: resize columns to include im2col: done
+ // todo: add im2col_step as input
+ // todo: add new output buffer and transpose it to output (or directly
+ // transpose output) todo: possibly change data indexing because of
+ // parallel_imgs
+
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input.unsqueeze_(0);
+ offset.unsqueeze_(0);
+ }
+
+ // todo: assert batchsize dividable by im2col_step
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+ outputHeight, outputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+ ones = at::ones({outputHeight, outputWidth}, input.options());
+ }
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ at::Tensor output_buffer =
+ at::zeros({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth},
+ output.options());
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
+ output_buffer.size(2), output_buffer.size(3)});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ output_buffer[elt][g] = output_buffer[elt][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output_buffer[elt][g]);
+ }
+ }
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+ output_buffer.size(3), output_buffer.size(4)});
+
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step, outputHeight, outputWidth});
+ output_buffer.transpose_(1, 2);
+ output.copy_(output_buffer);
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ // change order of grad output
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight,
+ outputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ // divide into groups
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, im2col_step, deformable_group,
+ gradOffset[elt]);
+
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
+ }
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ gradOffset = gradOffset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ gradOffset =
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ // todo: transpose and reshape outGrad
+ // todo: reshape columns
+ // todo: add im2col_step as input
+
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+ padW, dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view(
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = gradWeight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+ outputHeight, outputWidth});
+ gradOutputBuffer.copy_(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth});
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ // divide into group
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ gradWeight =
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ gradWeight[g] = gradWeight[g]
+ .flatten(1)
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
+ columns[g].transpose(1, 0), 1.0, scale)
+ .view_as(gradWeight[g]);
+ }
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0),
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3),
+ gradWeight.size(4)});
+ }
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ }
+
+ return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ // resize output
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
+ // resize temporary columns
+ columns =
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+ input.options());
+
+ output = output.view({output.size(0), group, output.size(1) / group,
+ output.size(2), output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ // divide into group
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+ for (int g = 0; g < group; g++) {
+ output[b][g] = output[b][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output[b][g]);
+ }
+
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ }
+
+ output = output.view({output.size(0), output.size(1) * output.size(2),
+ output.size(3), output.size(4)});
+
+ if (with_bias) {
+ output += bias.view({1, bias.size(0), 1, 1});
+ }
+}
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ grad_input = grad_input.view({batch, channels, height, width});
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+ input.options());
+
+ grad_output =
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+ grad_output.size(2), grad_output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ // divide int group
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+ grad_mask[b]);
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
+ // group
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+ grad_weight.size(1), grad_weight.size(2),
+ grad_weight.size(3)});
+ if (with_bias)
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+ for (int g = 0; g < group; g++) {
+ grad_weight[g] =
+ grad_weight[g]
+ .flatten(1)
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+ .view_as(grad_weight[g]);
+ if (with_bias) {
+ grad_bias[g] =
+ grad_bias[g]
+ .view({-1, 1})
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+ .view(-1);
+ }
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+ grad_weight.size(2), grad_weight.size(3),
+ grad_weight.size(4)});
+ if (with_bias)
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+ }
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+ grad_output.size(2), grad_output.size(3),
+ grad_output.size(4)});
+}