From c9082db821ef56042e7e14c43c3ea15c0b43e23c Mon Sep 17 00:00:00 2001 From: Atticus Zhou <119986792+Atticuszz@users.noreply.github.com> Date: Sun, 9 Jun 2024 20:27:32 +0800 Subject: [PATCH] fix: model.py Merge point clouds using a differentiable method --- src/pose_estimation/model.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/pose_estimation/model.py b/src/pose_estimation/model.py index c275e68..f0a78ee 100644 --- a/src/pose_estimation/model.py +++ b/src/pose_estimation/model.py @@ -72,19 +72,17 @@ def forward(self, depth_last, depth_current, i: int | None): pcd_last = project_depth(normalized_depth_last, pose_last, self.intrinsics) pcd_current = project_depth(normalized_depth_current, pose_cur, self.intrinsics) - # projected to depth - projected_depth_last = unproject_depth(pcd_last, pose_cur, self.intrinsics) - projected_depth_current = unproject_depth( - pcd_current, pose_cur, self.intrinsics - ) - - # combined - combined_projected_depth = torch.min( - projected_depth_last, projected_depth_current - ) - combined_projected_depth[combined_projected_depth == 0] = torch.max( - projected_depth_last, projected_depth_current - )[combined_projected_depth == 0] + # Merge point clouds using a differentiable method + # Assuming pcd_last and pcd_current are [H, W, 4] + valid_last = (pcd_last[..., 2] > 0).float() # Depth should be greater than 0 + valid_current = (pcd_current[..., 2] > 0).float() + weights_last = valid_last / (valid_last + valid_current + 1e-6) + weights_current = valid_current / (valid_last + valid_current + 1e-6) + + merged_pcd = weights_last.unsqueeze(-1) * pcd_last + weights_current.unsqueeze(-1) * pcd_current + + # Project the merged point cloud back to depth map + combined_projected_depth = unproject_depth(merged_pcd, pose_cur, self.intrinsics) # NOTE: Calculate depth loss depth_loss = compute_depth_loss(