From 20a8122bf44a81975c856750ae92a5d1017284cb Mon Sep 17 00:00:00 2001 From: Atticus Zhou <119986792+Atticuszz@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:45:00 +0800 Subject: [PATCH] fix: Update gemoetry.py --- src/pose_estimation/gemoetry.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/pose_estimation/gemoetry.py b/src/pose_estimation/gemoetry.py index b05a470..df94b89 100644 --- a/src/pose_estimation/gemoetry.py +++ b/src/pose_estimation/gemoetry.py @@ -68,19 +68,22 @@ def unproject_depth(pcd: Tensor, pose: Tensor, intrinsics: Tensor) -> Tensor: torch.Tensor The depth image created from the point cloud with dimensions [height, width]. """ - pcd_camera = torch.einsum("hwj,jk->hwk", pcd, torch.inverse(pose)) - - x = pcd_camera[..., 0] - y = pcd_camera[..., 1] - z = pcd_camera[..., 2] - u = (x / z) * intrinsics[0, 0] + intrinsics[0, 2] - v = (y / z) * intrinsics[1, 1] + intrinsics[1, 2] - - grid = torch.stack((u / z.shape[1], v / z.shape[0]), dim=-1) * 2 - 1 # Normalize to [-1, 1] - depth = z.unsqueeze(0).unsqueeze(1) - - projected_depth = F.grid_sample(depth, grid, mode='bilinear', padding_mode='zeros', align_corners=False) - return projected_depth.squeeze() + # Inverse transform to world coordinates (assuming pcd includes homogenous coordinate) + homogenous_world = torch.einsum('hwj,jk->hwk', pcd[..., 2:], torch.inverse(pose)) + + # Convert homogenous coordinates back to Euclidean coordinates in the camera frame + xyz_camera = homogenous_world[..., :3] / (homogenous_world[..., 3:4] + 1e-10) # Adding epsilon to avoid division by zero + + # Project points using intrinsics + projected = torch.einsum('ij,hwj->hwi', intrinsics, xyz_camera) + u = projected[:, :, 0] / (projected[:, :, 2] + 1e-10) # u coordinate + v = projected[:, :, 1] / (projected[:, :, 2] + 1e-10) # v coordinate + + # Map (u, v) back to depth values + # Since depth values are typically in z, directly extract z from the transformed coordinates + depth_image = xyz_camera[:, :, 2] + + return depth_image def compute_silhouette_diff(depth: Tensor, rastered_depth: Tensor) -> Tensor: