diff options
author | Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> | 2025-01-17 11:00:30 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 11:00:30 +0000 |
commit | 495ffc4777522e40941753e3b1b79c02f84b25b4 (patch) | |
tree | 5130fcb8676afdcb619a5e5eaef3ac28e135bc08 /r_basicsr/ops/fused_act | |
parent | febd45814cd41560c5247aacb111d8d013f3a303 (diff) | |
download | Comfyui-reactor-node-495ffc4777522e40941753e3b1b79c02f84b25b4.tar.gz |
Add files via upload
Diffstat (limited to 'r_basicsr/ops/fused_act')
-rw-r--r-- | r_basicsr/ops/fused_act/__init__.py | 3 | ||||
-rw-r--r-- | r_basicsr/ops/fused_act/fused_act.py | 95 | ||||
-rw-r--r-- | r_basicsr/ops/fused_act/src/fused_bias_act.cpp | 26 | ||||
-rw-r--r-- | r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu | 100 |
4 files changed, 224 insertions, 0 deletions
diff --git a/r_basicsr/ops/fused_act/__init__.py b/r_basicsr/ops/fused_act/__init__.py new file mode 100644 index 0000000..1f8e03b --- /dev/null +++ b/r_basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu
+
+__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
diff --git a/r_basicsr/ops/fused_act/fused_act.py b/r_basicsr/ops/fused_act/fused_act.py new file mode 100644 index 0000000..876c959 --- /dev/null +++ b/r_basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,95 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+import os
+import torch
+from torch import nn
+from torch.autograd import Function
+
+BASICSR_JIT = os.getenv('BASICSR_JIT')
+if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ fused_act_ext = load(
+ 'fused',
+ sources=[
+ os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
+ os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
+ ],
+ )
+else:
+ try:
+ from . import fused_act_ext
+ except ImportError:
+ pass
+ # avoid annoying print output
+ # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+ # '1. compile with BASICSR_EXT=True. or\n '
+ # '2. set BASICSR_JIT=True during running')
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
+ ctx.scale)
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/r_basicsr/ops/fused_act/src/fused_bias_act.cpp b/r_basicsr/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 0000000..c6225bb --- /dev/null +++ b/r_basicsr/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
+#include <torch/extension.h>
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
diff --git a/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 0000000..31a536f --- /dev/null +++ b/r_basicsr/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include <torch/types.h>
+
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+
+template <typename scalar_t>
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
+ y.data_ptr<scalar_t>(),
+ x.data_ptr<scalar_t>(),
+ b.data_ptr<scalar_t>(),
+ ref.data_ptr<scalar_t>(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
|