Skip to content

Commit

Permalink
Record means 3D at current frame for backward.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander0Yang committed Sep 15, 2024
1 parent f561b7c commit 5094ca8
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 11 deletions.
6 changes: 6 additions & 0 deletions diff-gaussian-rasterization/cuda_rasterizer/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ __device__ void computeCov3D_conditional(const glm::vec3 scale, const float scal
template<int C>
__global__ void preprocessCUDA(int P, int D, int D_t, int M,
const float* orig_points,
float* out_means3D,
const float* ts,
const glm::vec3* scales,
const float* scales_t,
Expand Down Expand Up @@ -403,6 +404,9 @@ __global__ void preprocessCUDA(int P, int D, int D_t, int M,
rotations[idx], rotations_r[idx], cov3Ds + idx * 6, p_orig, ts[idx], timestamp, idx, time_mask, opacity);
if (!time_mask) return;
cov3D = cov3Ds + idx * 6;
out_means3D[idx*3+0]=p_orig.x;
out_means3D[idx*3+1]=p_orig.y;
out_means3D[idx*3+2]=p_orig.z;
}
else
{
Expand Down Expand Up @@ -641,6 +645,7 @@ void FORWARD::render(

void FORWARD::preprocess(int P, int D, int D_t, int M,
const float* means3D,
float* out_means3D,
const float* ts,
const glm::vec3* scales,
const float* scales_t,
Expand Down Expand Up @@ -674,6 +679,7 @@ void FORWARD::preprocess(int P, int D, int D_t, int M,
preprocessCUDA<NUM_CHANNELS> << <(P + 255) / 256, 256 >> > (
P, D, D_t, M,
means3D,
out_means3D,
ts,
scales,
scales_t,
Expand Down
1 change: 1 addition & 0 deletions diff-gaussian-rasterization/cuda_rasterizer/forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace FORWARD
// Perform initial steps for each Gaussian prior to rasterization.
void preprocess(int P, int D, int D_t, int M,
const float* orig_points,
float* out_means3D,
const float* ts,
const glm::vec3* scales,
const float* scales_t,
Expand Down
3 changes: 2 additions & 1 deletion diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace CudaRasterizer
const float* background,
const int width, int height,
const float* means3D,
float* out_means3D,
const float* shs,
const float* colors_precomp,
const float* flows_precomp,
Expand Down Expand Up @@ -66,7 +67,7 @@ namespace CudaRasterizer
const int P, int D, int D_t, int M, int R,
const float* background,
const int width, int height,
const float* means3D,
const float* out_means3D,
const float* shs,
const float* colors_precomp,
const float* flows_2d,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ int CudaRasterizer::Rasterizer::forward(
const float* background,
const int width, int height,
const float* means3D,
float* out_means3D,
const float* shs,
const float* colors_precomp,
const float* flows_precomp,
Expand Down Expand Up @@ -259,6 +260,7 @@ int CudaRasterizer::Rasterizer::forward(
CHECK_CUDA(FORWARD::preprocess(
P, D, D_t, M,
means3D,
out_means3D,
ts,
(glm::vec3*)scales,
scales_t,
Expand Down Expand Up @@ -365,7 +367,7 @@ void CudaRasterizer::Rasterizer::backward(
const int P, int D, int D_t, int M, int R,
const float* background,
const int width, int height,
const float* means3D,
const float* out_means3D,
const float* shs,
const float* colors_precomp,
const float* flows_2d,
Expand Down Expand Up @@ -455,7 +457,7 @@ void CudaRasterizer::Rasterizer::backward(
// use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
CHECK_CUDA(BACKWARD::preprocess(P, D, D_t, M,
(float3*)means3D,
(float3*)out_means3D,
radii,
shs,
ts,
Expand Down
10 changes: 7 additions & 3 deletions diff-gaussian-rasterization/rasterize_points.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
return lambda;
}

std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
Expand Down Expand Up @@ -81,6 +81,7 @@ RasterizeGaussiansCUDA(
torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts);
torch::Tensor out_T = torch::full({1, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Tensor out_means3D = means3D.clone();

torch::Device device(torch::kCUDA);
torch::TensorOptions options(torch::kByte);
Expand Down Expand Up @@ -108,6 +109,7 @@ RasterizeGaussiansCUDA(
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
out_means3D.contiguous().data<float>(),
sh.contiguous().data_ptr<float>(),
colors.contiguous().data<float>(),
flows.contiguous().data<float>(),
Expand Down Expand Up @@ -141,13 +143,14 @@ RasterizeGaussiansCUDA(
CudaRasterizer::GeometryState geoState = CudaRasterizer::GeometryState::fromChunk(geo_ptr, P);

torch::Tensor covs3D_com = torch::from_blob(geoState.cov3D, {P, 6}, float_opts);
return std::make_tuple(rendered, out_color, out_flow, out_depth, out_T, radii, geomBuffer, binningBuffer, imgBuffer, covs3D_com);
return std::make_tuple(rendered, out_color, out_flow, out_depth, out_T, radii, geomBuffer, binningBuffer, imgBuffer, covs3D_com, out_means3D);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& out_means3D,
const torch::Tensor& radii,
const torch::Tensor& colors,
const torch::Tensor& flows_2d,
Expand Down Expand Up @@ -211,7 +214,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
CudaRasterizer::Rasterizer::backward(P, degree, degree_t, M, R,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
// means3D.contiguous().data<float>(),
out_means3D.contiguous().data<float>(),
sh.contiguous().data<float>(),
colors.contiguous().data<float>(),
flows_2d.contiguous().data<float>(),
Expand Down
3 changes: 2 additions & 1 deletion diff-gaussian-rasterization/rasterize_points.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <tuple>
#include <string>

std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
Expand Down Expand Up @@ -51,6 +51,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
RasterizeGaussiansBackwardCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& out_means3D,
const torch::Tensor& radii,
const torch::Tensor& colors,
const torch::Tensor& flows_2d,
Expand Down
9 changes: 5 additions & 4 deletions gaussian_renderer/diff_gaussian_rasterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,18 @@ def forward(
if raster_settings.debug:
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
try:
num_rendered, color, flow, depth, T, radii, geomBuffer, binningBuffer, imgBuffer, covs_com = _C.rasterize_gaussians(*args)
num_rendered, color, flow, depth, T, radii, geomBuffer, binningBuffer, imgBuffer, covs_com, out_means3D = _C.rasterize_gaussians(*args)
except Exception as ex:
torch.save(cpu_args, "snapshot_fw.dump")
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
raise ex
else:
num_rendered, color, flow, depth, T, radii, geomBuffer, binningBuffer, imgBuffer, covs_com = _C.rasterize_gaussians(*args)
num_rendered, color, flow, depth, T, radii, geomBuffer, binningBuffer, imgBuffer, covs_com, out_means3D = _C.rasterize_gaussians(*args)

# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.num_rendered = num_rendered
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh,
ctx.save_for_backward(colors_precomp, means3D, out_means3D, scales, rotations, cov3Ds_precomp, radii, sh,
flow_2d, opacities, ts, scales_t, rotations_r,
geomBuffer, binningBuffer, imgBuffer)
return color, radii, depth, 1-T, flow, covs_com
Expand All @@ -140,13 +140,14 @@ def backward(ctx, grad_out_color, grad_radii, grad_depth, grad_alpha, grad_flow,
# Restore necessary values from context
num_rendered = ctx.num_rendered
raster_settings = ctx.raster_settings
(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh,
(colors_precomp, means3D, out_means3D, scales, rotations, cov3Ds_precomp, radii, sh,
flow_2d, opacities, ts, scales_t, rotations_r,
geomBuffer, binningBuffer, imgBuffer) = ctx.saved_tensors

# Restructure args as C++ method expects them
args = (raster_settings.bg,
means3D,
out_means3D,
radii,
colors_precomp,
flow_2d,
Expand Down

0 comments on commit 5094ca8

Please sign in to comment.