From 9acf676cecbf4e08c459684cff07699a667cfc10 Mon Sep 17 00:00:00 2001 From: LoickCh Date: Fri, 1 Mar 2024 10:34:36 +0100 Subject: [PATCH] add gs package. --- pointbev/ops/gs/.clang-format | 108 +++ pointbev/ops/gs/.gitignore | 8 + pointbev/ops/gs/cuda/gs_cuda_sparse.cu | 779 +++++++++++++++++++++ pointbev/ops/gs/cuda/gs_cuda_torch.cu | 550 +++++++++++++++ pointbev/ops/gs/cuda/gs_gpu.cpp | 110 +++ pointbev/ops/gs/cuda/neighborhood.cpp | 58 ++ pointbev/ops/gs/cuda/neighbourhood_cuda.cu | 111 +++ pointbev/ops/gs/functions/__init__.py | 1 + pointbev/ops/gs/functions/gs.py | 122 ++++ pointbev/ops/gs/include/check.h | 6 + pointbev/ops/gs/include/gs.h | 82 +++ pointbev/ops/gs/include/neighbourhood.h | 10 + pointbev/ops/gs/pybind.cpp | 15 + pointbev/ops/gs/scripts/clean.sh | 6 + pointbev/ops/gs/scripts/setup.sh | 4 + pointbev/ops/gs/setup.py | 31 + pointbev/ops/gs/tests/__init__.py | 0 pointbev/ops/gs/tests/fixtures.py | 31 + pointbev/ops/gs/tests/test_sparse.py | 162 +++++ pointbev/ops/gs/tests/utils.py | 79 +++ 20 files changed, 2273 insertions(+) create mode 100644 pointbev/ops/gs/.clang-format create mode 100644 pointbev/ops/gs/.gitignore create mode 100644 pointbev/ops/gs/cuda/gs_cuda_sparse.cu create mode 100644 pointbev/ops/gs/cuda/gs_cuda_torch.cu create mode 100644 pointbev/ops/gs/cuda/gs_gpu.cpp create mode 100644 pointbev/ops/gs/cuda/neighborhood.cpp create mode 100644 pointbev/ops/gs/cuda/neighbourhood_cuda.cu create mode 100644 pointbev/ops/gs/functions/__init__.py create mode 100644 pointbev/ops/gs/functions/gs.py create mode 100644 pointbev/ops/gs/include/check.h create mode 100644 pointbev/ops/gs/include/gs.h create mode 100644 pointbev/ops/gs/include/neighbourhood.h create mode 100644 pointbev/ops/gs/pybind.cpp create mode 100644 pointbev/ops/gs/scripts/clean.sh create mode 100644 pointbev/ops/gs/scripts/setup.sh create mode 100644 pointbev/ops/gs/setup.py create mode 100644 pointbev/ops/gs/tests/__init__.py create mode 100644 pointbev/ops/gs/tests/fixtures.py create mode 100644 pointbev/ops/gs/tests/test_sparse.py create mode 100644 pointbev/ops/gs/tests/utils.py diff --git a/pointbev/ops/gs/.clang-format b/pointbev/ops/gs/.clang-format new file mode 100644 index 0000000..bb10d98 --- /dev/null +++ b/pointbev/ops/gs/.clang-format @@ -0,0 +1,108 @@ +Language: Cpp +# BasedOnStyle: LLVM +AccessModifierOffset: -4 # -2 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Right +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: true # false +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: true # false + AfterControlStatement: true # false + AfterEnum: true # false + AfterFunction: true # false + AfterNamespace: true # false + AfterObjCDeclaration: true # false + AfterStruct: true # false + AfterUnion: true # false + AfterExternBlock: true # false + BeforeCatch: true # false + BeforeElse: true # false + IndentBraces: true # false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Allman # Attach +BreakBeforeInheritanceComma: false +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 100 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|gmock|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: true # false +IndentPPDirectives: None +IndentWidth: 2 #2 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Left # Right +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never \ No newline at end of file diff --git a/pointbev/ops/gs/.gitignore b/pointbev/ops/gs/.gitignore new file mode 100644 index 0000000..32ac457 --- /dev/null +++ b/pointbev/ops/gs/.gitignore @@ -0,0 +1,8 @@ +build/* +dist/* +sparse_gs.egg-info/* +.vscode/ +.pytest_cache/ +__pycache__/ +*/__pycache__/ +functions/__pycache__/* \ No newline at end of file diff --git a/pointbev/ops/gs/cuda/gs_cuda_sparse.cu b/pointbev/ops/gs/cuda/gs_cuda_sparse.cu new file mode 100644 index 0000000..4344d02 --- /dev/null +++ b/pointbev/ops/gs/cuda/gs_cuda_sparse.cu @@ -0,0 +1,779 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "check.h" + +using namespace at::cuda::detail; +using at::native::detail::GridSamplerInterpolation; +using at::native::detail::GridSamplerPadding; + +// ---------------------------------------------- +// Kernels: sparsed +// ---------------------------------------------- +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void sparsed_gs_3d_fw_kernel( + const index_t nthreads, TensorInfo input, TensorInfo grid, + const torch::PackedTensorAccessor index_batch_ptr, + TensorInfo output, const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, bool align_corners) +{ + const index_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= output.sizes[0]) + return; + + using opmath_t = at::opmath_type; + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + index_t out_Npts = grid.sizes[0]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + index_t grid_sNpts = grid.strides[0]; + index_t grid_sCoor = grid.strides[1]; + index_t out_sNpts = output.strides[0]; + index_t out_sC = output.strides[1]; + + const index_t grid_offset = index * grid_sNpts; + + // get the corresponding input x, y, z co-ordinates from grid + opmath_t x = grid.data[grid_offset]; + opmath_t y = grid.data[grid_offset + grid_sCoor]; + opmath_t z = grid.data[grid_offset + 2 * grid_sCoor]; + index_t n = index_batch_ptr[index]; + + opmath_t ix = + at::native::grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + opmath_t iy = + at::native::grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); + opmath_t iz = + at::native::grid_sampler_compute_source_index(z, inp_D, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(::floor(ix)); + index_t iy_tnw = static_cast(::floor(iy)); + index_t iz_tnw = static_cast(::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + opmath_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + opmath_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + opmath_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + opmath_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + opmath_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + opmath_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + opmath_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + opmath_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCDHW = output.data + index * out_sNpts; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) + { + opmath_t out_acc = 0; + if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + *out_ptr_NCDHW = out_acc; + } + } +} + +template +C10_LAUNCH_BOUNDS_1(256) +__global__ void sparsed_gs_2d_fw_kernel( + const index_t nthreads, TensorInfo input, TensorInfo grid, + const torch::PackedTensorAccessor index_batch_ptr, + TensorInfo output, const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, bool align_corners) +{ + const index_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= output.sizes[0]) + return; + + using opmath_t = at::opmath_type; + index_t C = input.sizes[1]; + index_t inp_H = input.sizes[2]; + index_t inp_W = input.sizes[3]; + index_t out_Npts = grid.sizes[0]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sH = input.strides[2]; + index_t inp_sW = input.strides[3]; + index_t grid_sNpts = grid.strides[0]; + index_t grid_sCoor = grid.strides[1]; + index_t out_sNpts = output.strides[0]; + index_t out_sC = output.strides[1]; + + const index_t grid_offset = index * grid_sNpts; + + // get the corresponding input x, y, z co-ordinates from grid + opmath_t x = grid.data[grid_offset]; + opmath_t y = grid.data[grid_offset + grid_sCoor]; + index_t n = index_batch_ptr[index]; + + opmath_t ix = + at::native::grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + opmath_t iy = + at::native::grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + index_t ix_nw = static_cast(::floor(ix)); + index_t iy_nw = static_cast(::floor(iy)); + index_t ix_ne = ix_nw + 1; + index_t iy_ne = iy_nw; + index_t ix_sw = ix_nw; + index_t iy_sw = iy_nw + 1; + index_t ix_se = ix_nw + 1; + index_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + opmath_t nw = (ix_se - ix) * (iy_se - iy); + opmath_t ne = (ix - ix_sw) * (iy_sw - iy); + opmath_t sw = (ix_ne - ix) * (iy - iy_ne); + opmath_t se = (ix - ix_nw) * (iy - iy_nw); + + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCHW = output.data + index * out_sNpts; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) + { + opmath_t out_acc = 0; + if (at::native::within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (at::native::within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (at::native::within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (at::native::within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + *out_ptr_NCHW = out_acc; + } + } +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void sparsed_gs_3d_bw_kernel( + const index_t nthreads, TensorInfo grad_output, + TensorInfo input, TensorInfo grid, + const torch::PackedTensorAccessor index_batch_ptr, + TensorInfo grad_input, TensorInfo grad_grid, + const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, + bool align_corners, const index_t grad_input_memory_span, const bool input_requires_grad) +{ + const index_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= grad_output.sizes[0]) + return; + + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + index_t grid_sNpts = grid.strides[0]; + index_t grid_sCoor = grid.strides[1]; + index_t gOut_sNpts = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + int64_t gInp_sN = 0; + int64_t gInp_sC = 0; + int64_t gInp_sD = 0; + int64_t gInp_sH = 0; + int64_t gInp_sW = 0; + if (input_requires_grad) + { + gInp_sN = grad_input.strides[0]; + gInp_sC = grad_input.strides[1]; + gInp_sD = grad_input.strides[2]; + gInp_sH = grad_input.strides[3]; + gInp_sW = grad_input.strides[4]; + } + index_t gGrid_sNpts = grad_grid.strides[0]; + + const auto grid_offset = index * grid_sNpts; + + // get the corresponding input x, y, z co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + scalar_t gix_mult, giy_mult, giz_mult; + ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, + align_corners, &gix_mult); + iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, + align_corners, &giy_mult); + iz = at::native::grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, + align_corners, &giz_mult); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(std::floor(ix)); + index_t iy_tnw = static_cast(std::floor(iy)); + index_t iz_tnw = static_cast(std::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + scalar_t gix = static_cast(0), giy = static_cast(0), + giz = static_cast(0); + + const index_t n = index_batch_ptr[index]; + scalar_t* gOut_ptr_NCDHW = grad_output.data + index * gOut_sNpts; + index_t NC_offset; + if (input_requires_grad) + { + NC_offset = n * gInp_sN; + } + scalar_t* inp_ptr_NC = input.data + n * inp_sN; + // calculate bilinear weighted pixel value and set output pixel + for (index_t c = 0; c < C; + ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) + { + scalar_t gOut = *gOut_ptr_NCDHW; + + if (input_requires_grad) + { + at::native::safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, tnw * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, tne * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, tsw * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, tse * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, bnw * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, bne * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, bsw * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, + inp_D, inp_H, inp_W, bse * gOut, NC_offset, grad_input_memory_span); + } + // calculate grad_grid + if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) + { + scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) + { + scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) + { + scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) + { + scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) + { + scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) + { + scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) + { + scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) + { + scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + scalar_t* gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sNpts; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } +} + +template +C10_LAUNCH_BOUNDS_1(256) +__global__ void sparsed_gs_2d_bw_kernel( + const index_t nthreads, TensorInfo grad_output, + TensorInfo input, TensorInfo grid, + const torch::PackedTensorAccessor index_batch_ptr, + TensorInfo grad_input, TensorInfo grad_grid, + const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, + bool align_corners, const index_t grad_input_memory_span, const bool input_requires_grad) +{ + const index_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= grad_output.sizes[0]) + return; + + index_t C = input.sizes[1]; + index_t inp_H = input.sizes[2]; + index_t inp_W = input.sizes[3]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sH = input.strides[2]; + index_t inp_sW = input.strides[3]; + index_t grid_sNpts = grid.strides[0]; + index_t grid_sCoor = grid.strides[1]; + index_t gOut_sNpts = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + int64_t gInp_sN = 0; + int64_t gInp_sC = 0; + int64_t gInp_sH = 0; + int64_t gInp_sW = 0; + if (input_requires_grad) + { + gInp_sN = grad_input.strides[0]; + gInp_sC = grad_input.strides[1]; + gInp_sH = grad_input.strides[2]; + gInp_sW = grad_input.strides[3]; + } + index_t gGrid_sNpts = grad_grid.strides[0]; + + const auto grid_offset = index * grid_sNpts; + + // get the corresponding input x, y, z co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + scalar_t gix_mult, giy_mult, giz_mult; + ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, + align_corners, &gix_mult); + iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, + align_corners, &giy_mult); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_nw = static_cast(std::floor(ix)); + index_t iy_nw = static_cast(std::floor(iy)); + index_t ix_ne = ix_nw + 1; + index_t iy_ne = iy_nw; + index_t ix_sw = ix_nw; + index_t iy_sw = iy_nw + 1; + index_t ix_se = ix_nw + 1; + index_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + scalar_t gix = static_cast(0), giy = static_cast(0); + + const index_t n = index_batch_ptr[index]; + scalar_t* gOut_ptr_NCHW = grad_output.data + index * gOut_sNpts; + index_t NC_offset; + if (input_requires_grad) + { + NC_offset = n * gInp_sN; + } + scalar_t* inp_ptr_NC = input.data + n * inp_sN; + // calculate bilinear weighted pixel value and set output pixel + for (index_t c = 0; c < C; + ++c, gOut_ptr_NCHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) + { + scalar_t gOut = *gOut_ptr_NCHW; + + if (input_requires_grad) + { + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + at::native::safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, + nw * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, + ne * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, + sw * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, + se * gOut, NC_offset, grad_input_memory_span); + } + // calculate grad_grid + if (at::native::within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) + { + scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]; + gix -= nw_val * (iy_se - iy) * gOut; + giy -= nw_val * (ix_se - ix) * gOut; + } + if (at::native::within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) + { + scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]; + gix += ne_val * (iy_sw - iy) * gOut; + giy -= ne_val * (ix - ix_sw) * gOut; + } + if (at::native::within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) + { + scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]; + gix -= sw_val * (iy - iy_ne) * gOut; + giy += sw_val * (ix_ne - ix) * gOut; + } + if (at::native::within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) + { + scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]; + gix += se_val * (iy - iy_nw) * gOut; + giy += se_val * (ix - ix_nw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + scalar_t* gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sNpts; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + } +} + +// ---------------------------------------------- +// Launchers +// ---------------------------------------------- +void launch_sparsed_gs_3d_fw_kernel(const at::TensorBase& output, const at::TensorBase& input, + const at::TensorBase& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) +{ + int64_t count = output.size(0); + if (count > 0) + { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "sparsed_gs_fw_kernel", + [&] + { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(output)) + { + sparsed_gs_3d_fw_kernel + <<>>( + static_cast(count), getTensorInfo(input), + getTensorInfo(grid), + index_batch.packed_accessor(), + getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + else + { + sparsed_gs_3d_fw_kernel + <<>>( + count, getTensorInfo(input), + getTensorInfo(grid), + index_batch.packed_accessor(), + getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} +void launch_sparsed_gs_2d_fw_kernel(const at::TensorBase& output, const at::TensorBase& input, + const at::TensorBase& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) +{ + int64_t count = output.size(0); + if (count > 0) + { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "sparsed_gs_2d_fw_kernel", + [&] + { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(output)) + { + sparsed_gs_2d_fw_kernel + <<>>( + static_cast(count), getTensorInfo(input), + getTensorInfo(grid), + index_batch.packed_accessor(), + getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + else + { + sparsed_gs_2d_fw_kernel + <<>>( + count, getTensorInfo(input), + getTensorInfo(grid), + index_batch.packed_accessor(), + getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} + +void launch_sparsed_gs_3d_bw_kernel(const at::TensorBase& grad_input, + const at::TensorBase& grad_grid, + const at::TensorBase& grad_output, const at::TensorBase& input, + const at::TensorBase& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask) +{ + at::globalContext().alertNotDeterministic("sparsed_gs_3d_bw_kernel"); + + // grid: Npts, Coord + int64_t count = grid.size(0); + auto input_requires_grad = output_mask[0]; + int16_t NUM_THREADS = 512; + + // clang-format off + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "sparsed_gs_3d_bw_kernel", [&] { + if ( + at::native::canUse32BitIndexMath(input) && + at::native::canUse32BitIndexMath(grid) && + at::native::canUse32BitIndexMath(grad_output)) + { + sparsed_gs_3d_bw_kernel + <<>>( + static_cast(count), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + index_batch.packed_accessor(), + input_requires_grad ? getTensorInfo(grad_input) : TensorInfo(), + getTensorInfo(grad_grid), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners, + /*grad_input_memory_span =*/input_requires_grad ? static_cast(grad_input.numel()) : 0, + input_requires_grad); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + sparsed_gs_3d_bw_kernel + <<>>( + count, + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + index_batch.packed_accessor(), + input_requires_grad ? getTensorInfo(grad_input) : TensorInfo(), + getTensorInfo(grad_grid), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners, + /*grad_input_memory_span =*/input_requires_grad ? grad_input.numel() : 0, + input_requires_grad); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} + +void launch_sparsed_gs_2d_bw_kernel(const at::TensorBase& grad_input, + const at::TensorBase& grad_grid, + const at::TensorBase& grad_output, const at::TensorBase& input, + const at::TensorBase& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask) +{ + at::globalContext().alertNotDeterministic("sparsed_gs_2d_bw_kernel"); + + // grid: Npts, Coord + int64_t count = grid.size(0); + auto input_requires_grad = output_mask[0]; + int16_t NUM_THREADS = 256; + + // clang-format off + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "sparsed_gs_2d_bw_kernel", [&] { + if ( + at::native::canUse32BitIndexMath(input) && + at::native::canUse32BitIndexMath(grid) && + at::native::canUse32BitIndexMath(grad_output)) + { + sparsed_gs_2d_bw_kernel + <<>>( + static_cast(count), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + index_batch.packed_accessor(), + input_requires_grad ? getTensorInfo(grad_input) : TensorInfo(), + getTensorInfo(grad_grid), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners, + /*grad_input_memory_span =*/input_requires_grad ? static_cast(grad_input.numel()) : 0, + input_requires_grad); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + sparsed_gs_2d_bw_kernel + <<>>( + count, + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + index_batch.packed_accessor(), + input_requires_grad ? getTensorInfo(grad_input) : TensorInfo(), + getTensorInfo(grad_grid), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners, + /*grad_input_memory_span =*/input_requires_grad ? grad_input.numel() : 0, + input_requires_grad); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} diff --git a/pointbev/ops/gs/cuda/gs_cuda_torch.cu b/pointbev/ops/gs/cuda/gs_cuda_torch.cu new file mode 100644 index 0000000..09dbd6c --- /dev/null +++ b/pointbev/ops/gs/cuda/gs_cuda_torch.cu @@ -0,0 +1,550 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "check.h" + +using namespace at::cuda::detail; +using at::native::detail::GridSamplerInterpolation; +using at::native::detail::GridSamplerPadding; + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void torch_gs_3d_fw_kernel(const index_t nthreads, TensorInfo input, + TensorInfo grid, + TensorInfo output, + const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, bool align_corners) +{ + + using opmath_t = at::opmath_type; + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + index_t out_sN = output.strides[0]; + index_t out_sC = output.strides[1]; + index_t out_sD = output.strides[2]; + index_t out_sH = output.strides[3]; + index_t out_sW = output.strides[4]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) + { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t d = (index / (out_H * out_W)) % out_D; + const index_t n = index / (out_D * out_H * out_W); + const index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + opmath_t x = grid.data[grid_offset]; + opmath_t y = grid.data[grid_offset + grid_sCoor]; + opmath_t z = grid.data[grid_offset + 2 * grid_sCoor]; + + opmath_t ix = + at::native::grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + opmath_t iy = + at::native::grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); + opmath_t iz = + at::native::grid_sampler_compute_source_index(z, inp_D, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(::floor(ix)); + index_t iy_tnw = static_cast(::floor(iy)); + index_t iz_tnw = static_cast(::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + opmath_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + opmath_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + opmath_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + opmath_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + opmath_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + opmath_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + opmath_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + opmath_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) + { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse + opmath_t out_acc = 0; + if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) + { + out_acc += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + *out_ptr_NCDHW = out_acc; + } + } + else if (interpolation_mode == GridSamplerInterpolation::Nearest) + { + index_t ix_nearest = static_cast(std::nearbyint(ix)); + index_t iy_nearest = static_cast(std::nearbyint(iy)); + index_t iz_nearest = static_cast(std::nearbyint(iz)); + + // assign nearest neighor pixel value to output pixel + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) + { + if (at::native::within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW = + inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } + else + { + *out_ptr_NCDHW = static_cast(0); + } + } + } + } +} + +template +C10_LAUNCH_BOUNDS_1(256) +__global__ void torch_gs_3d_bw_kernel( + const index_t nthreads, TensorInfo grad_output, + TensorInfo input, TensorInfo grid, + TensorInfo + grad_input, // initialized to zeros (or unused if input_requires_grad is false) + TensorInfo grad_grid, // initialized to empty + const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, + bool align_corners, const index_t grad_input_memory_span, const bool input_requires_grad) +{ + + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sD = grad_output.strides[2]; + index_t gOut_sH = grad_output.strides[3]; + index_t gOut_sW = grad_output.strides[4]; + // gInp_* (and NC_offset below) are not really needed if input_requires_grad is false. + int64_t gInp_sN = 0; + int64_t gInp_sC = 0; + int64_t gInp_sD = 0; + int64_t gInp_sH = 0; + int64_t gInp_sW = 0; + if (input_requires_grad) + { + gInp_sN = grad_input.strides[0]; + gInp_sC = grad_input.strides[1]; + gInp_sD = grad_input.strides[2]; + gInp_sH = grad_input.strides[3]; + gInp_sW = grad_input.strides[4]; + } + index_t gGrid_sW = grad_grid.strides[3]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) + { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t d = (index / (out_H * out_W)) % out_D; + const index_t n = index / (out_D * out_H * out_W); + const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + scalar_t gix_mult, giy_mult, giz_mult; + ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, + align_corners, &gix_mult); + iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, + align_corners, &giy_mult); + iz = at::native::grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, + align_corners, &giz_mult); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(std::floor(ix)); + index_t iy_tnw = static_cast(std::floor(iy)); + index_t iz_tnw = static_cast(std::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + scalar_t gix = static_cast(0), giy = static_cast(0), + giz = static_cast(0); + scalar_t* gOut_ptr_NCDHW = + grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + index_t NC_offset; + if (input_requires_grad) + { + NC_offset = n * gInp_sN; + } + scalar_t* inp_ptr_NC = input.data + n * inp_sN; + // calculate bilinear weighted pixel value and set output pixel + for (index_t c = 0; c < C; + ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) + { + scalar_t gOut = *gOut_ptr_NCDHW; + + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + if (input_requires_grad) + { + at::native::safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, + gInp_sW, inp_D, inp_H, inp_W, tnw * gOut, NC_offset, + grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, + gInp_sW, inp_D, inp_H, inp_W, tne * gOut, NC_offset, + grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, + gInp_sW, inp_D, inp_H, inp_W, tsw * gOut, NC_offset, + grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, + gInp_sW, inp_D, inp_H, inp_W, tse * gOut, NC_offset, + grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, + gInp_sW, inp_D, inp_H, inp_W, bnw * gOut, NC_offset, + grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, + gInp_sW, inp_D, inp_H, inp_W, bne * gOut, NC_offset, + grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, + gInp_sW, inp_D, inp_H, inp_W, bsw * gOut, NC_offset, + grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, + gInp_sW, inp_D, inp_H, inp_W, bse * gOut, NC_offset, + grad_input_memory_span); + } + // calculate grad_grid + if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) + { + scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) + { + scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) + { + scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) + { + scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) + { + scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) + { + scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) + { + scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) + { + scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + scalar_t* gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } + else if (interpolation_mode == GridSamplerInterpolation::Nearest) + { + if (input_requires_grad) + { + auto ix_nearest = static_cast(std::nearbyint(ix)); + auto iy_nearest = static_cast(std::nearbyint(iy)); + auto iz_nearest = static_cast(std::nearbyint(iz)); + + // assign nearest neighor pixel value to output pixel + scalar_t* gOut_ptr_NCDHW = + grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + index_t NC_offset = n * gInp_sN; + for (index_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC) + { + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + at::native::safe_add_3d(grad_input.data, iz_nearest, iy_nearest, ix_nearest, gInp_sD, + gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW, NC_offset, + grad_input_memory_span); + } + } + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + scalar_t* gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NDHW[0] = static_cast(0); + gGrid_ptr_NDHW[1] = static_cast(0); + gGrid_ptr_NDHW[2] = static_cast(0); + } + } +} + +void launch_torch_gs_3d_fw_kernel(const at::TensorBase& output, const at::TensorBase& input, + const at::TensorBase& grid, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners) +{ + // See NOTE [ grid_sampler Native Functions ]. + // Add checks here in case this is called instead of grid_sampler. + + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + int64_t count = N * D * H * W; + if (count > 0) + { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "torch_gs_fw_kernel", + [&] + { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(output)) + { + torch_gs_3d_fw_kernel + <<>>( + static_cast(count), getTensorInfo(input), + getTensorInfo(grid), getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + else + { + torch_gs_3d_fw_kernel + <<>>( + count, getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} + +void launch_torch_gs_3d_bw_kernel(const at::TensorBase& grad_input, const at::TensorBase& grad_grid, + const at::TensorBase& grad_output, const at::TensorBase& input, + const at::TensorBase& grid, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners, + std::array output_mask) +{ + // See NOTE [ grid_sampler Native Functions ]. + // Add checks here in case this is called instead of grid_sampler. + + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + at::globalContext().alertNotDeterministic("grid_sampler_3d_backward_cuda"); + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + int64_t count = N * D * H * W; + auto input_requires_grad = output_mask[0]; + if (count > 0) + { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "grid_sampler_3d_backward_cuda", + [&] + { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) + { + torch_gs_3d_bw_kernel + <<>>( + static_cast(count), getTensorInfo(grad_output), + getTensorInfo(input), getTensorInfo(grid), + input_requires_grad ? getTensorInfo(grad_input) + : TensorInfo(), + getTensorInfo(grad_grid), + static_cast(interpolation_mode), + static_cast(padding_mode), align_corners, + /*grad_input_memory_span =*/ + input_requires_grad ? static_cast(grad_input.numel()) : 0, + input_requires_grad); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + else + { + torch_gs_3d_bw_kernel + <<>>( + count, getTensorInfo(grad_output), + getTensorInfo(input), getTensorInfo(grid), + input_requires_grad ? getTensorInfo(grad_input) + : TensorInfo(), + getTensorInfo(grad_grid), + static_cast(interpolation_mode), + static_cast(padding_mode), align_corners, + /*grad_input_memory_span =*/input_requires_grad ? grad_input.numel() : 0, + input_requires_grad); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} \ No newline at end of file diff --git a/pointbev/ops/gs/cuda/gs_gpu.cpp b/pointbev/ops/gs/cuda/gs_gpu.cpp new file mode 100644 index 0000000..eac0450 --- /dev/null +++ b/pointbev/ops/gs/cuda/gs_gpu.cpp @@ -0,0 +1,110 @@ +#include "gs.h" +#include + +// Forward +torch::Tensor sparsed_gs_3d_fw_cuda(const torch::Tensor& input, const torch::Tensor& grid, + const torch::Tensor& index_batch, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners) +{ + // N*S,C,D,H,W + auto c = input.size(1); + auto npts = index_batch.size(0); + auto output = at::empty({npts, c}, input.options()); + launch_sparsed_gs_3d_fw_kernel(output, input, grid, index_batch, interpolation_mode, padding_mode, + align_corners); + return output; +} +torch::Tensor sparsed_gs_2d_fw_cuda(const torch::Tensor& input, const torch::Tensor& grid, + const torch::Tensor& index_batch, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners) +{ + // N*S,C,H,W + auto c = input.size(1); + auto npts = index_batch.size(0); + auto output = at::empty({npts, c}, input.options()); + launch_sparsed_gs_2d_fw_kernel(output, input, grid, index_batch, interpolation_mode, padding_mode, + align_corners); + return output; +} + +torch::Tensor torch_gs_3d_fw_cuda(const torch::Tensor& input, const torch::Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) +{ + auto in_size = input.sizes(); + auto grid_size = grid.sizes(); + auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2], grid_size[3]}, + input.options()); + launch_torch_gs_3d_fw_kernel(output, input, grid, interpolation_mode, padding_mode, + align_corners); + return output; +} + +// Backward +std::tuple +sparsed_gs_3d_bw_cuda(const torch::Tensor& grad_output, const torch::Tensor& input, + const torch::Tensor& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners, + std::array output_mask) +{ + auto input_requires_grad = output_mask[0]; + torch::Tensor grad_input = ([&]() { + if (input_requires_grad) + { + return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + else + { + return torch::Tensor(); + } + })(); + auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + launch_sparsed_gs_3d_bw_kernel(grad_input, grad_grid, grad_output, input, grid, index_batch, + interpolation_mode, padding_mode, align_corners, output_mask); + return std::make_tuple(grad_input, grad_grid); +} + +std::tuple +sparsed_gs_2d_bw_cuda(const torch::Tensor& grad_output, const torch::Tensor& input, + const torch::Tensor& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners, + std::array output_mask) +{ + auto input_requires_grad = output_mask[0]; + torch::Tensor grad_input = ([&]() { + if (input_requires_grad) + { + return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + else + { + return torch::Tensor(); + } + })(); + auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + launch_sparsed_gs_2d_bw_kernel(grad_input, grad_grid, grad_output, input, grid, index_batch, + interpolation_mode, padding_mode, align_corners, output_mask); + return std::make_tuple(grad_input, grad_grid); +} + +std::tuple +torch_gs_3d_bw_cuda(const torch::Tensor& grad_output, const torch::Tensor& input, + const torch::Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask) +{ + auto input_requires_grad = output_mask[0]; + torch::Tensor grad_input = ([&]() { + if (input_requires_grad) + { + return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + else + { + return torch::Tensor(); + } + })(); + auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + launch_torch_gs_3d_bw_kernel(grad_input, grad_grid, grad_output, input, grid, interpolation_mode, + padding_mode, align_corners, output_mask); + return std::make_tuple(grad_input, grad_grid); +} diff --git a/pointbev/ops/gs/cuda/neighborhood.cpp b/pointbev/ops/gs/cuda/neighborhood.cpp new file mode 100644 index 0000000..d632dfd --- /dev/null +++ b/pointbev/ops/gs/cuda/neighborhood.cpp @@ -0,0 +1,58 @@ +#include +#include +#include +#include + +using namespace torch::indexing; + +std::tuple find_indices_cuda(const torch::Tensor& index_activ, + const torch::Tensor& img_mask, + std::tuple ws, + bool only_last_z) +{ + TORCH_CHECK(index_activ.is_cuda(), " must be a CUDA tensor"); + TORCH_CHECK(index_activ.is_contiguous(), " must be contiguous"); + AT_ASSERTM(index_activ.dim() == 2, "index_activ must be a 2D tensor: (Nactiv,4)"); + AT_ASSERTM(index_activ.size(1) == 4, "index_activ must be a 2D tensor: (Nactiv,4)"); + + TORCH_CHECK(img_mask.is_cuda(), " must be a CUDA tensor"); + TORCH_CHECK(img_mask.is_contiguous(), " must be contiguous"); + AT_ASSERTM(img_mask.dim() == 4, "index_activ must be a 4D tensor: (b,Z,X,Y)"); + + // index_activ: (Nactiv,4) + const int nactiv = index_activ.size(0); + + // Maximum comparisons: + int ws_z = std::get<0>(ws); + int ws_x = std::get<1>(ws); + int ws_y = std::get<2>(ws); + const int ws_prod = ws_z * ws_x * ws_y; + + // Outputs + torch::Tensor index_q = + torch::full({nactiv * ws_prod}, -1, index_activ.options().dtype(torch::kInt64)); + torch::Tensor index_k = + torch::full({nactiv * ws_prod}, -1, index_activ.options().dtype(torch::kInt64)); + + // Change mask to a matrix containing order at the location of the activated points. + const int Z = img_mask.size(1); + const int X = img_mask.size(2); + const int Y = img_mask.size(3); + torch::Tensor nonZeroIndices = + index_activ.index({Slice(), 0}) * Z * X * Y + index_activ.index({Slice(), 1}) * X * Y + + index_activ.index({Slice(), 2}) * Y + index_activ.index({Slice(), 3}); + + // Activated indices now contained their order in the list of activated points. + torch::Tensor arangeTensor = torch::arange(1, nactiv + 1, img_mask.options()); + img_mask.flatten().index_put_({nonZeroIndices}, arangeTensor); + + // Launch kernel + launch_find_indices_kernel(index_q, index_k, index_activ, img_mask, ws, only_last_z); + + // Reinitialize mask to a matrix containing 1 at the location of the activated points. + torch::Tensor oneTensor = torch::ones(nactiv, img_mask.options()); + img_mask.flatten().index_put_({nonZeroIndices}, oneTensor); + + auto idx_keep = index_q != -1; + return std::make_tuple(index_q.masked_select(idx_keep), index_k.masked_select(idx_keep)); +}; \ No newline at end of file diff --git a/pointbev/ops/gs/cuda/neighbourhood_cuda.cu b/pointbev/ops/gs/cuda/neighbourhood_cuda.cu new file mode 100644 index 0000000..66fbd71 --- /dev/null +++ b/pointbev/ops/gs/cuda/neighbourhood_cuda.cu @@ -0,0 +1,111 @@ +#include +#include +#include +#include + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void find_indices_kernel( + int64_t* __restrict__ index_q, int64_t* __restrict__ index_k, + const torch::PackedTensorAccessor64 index_activ, + const torch::PackedTensorAccessor64 img_mask, + std::tuple ws, bool only_last_z) +{ + // index_q: [Nout] + int ws_z = std::get<0>(ws); + int ws_x = std::get<1>(ws); + int ws_y = std::get<2>(ws); + int64_t ws_prod = ws_z * ws_x * ws_y; + int64_t nactiv = index_activ.size(0); + scalar_t Z = img_mask.size(1); + scalar_t X = img_mask.size(2); + scalar_t Y = img_mask.size(3); + + // Threads + int64_t elem = threadIdx.x + blockIdx.x * blockDim.x; + + if (elem >= nactiv) + { + return; + } + + scalar_t range_z = static_cast((ws_z - 1) / 2); + scalar_t range_x = static_cast((ws_x - 1) / 2); + scalar_t range_y = static_cast((ws_y - 1) / 2); + + scalar_t q_id_b = index_activ[elem][0]; + scalar_t q_id_z = index_activ[elem][1]; + scalar_t q_id_x = index_activ[elem][2]; + scalar_t q_id_y = index_activ[elem][3]; + if (only_last_z && (q_id_z != (Z - 1))) + { + return; + } + + int64_t cnt = 0; + for (int64_t iz = -range_z; iz <= range_z; iz++) + { + for (int64_t ix = -range_x; ix <= range_x; ix++) + { + for (int64_t iy = -range_y; iy <= range_y; iy++) + { + int32_t k_id_z = q_id_z + iz; + int32_t k_id_x = q_id_x + ix; + int32_t k_id_y = q_id_y + iy; + + if ((k_id_z < 0) || (k_id_x < 0) || (k_id_y < 0) || (k_id_z >= Z) || (k_id_x >= X) || + (k_id_y >= Y)) + { + cnt++; + continue; + } + + if (img_mask[q_id_b][k_id_z][k_id_x][k_id_y] != 0) + { + // Outputs as [Nout,3] + // index_q[elem * ws_2 + cnt][0] = q_id_b; + // index_q[elem * ws_2 + cnt][1] = q_id_x; + // index_q[elem * ws_2 + cnt][2] = q_id_y; + + // index_k[elem * ws_2 + cnt][0] = q_id_b; + // index_k[elem * ws_2 + cnt][1] = k_id_x; + // index_k[elem * ws_2 + cnt][2] = k_id_y; + + index_q[elem * ws_prod + cnt] = elem; + index_k[elem * ws_prod + cnt] = img_mask[q_id_b][k_id_z][k_id_x][k_id_y] - 1; + } + + cnt++; + } + } + } + return; +}; + +void launch_find_indices_kernel(torch::Tensor index_q, torch::Tensor index_k, + const torch::Tensor& index_activ, const torch::Tensor& img_mask, + std::tuple ws, bool only_last_z) +{ + const int64_t threads = 512; + + // index_activ: [Nactiv, 4] + const int64_t nactiv = index_activ.size(0); + const int64_t count = nactiv; + + if (count > 0) + { + AT_DISPATCH_INTEGRAL_TYPES( + index_activ.scalar_type(), "find_indices_kernel", + ( + [&] + { + find_indices_kernel<<>>( + index_q.data_ptr(), index_k.data_ptr(), + index_activ.packed_accessor64(), + img_mask.packed_accessor64(), ws, + only_last_z); + })); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +}; \ No newline at end of file diff --git a/pointbev/ops/gs/functions/__init__.py b/pointbev/ops/gs/functions/__init__.py new file mode 100644 index 0000000..7e89c3d --- /dev/null +++ b/pointbev/ops/gs/functions/__init__.py @@ -0,0 +1 @@ +from .gs import sparsed_grid_sample, torch_grid_sample diff --git a/pointbev/ops/gs/functions/gs.py b/pointbev/ops/gs/functions/gs.py new file mode 100644 index 0000000..8bd4f57 --- /dev/null +++ b/pointbev/ops/gs/functions/gs.py @@ -0,0 +1,122 @@ +import pdb + +import torch + +import sparse_gs # isort: skip + + +class SparsedGridSampleFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input, + grid, + index_batch, + interpolation_mode, + padding_mode, + align_corners, + ): + if len(input.shape) == 5: + func = sparse_gs.forward_sparse + elif len(input.shape) == 4: + func = sparse_gs.forward_2d_sparse + out = func( + input, + grid, + index_batch, + interpolation_mode, + padding_mode, + align_corners, + ) + ctx.save_for_backward(*[input, grid, index_batch]) + + ctx.interpolation_mode = interpolation_mode + ctx.padding_mode = padding_mode + ctx.align_corners = align_corners + return out + + @staticmethod + def backward(ctx, grad_output): + input, grid, index_batch = ctx.saved_tensors + output_mask = (ctx.needs_input_grad[0], ctx.needs_input_grad[1]) + if len(input.shape) == 5: + func = sparse_gs.backward_sparse + elif len(input.shape) == 4: + func = sparse_gs.backward_2d_sparse + grad_input, grad_grid = func( + grad_output.contiguous(), + input, + grid, + index_batch, + ctx.interpolation_mode, + ctx.padding_mode, + ctx.align_corners, + output_mask, + ) + return grad_input, grad_grid, None, None, None, None + + +class TorchGridSampleFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + ): + out = sparse_gs.forward_torch( + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + ) + ctx.save_for_backward(*[input, grid]) + + ctx.interpolation_mode = interpolation_mode + ctx.padding_mode = padding_mode + ctx.align_corners = align_corners + return out + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + output_mask = (ctx.needs_input_grad[0], ctx.needs_input_grad[1]) + grad_input, grad_grid = sparse_gs.backward_sparse( + grad_output.contiguous(), + input, + grid, + ctx.interpolation_mode, + ctx.padding_mode, + ctx.align_corners, + output_mask, + ) + return grad_input, grad_grid, None, None, None + + +# Functions +def sparsed_grid_sample( + input, grid, index_batch, interpolation_mode=0, padding_mode=0, align_corners=False +): + return SparsedGridSampleFunction.apply( + input, + grid, + index_batch, + interpolation_mode, + padding_mode, + align_corners, + ) + + +def torch_grid_sample( + input, grid, interpolation_mode=0, padding_mode=0, align_corners=False +): + return TorchGridSampleFunction.apply( + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + ) diff --git a/pointbev/ops/gs/include/check.h b/pointbev/ops/gs/include/check.h new file mode 100644 index 0000000..c9969e5 --- /dev/null +++ b/pointbev/ops/gs/include/check.h @@ -0,0 +1,6 @@ +#include +#include +#pragma once + +using at::native::detail::GridSamplerInterpolation; +using at::native::detail::GridSamplerPadding; diff --git a/pointbev/ops/gs/include/gs.h b/pointbev/ops/gs/include/gs.h new file mode 100644 index 0000000..aa19fcc --- /dev/null +++ b/pointbev/ops/gs/include/gs.h @@ -0,0 +1,82 @@ +#pragma once +#include +#include + +using at::native::detail::GridSamplerInterpolation; +using at::native::detail::GridSamplerPadding; + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +// --------------------------------------- +// Kernels +// --------------------------------------- +torch::Tensor sparsed_gs_3d_fw_cuda(const torch::Tensor& input, const torch::Tensor& grid, + const torch::Tensor& index_batch, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners); + +torch::Tensor sparsed_gs_2d_fw_cuda(const torch::Tensor& input, const torch::Tensor& grid, + const torch::Tensor& index_batch, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners); + +torch::Tensor torch_gs_3d_fw_cuda(const torch::Tensor& input, const torch::Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + +// Backward +std::tuple +sparsed_gs_3d_bw_cuda(const torch::Tensor& grad_output, const torch::Tensor& input, + const torch::Tensor& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners, + std::array output_mask); + +std::tuple +sparsed_gs_2d_bw_cuda(const torch::Tensor& grad_output, const torch::Tensor& input, + const torch::Tensor& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners, + std::array output_mask); + +std::tuple +torch_gs_3d_bw_cuda(const torch::Tensor& grad_output, const torch::Tensor& input, + const torch::Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); + +// --------------------------------------- +// Launchers +// --------------------------------------- +void launch_sparsed_gs_3d_fw_kernel(const at::TensorBase& output, const at::TensorBase& input, + const at::TensorBase& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + +void launch_sparsed_gs_2d_fw_kernel(const at::TensorBase& output, const at::TensorBase& input, + const at::TensorBase& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners); + +void launch_torch_gs_3d_fw_kernel(const at::TensorBase& output, const at::TensorBase& input, + const at::TensorBase& grid, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners); + +// Backward +void launch_sparsed_gs_3d_bw_kernel(const at::TensorBase& grad_input, + const at::TensorBase& grad_grid, + const at::TensorBase& grad_output, const at::TensorBase& input, + const at::TensorBase& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); +void launch_sparsed_gs_2d_bw_kernel(const at::TensorBase& grad_input, + const at::TensorBase& grad_grid, + const at::TensorBase& grad_output, const at::TensorBase& input, + const at::TensorBase& grid, const torch::Tensor& index_batch, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); + +void launch_torch_gs_3d_bw_kernel(const at::TensorBase& grad_input, const at::TensorBase& grad_grid, + const at::TensorBase& grad_output, const at::TensorBase& input, + const at::TensorBase& grid, int64_t interpolation_mode, + int64_t padding_mode, bool align_corners, + std::array output_mask); \ No newline at end of file diff --git a/pointbev/ops/gs/include/neighbourhood.h b/pointbev/ops/gs/include/neighbourhood.h new file mode 100644 index 0000000..9b6992e --- /dev/null +++ b/pointbev/ops/gs/include/neighbourhood.h @@ -0,0 +1,10 @@ +#include + +std::tuple find_indices_cuda(const torch::Tensor& index_activ, + const torch::Tensor& img_mask, + std::tuple ws, + bool only_last_z); + +void launch_find_indices_kernel(torch::Tensor index_q, torch::Tensor index_k, + const torch::Tensor& index_activ, const torch::Tensor& img_mask, + std::tuple ws, bool only_last_z); \ No newline at end of file diff --git a/pointbev/ops/gs/pybind.cpp b/pointbev/ops/gs/pybind.cpp new file mode 100644 index 0000000..ddd0dc1 --- /dev/null +++ b/pointbev/ops/gs/pybind.cpp @@ -0,0 +1,15 @@ +#include + +#include "gs.h" +#include "neighbourhood.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("forward_torch", &torch_gs_3d_fw_cuda); + m.def("backward_torch", &torch_gs_3d_bw_cuda); + m.def("forward_sparse", &sparsed_gs_3d_fw_cuda); + m.def("backward_sparse", &sparsed_gs_3d_bw_cuda); + m.def("forward_2d_sparse", &sparsed_gs_2d_fw_cuda); + m.def("backward_2d_sparse", &sparsed_gs_2d_bw_cuda); + m.def("find_indices", &find_indices_cuda); +} \ No newline at end of file diff --git a/pointbev/ops/gs/scripts/clean.sh b/pointbev/ops/gs/scripts/clean.sh new file mode 100644 index 0000000..735604b --- /dev/null +++ b/pointbev/ops/gs/scripts/clean.sh @@ -0,0 +1,6 @@ +rm -rf ./build +rm -rf ./dist +rm -rf ./sparse_gs.egg-info +rm -rf __pycache__ +rm -rf .pytest_cache +rm -rf ~/micromamba/envs/bevsegm/lib/python3.11/site-packages/sparse_gs-1.0-py3.11-linux-x86_64.egg/ \ No newline at end of file diff --git a/pointbev/ops/gs/scripts/setup.sh b/pointbev/ops/gs/scripts/setup.sh new file mode 100644 index 0000000..f3de4c4 --- /dev/null +++ b/pointbev/ops/gs/scripts/setup.sh @@ -0,0 +1,4 @@ +zsh scripts/clean.sh +clear +python setup.py build install +python -c 'import torch; import sparse_gs; print(sparse_gs.forward_torch); print(sparse_gs.backward_torch); print(sparse_gs.forward_sparse); print(sparse_gs.backward_sparse); print(sparse_gs.forward_2d_sparse); print(sparse_gs.backward_2d_sparse)' \ No newline at end of file diff --git a/pointbev/ops/gs/setup.py b/pointbev/ops/gs/setup.py new file mode 100644 index 0000000..09fbc2d --- /dev/null +++ b/pointbev/ops/gs/setup.py @@ -0,0 +1,31 @@ +import glob +import os.path as osp + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +ROOT_DIR = osp.dirname(osp.abspath(__file__)) +include_dirs = [osp.join(ROOT_DIR, "include")] + +sources = glob.glob("*.cpp") + glob.glob("*.cu") +for dir in ["cuda", "cpu"]: + sources += glob.glob(f"{dir}/*.cpp") + sources += glob.glob(f"{dir}/*.cu") + +print(sources) +setup( + name="sparse_gs", + version="1.0", + author="Loick Chambon", + author_email="loick.chambon@valeo.com", + description="Sparse grid sampling.", + ext_modules=[ + CUDAExtension( + name="sparse_gs", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={"cxx": ["-O2"], "nvcc": ["-O2"]}, + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/pointbev/ops/gs/tests/__init__.py b/pointbev/ops/gs/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pointbev/ops/gs/tests/fixtures.py b/pointbev/ops/gs/tests/fixtures.py new file mode 100644 index 0000000..2f1fdc6 --- /dev/null +++ b/pointbev/ops/gs/tests/fixtures.py @@ -0,0 +1,31 @@ +import pytest +import torch + +torch.set_printoptions(precision=12) +N, S, C, D, H, W = 1, 6, 128, 1, 28, 60 +Do, Ho, Wo = 1, 1, 200 * 200 * 8 + + +@pytest.fixture(params=[0.0, 0.16, 0.25, 0.5, 1.0]) +def set_pct_mask(request): + return request.param + + +@pytest.fixture +def input_data(set_pct_mask): + device = "cuda" + + rand = torch.randn(N, S, C, D, H, W, device=device) + grid = torch.rand(N, S, Do, Ho, Wo, 3, device=device, requires_grad=False) * 2 - 1 + mask = ( + torch.rand(N, S, Do, Ho, Wo, device=device, requires_grad=False) + .le(set_pct_mask) + .float() + ) + + return ( + rand, + grid, + mask, + {"pct": f"mask:{set_pct_mask}"}, + ) diff --git a/pointbev/ops/gs/tests/test_sparse.py b/pointbev/ops/gs/tests/test_sparse.py new file mode 100644 index 0000000..367cdb8 --- /dev/null +++ b/pointbev/ops/gs/tests/test_sparse.py @@ -0,0 +1,162 @@ +import pdb +import time +from math import prod +from sys import path + +import pytest +import torch + +from tests.fixtures import * + +from .utils import * + +path.insert(0, "../") +from functions import sparsed_grid_sample, torch_grid_sample + +N_REPEAT = 1000 +W_MASK = False +W_MODULE = False + + +# @memory_decorator +@time_decorator(N_REPEAT) +def gs_torch_fw(is_2d, input, grid, mask): + if W_MODULE: + out = torch_grid_sample(input, grid, 0, 0, False) + else: + out = torch.nn.functional.grid_sample(input, grid, "bilinear", "zeros", False) + if W_MASK: + if is_2d: + out = rearrange(out, "b c h w -> b h w c")[mask.bool()] + else: + out = rearrange(out, "b c d h w -> b d h w c")[mask.bool()] + return out + + +# @memory_decorator +@time_decorator(N_REPEAT) +def gs_torch_bw(is_2d, input, grid, mask): + out = torch.nn.functional.grid_sample(input, grid, "bilinear", "zeros", False) + if W_MASK: + if is_2d: + out = rearrange(out, "b c h w -> b h w c")[mask.bool()] + else: + out = rearrange(out, "b c d h w -> b d h w c")[mask.bool()] + out.mean().backward(retain_graph=True) + return out + + +def gs_torch(is_2d, input, grid, mask): + out_fw = gs_torch_fw(is_2d, input, grid, mask) + out_bw = gs_torch_bw(is_2d, input, grid, mask) + return out_fw, out_bw + + +# @memory_decorator +@time_decorator(N_REPEAT) +def gs_sparse_pckg_fw(input, grid, index_batch): + out = sparsed_grid_sample(input, grid, index_batch, 0, 0, False) + return out + + +# @memory_decorator +@time_decorator(N_REPEAT) +def gs_sparse_pckg_bw(input, grid, index_batch): + out = sparsed_grid_sample(input, grid, index_batch, 0, 0, False) + out.mean().backward(retain_graph=True) + return out + + +def gs_sparse_pckg(input, grid, index_batch): + out_pckg_fw = gs_sparse_pckg_fw(input, grid, index_batch) + out_pckg_bw = gs_sparse_pckg_bw(input, grid, index_batch) + return out_pckg_fw, out_pckg_bw + + +@pytest.mark.parametrize("is_2d, w_torch, w_pckg", [(True, True, True)]) +def test_compare_sparse(input_data, is_2d, w_torch, w_pckg): + rand, grid, mask, dict_ = input_data + + # Inputs + if is_2d: + rand = rand[:, :, :, 0].contiguous() + grid = grid[..., 0, :, :2].contiguous() + mask = mask[..., 0, :].contiguous() + if w_torch: + inp_torch, grid_torch, grid_mask_torch = get_input(rand, grid, mask) + if w_pckg: + inp_pckg, grid_pckg, grid_mask_pckg = get_input(rand, grid, mask) + index_batch = torch.arange( + inp_pckg.shape[0], device=inp_pckg.device, dtype=torch.int16 + ).repeat_interleave(prod(grid_pckg.shape[1:-1])) + index_batch = index_batch[grid_mask_pckg.view(-1).bool()].contiguous() + grid_pckg_ = grid_pckg.view(-1, 3 if not is_2d else 2)[ + grid_mask_pckg.view(-1).bool() + ].contiguous() + reset_mem() + + # fmt: off + print(f"\nPct: {dict_['pct']}") + if w_torch: + if is_2d: + assert not W_MODULE + out_torch_fw, out_torch_bw = gs_torch(is_2d, inp_torch,grid_torch,grid_mask_torch) + if w_pckg: + out_sparse_fw, out_sparse_bw = gs_sparse_pckg(inp_pckg,grid_pckg_,index_batch) + print('\n================================') + # fmt: on + + if w_torch and w_pckg and W_MASK: + assert out_torch_fw.shape == out_sparse_fw.shape + assert torch.equal(out_torch_fw, out_sparse_fw) + + assert out_torch_bw.shape == out_sparse_bw.shape + assert torch.equal(out_torch_bw, out_sparse_bw) + + assert (inp_torch.grad is not None) and (inp_pckg.grad is not None) + assert torch.allclose(inp_torch.grad, inp_pckg.grad) + + assert (grid_torch.grad is not None) and (grid_pckg.grad is not None) + assert torch.allclose(grid_torch.grad, grid_pckg.grad) + + return + + +@pytest.mark.parametrize( + "dtype_inp_grid, dtype_index", + [ + ("half", "half"), + ("half", "int"), + ("half", "long"), + ("float", "half"), + ("float", "int"), + ("float", "long"), + ], +) +def test_dtype(dtype_inp_grid, dtype_index): + # N, C, D, H, W, Npts = 64, 128, 1, 28, 60, 18_000_000 + N, C, D, H, W, Npts = 1, 128, 1, 28, 60, 10 + + if dtype_inp_grid == "float": + dtype_inp_grid = torch.float32 + elif dtype_inp_grid == "double": + dtype_inp_grid = torch.float64 + elif dtype_inp_grid == "half": + dtype_inp_grid = torch.float16 + + if dtype_index == "int": + dtype_index = torch.int32 + elif dtype_index == "long": + dtype_index = torch.int64 + elif dtype_index == "half": + dtype_index = torch.int16 + + inp_pckg = torch.randn( + N, C, D, H, W, device="cuda", dtype=dtype_inp_grid, requires_grad=True + ) + grid_pckg_ = torch.rand(Npts, C, device="cuda", dtype=dtype_inp_grid) + index_batch = torch.randint(0, N, (Npts,), device="cuda", dtype=dtype_index) + + gs_sparse_pckg(inp_pckg, grid_pckg_, index_batch) + del inp_pckg, grid_pckg_, index_batch + reset_mem() diff --git a/pointbev/ops/gs/tests/utils.py b/pointbev/ops/gs/tests/utils.py new file mode 100644 index 0000000..2f93399 --- /dev/null +++ b/pointbev/ops/gs/tests/utils.py @@ -0,0 +1,79 @@ +import functools +import time + +import torch +from einops import rearrange + + +def time_decorator(N): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + nonlocal N + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + delta = 0 + for _ in range(N): + start.record() + result = func(*args, **kwargs) + end.record() + torch.cuda.synchronize() + if _ <= 5: # Warmup + pass + else: + delta += start.elapsed_time(end) * 1e-3 + + print(f"{func.__name__} took {delta:.4f} seconds") + return result + + return wrapper + + return decorator + + +def memory_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + reset_mem() + result = func(*args, **kwargs) + print(f"Max mem {func.__name__}: {torch.cuda.memory_allocated()/(2**30):.4f}") + return result + + return wrapper + + +def reset_mem(): + torch.cuda.torch.cuda.reset_peak_memory_stats("cuda") + torch.cuda.torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + +def get_input(rand, grid, mask=None): + """Arange input shape to be compatible with the gs.""" + # Inputs + input = rand.flatten(0, 1).clone().requires_grad_() + grid = grid.flatten(0, 1).requires_grad_() + mask = mask.flatten(0, 1) if mask is not None else None + return input, grid, mask + + +def set_inp_to_inf(inp, inp_mask, mask): + NS, C, D, H, W = inp.shape + + inp = rearrange(inp, "ns c d h w -> (ns d h w) c") + inp_mask = rearrange(inp_mask, "ns d h w -> (ns d h w)") + inp = inp.detach() + inp[~inp_mask.bool()] = torch.inf + inp = rearrange(inp, "(ns d h w) c -> ns c d h w", ns=NS, c=C, d=D, h=H, w=W) + # mask = torch.where(mask == 1, mask, torch.nan) + mask = torch.where(mask == 1, mask, torch.inf) + + return inp.clone().requires_grad_(), mask + + +def set_inf_to_zero(out): + out = torch.nan_to_num(out, 0.0) + out = torch.where(out != torch.inf, out, 0.0) + out = torch.where(out < 1e38, out, 0.0) + out = torch.where(out > -1e38, out, 0.0) + return out