Skip to content

Commit

Permalink
Fix backward of SH in CUDA.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander0Yang committed Sep 16, 2024
1 parent 900c592 commit 3fe460f
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions diff-gaussian-rasterization/cuda_rasterizer/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,23 @@ __device__ void computeColorFromSH_4D(int idx, int deg, int deg_t, int max_coeff
float t1 = cos(2 * MY_PI * dir_t / time_duration);
float dt1_dt = sin(2 * MY_PI * dir_t / time_duration) * 2 * MY_PI / time_duration;

dL_dsh[16] = t1 * l0m0 * dL_dRGB;
dL_dsh[17] = t1 * l1m1 * dL_dRGB;
dL_dsh[18] = t1 * l1m0 * dL_dRGB;
dL_dsh[19] = t1 * l1p1 * dL_dRGB;
dL_dsh[20] = t1 * l2m2 * dL_dRGB;
dL_dsh[21] = t1 * l2m1 * dL_dRGB;
dL_dsh[22] = t1 * l2m0 * dL_dRGB;
dL_dsh[23] = t1 * l2p1 * dL_dRGB;
dL_dsh[24] = t1 * l2p2 * dL_dRGB;
dL_dsh[25] = t1 * l3m3 * dL_dRGB;
dL_dsh[26] = t1 * l3m2 * dL_dRGB;
dL_dsh[27] = t1 * l3m1 * dL_dRGB;
dL_dsh[28] = t1 * l3m0 * dL_dRGB;
dL_dsh[29] = t1 * l3p1 * dL_dRGB;
dL_dsh[30] = t1 * l3p2 * dL_dRGB;
dL_dsh[31] = t1 * l3p3 * dL_dRGB;

dRGBdt = dt1_dt * (
l0m0 * sh[16] +
l1m1 * sh[17] +
Expand Down Expand Up @@ -366,6 +383,23 @@ __device__ void computeColorFromSH_4D(int idx, int deg, int deg_t, int max_coeff
float t2 = cos(2 * MY_PI * dir_t * 2 / time_duration);
float dt2_dt = sin(2 * MY_PI * dir_t * 2 / time_duration) * 2 * MY_PI * 2 / time_duration;

dL_dsh[32] = t2 * l0m0 * dL_dRGB;
dL_dsh[33] = t2 * l1m1 * dL_dRGB;
dL_dsh[34] = t2 * l1m0 * dL_dRGB;
dL_dsh[35] = t2 * l1p1 * dL_dRGB;
dL_dsh[36] = t2 * l2m2 * dL_dRGB;
dL_dsh[37] = t2 * l2m1 * dL_dRGB;
dL_dsh[38] = t2 * l2m0 * dL_dRGB;
dL_dsh[39] = t2 * l2p1 * dL_dRGB;
dL_dsh[40] = t2 * l2p2 * dL_dRGB;
dL_dsh[41] = t2 * l3m3 * dL_dRGB;
dL_dsh[42] = t2 * l3m2 * dL_dRGB;
dL_dsh[43] = t2 * l3m1 * dL_dRGB;
dL_dsh[44] = t2 * l3m0 * dL_dRGB;
dL_dsh[45] = t2 * l3p1 * dL_dRGB;
dL_dsh[46] = t2 * l3p2 * dL_dRGB;
dL_dsh[47] = t2 * l3p3 * dL_dRGB;

dRGBdt = dt2_dt * (
l0m0 * sh[32] +
l1m1 * sh[33] +
Expand Down

0 comments on commit 3fe460f

Please sign in to comment.