diff options
Diffstat (limited to 'r_basicsr/ops/fused_act')
| -rw-r--r-- | r_basicsr/ops/fused_act/__init__.py | 3 | ||||
| -rw-r--r-- | r_basicsr/ops/fused_act/fused_act.py | 95 | ||||
| -rw-r--r-- | r_basicsr/ops/fused_act/src/fused_bias_act.cpp | 26 | ||||
| -rw-r--r-- | r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu | 100 | 
4 files changed, 224 insertions, 0 deletions
| diff --git a/r_basicsr/ops/fused_act/__init__.py b/r_basicsr/ops/fused_act/__init__.py new file mode 100644 index 0000000..1f8e03b --- /dev/null +++ b/r_basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu
 +
 +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
 diff --git a/r_basicsr/ops/fused_act/fused_act.py b/r_basicsr/ops/fused_act/fused_act.py new file mode 100644 index 0000000..876c959 --- /dev/null +++ b/r_basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,95 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
 +
 +import os
 +import torch
 +from torch import nn
 +from torch.autograd import Function
 +
 +BASICSR_JIT = os.getenv('BASICSR_JIT')
 +if BASICSR_JIT == 'True':
 +    from torch.utils.cpp_extension import load
 +    module_path = os.path.dirname(__file__)
 +    fused_act_ext = load(
 +        'fused',
 +        sources=[
 +            os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
 +            os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
 +        ],
 +    )
 +else:
 +    try:
 +        from . import fused_act_ext
 +    except ImportError:
 +        pass
 +        # avoid annoying print output
 +        # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
 +        #       '1. compile with BASICSR_EXT=True. or\n '
 +        #       '2. set BASICSR_JIT=True during running')
 +
 +
 +class FusedLeakyReLUFunctionBackward(Function):
 +
 +    @staticmethod
 +    def forward(ctx, grad_output, out, negative_slope, scale):
 +        ctx.save_for_backward(out)
 +        ctx.negative_slope = negative_slope
 +        ctx.scale = scale
 +
 +        empty = grad_output.new_empty(0)
 +
 +        grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
 +
 +        dim = [0]
 +
 +        if grad_input.ndim > 2:
 +            dim += list(range(2, grad_input.ndim))
 +
 +        grad_bias = grad_input.sum(dim).detach()
 +
 +        return grad_input, grad_bias
 +
 +    @staticmethod
 +    def backward(ctx, gradgrad_input, gradgrad_bias):
 +        out, = ctx.saved_tensors
 +        gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
 +                                                    ctx.scale)
 +
 +        return gradgrad_out, None, None, None
 +
 +
 +class FusedLeakyReLUFunction(Function):
 +
 +    @staticmethod
 +    def forward(ctx, input, bias, negative_slope, scale):
 +        empty = input.new_empty(0)
 +        out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
 +        ctx.save_for_backward(out)
 +        ctx.negative_slope = negative_slope
 +        ctx.scale = scale
 +
 +        return out
 +
 +    @staticmethod
 +    def backward(ctx, grad_output):
 +        out, = ctx.saved_tensors
 +
 +        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
 +
 +        return grad_input, grad_bias, None, None
 +
 +
 +class FusedLeakyReLU(nn.Module):
 +
 +    def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
 +        super().__init__()
 +
 +        self.bias = nn.Parameter(torch.zeros(channel))
 +        self.negative_slope = negative_slope
 +        self.scale = scale
 +
 +    def forward(self, input):
 +        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
 +
 +
 +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
 +    return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
 diff --git a/r_basicsr/ops/fused_act/src/fused_bias_act.cpp b/r_basicsr/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 0000000..c6225bb --- /dev/null +++ b/r_basicsr/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
 +#include <torch/extension.h>
 +
 +
 +torch::Tensor fused_bias_act_op(const torch::Tensor& input,
 +                                const torch::Tensor& bias,
 +                                const torch::Tensor& refer,
 +                                int act, int grad, float alpha, float scale);
 +
 +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
 +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
 +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
 +
 +torch::Tensor fused_bias_act(const torch::Tensor& input,
 +                             const torch::Tensor& bias,
 +                             const torch::Tensor& refer,
 +                             int act, int grad, float alpha, float scale) {
 +    CHECK_CUDA(input);
 +    CHECK_CUDA(bias);
 +
 +    return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
 +}
 +
 +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 +    m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
 +}
 diff --git a/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 0000000..31a536f --- /dev/null +++ b/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
 +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
 +//
 +// This work is made available under the Nvidia Source Code License-NC.
 +// To view a copy of this license, visit
 +// https://nvlabs.github.io/stylegan2/license.html
 +
 +#include <torch/types.h>
 +
 +#include <ATen/ATen.h>
 +#include <ATen/AccumulateType.h>
 +#include <ATen/cuda/CUDAContext.h>
 +#include <ATen/cuda/CUDAApplyUtils.cuh>
 +
 +#include <cuda.h>
 +#include <cuda_runtime.h>
 +
 +
 +template <typename scalar_t>
 +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
 +    int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
 +    int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
 +
 +    scalar_t zero = 0.0;
 +
 +    for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
 +        scalar_t x = p_x[xi];
 +
 +        if (use_bias) {
 +            x += p_b[(xi / step_b) % size_b];
 +        }
 +
 +        scalar_t ref = use_ref ? p_ref[xi] : zero;
 +
 +        scalar_t y;
 +
 +        switch (act * 10 + grad) {
 +            default:
 +            case 10: y = x; break;
 +            case 11: y = x; break;
 +            case 12: y = 0.0; break;
 +
 +            case 30: y = (x > 0.0) ? x : x * alpha; break;
 +            case 31: y = (ref > 0.0) ? x : x * alpha; break;
 +            case 32: y = 0.0; break;
 +        }
 +
 +        out[xi] = y * scale;
 +    }
 +}
 +
 +
 +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
 +    int act, int grad, float alpha, float scale) {
 +    int curDevice = -1;
 +    cudaGetDevice(&curDevice);
 +    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
 +
 +    auto x = input.contiguous();
 +    auto b = bias.contiguous();
 +    auto ref = refer.contiguous();
 +
 +    int use_bias = b.numel() ? 1 : 0;
 +    int use_ref = ref.numel() ? 1 : 0;
 +
 +    int size_x = x.numel();
 +    int size_b = b.numel();
 +    int step_b = 1;
 +
 +    for (int i = 1 + 1; i < x.dim(); i++) {
 +        step_b *= x.size(i);
 +    }
 +
 +    int loop_x = 4;
 +    int block_size = 4 * 32;
 +    int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
 +
 +    auto y = torch::empty_like(x);
 +
 +    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
 +        fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
 +            y.data_ptr<scalar_t>(),
 +            x.data_ptr<scalar_t>(),
 +            b.data_ptr<scalar_t>(),
 +            ref.data_ptr<scalar_t>(),
 +            act,
 +            grad,
 +            alpha,
 +            scale,
 +            loop_x,
 +            size_x,
 +            step_b,
 +            size_b,
 +            use_bias,
 +            use_ref
 +        );
 +    });
 +
 +    return y;
 +}
 | 
