From c65a49f223026c115ee1edcbb950340301601b77 Mon Sep 17 00:00:00 2001 From: Atticuszz <1831768457@qq.com> Date: Sat, 28 Dec 2024 17:36:42 +0800 Subject: [PATCH] update code --- scripts/run_eval.sh | 8 +-- src/GsplatLoc_eval.py | 2 +- src/component/visualize.py | 129 ++++++++++++++++++++++++------------ src/data/dataset.py | 2 +- src/my_gsplat/gs_trainer.py | 74 +++++++++++---------- 5 files changed, 131 insertions(+), 84 deletions(-) diff --git a/scripts/run_eval.sh b/scripts/run_eval.sh index 8ddda60..3d67dad 100644 --- a/scripts/run_eval.sh +++ b/scripts/run_eval.sh @@ -7,10 +7,10 @@ run_training() { } cd ../src || echo "failed to cd to ../src dir !" # Replica dataset -#run_training Replica room0 room1 -#run_training Replica room2 office0 -#run_training Replica office1 office2 -#run_training Replica office3 office4 +run_training Replica room0 room1 +run_training Replica room2 office0 +run_training Replica office1 office2 +run_training Replica office3 office4 # TUM dataset run_training TUM freiburg1_desk freiburg1_desk2 diff --git a/src/GsplatLoc_eval.py b/src/GsplatLoc_eval.py index bb94d20..f416424 100644 --- a/src/GsplatLoc_eval.py +++ b/src/GsplatLoc_eval.py @@ -4,7 +4,7 @@ from src.eval.experiment import WandbConfig from src.eval.utils import set_random_seed -from src.my_gsplat.gs_trainer_total import Runner +from src.my_gsplat.gs_trainer import Runner sys.path.append("..") set_random_seed(42) diff --git a/src/component/visualize.py b/src/component/visualize.py index 553926e..68837c9 100644 --- a/src/component/visualize.py +++ b/src/component/visualize.py @@ -1,4 +1,6 @@ from pathlib import Path +import time +from typing import List, Optional import numpy as np import open3d as o3d @@ -11,7 +13,6 @@ class PcdVisualizer: - def __init__(self, intrinsic_matrix: NDArray[np.int32], view_scale=1.0): self.o3d_vis = Visualizer() self.o3d_vis.create_window( @@ -30,77 +31,119 @@ def __init__(self, intrinsic_matrix: NDArray[np.int32], view_scale=1.0): cy=intrinsic_matrix[1, 2] * view_scale, ) self.camera_params.intrinsic = intrinsic - render_option = Path(__file__).parents[2] / "data" / "render_option.json" - self.o3d_vis.get_render_option().load_from_json(render_option.as_posix()) + render_option = Path(__file__).parents[2] / "datasets" / "render_option.json" + # self.o3d_vis.get_render_option().load_from_json(render_option.as_posix()) + + # Initialize trajectory variables + self.camera_positions: List[np.ndarray] = [] + self.camera_poses: List[np.ndarray] = [] + self.trajectory_line: Optional[o3d.geometry.LineSet] = None + self.colormap = plt.get_cmap("cool") + + self.line_width = 10.0 # 增加轨迹线的宽度 + self.o3d_vis.get_render_option().line_width = self.line_width def update_render( self, - new_pcd: NDArray[np.float32], - estimate_pose: NDArray[np.float32], - new_color: NDArray[np.float32] | None = None, - down_sample: float = 0.05, + new_pcd: np.ndarray, + estimate_pose: np.ndarray, + new_color: Optional[np.ndarray] = None, ): - # new_pcd = o3d.utility.Vector3dVector(new_pcd[:, :3]) - new_pcd = o3d.utility.Vector3dVector(new_pcd) - pcd_o3d = o3d.geometry.PointCloud(new_pcd) + # Create and add/update point cloud + pcd_o3d = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(new_pcd)) if new_color is not None: pcd_o3d.colors = o3d.utility.Vector3dVector(new_color) - - # pcd_o3d = pcd_o3d.random_down_sample(down_sample) - pcd_o3d.transform(estimate_pose) self.o3d_vis.add_geometry(pcd_o3d) self.o3d_vis.update_geometry(pcd_o3d) - self._follow_camera(estimate_pose) + + + if 0<=len(self.camera_positions)<=3: + # Update camera trajectory + self.camera_positions.append(estimate_pose[:3, 3]) + self.camera_poses.append(estimate_pose) + self.update_trajectory() + + # Follow camera + # self._follow_camera(estimate_pose) + + # Update visualization self.o3d_vis.poll_events() self.o3d_vis.update_renderer() - def _follow_camera(self, c2w: NDArray[np.float64]): - """Adjust the view control to follow a series of camera transformations.""" + def update_trajectory(self): - self.camera_params.extrinsic = np.linalg.inv( - c2w - ) # Convert c2w to camera intrinsic + # Update line set + lines = [[i, i + 1] for i in range(len(self.camera_positions) - 1)] + colors = [ + self.colormap(i / (len(self.camera_positions) - 1))[:3] + for i in range(len(lines)) + ] + + if self.trajectory_line is None: + self.trajectory_line = o3d.geometry.LineSet() + self.o3d_vis.add_geometry(self.trajectory_line) + + self.trajectory_line.points = o3d.utility.Vector3dVector( + self.camera_positions + ) + self.trajectory_line.lines = o3d.utility.Vector2iVector(lines) + self.trajectory_line.colors = o3d.utility.Vector3dVector(colors) + + # Add camera frustum and marker for the latest position + latest_pose = self.camera_poses[-1] + self.add_camera_frustum(latest_pose, [1, 0, 0]) # Blue for latest + # self.add_camera_marker(self.camera_positions[-1], [0, 0, 1]) + + self.o3d_vis.update_geometry(self.trajectory_line) + + def add_camera_frustum(self, pose: np.ndarray, color: List[float]): + # 将 c2w 转换为 w2c + w2c = np.linalg.inv(pose) + + frustum = o3d.geometry.LineSet.create_camera_visualization( + self.camera_params.intrinsic.width, + self.camera_params.intrinsic.height, + self.camera_params.intrinsic.intrinsic_matrix.astype(np.float64), + w2c.astype(np.float64), # 使用 w2c 而不是 pose + scale=0.1, + ) + frustum.paint_uniform_color(color) + self.o3d_vis.add_geometry(frustum) + + def add_camera_marker(self, position: np.ndarray, color: List[float]): + sphere = o3d.geometry.TriangleMesh.create_sphere( + radius=0.0001 + ) # 减小轨迹点的大小 + sphere.translate(position) + sphere.paint_uniform_color(color) + self.o3d_vis.add_geometry(sphere) + + def _follow_camera(self, c2w: np.ndarray): + self.camera_params.extrinsic = np.linalg.inv(c2w) self.view_control.convert_from_pinhole_camera_parameters( self.camera_params, allow_arbitrary=True ) + def run_visualization(self): + print("Visualization complete. Use mouse/keyboard to interact.") + print("Press Q or Esc to exit.") + self.o3d_vis.run() def close(self): self.o3d_vis.destroy_window() - def vis_trajectory( - self, - gt_poses: list[NDArray], - estimated_poses: list[NDArray], - downsampling_resolution: float, - fps: float, - ) -> None: - """Visualize the camera trajectory in 2D space.""" - gt_traj = np.array([pose[:3, 3] for pose in gt_poses]) - icp_traj = np.array([pose[:3, 3] for pose in estimated_poses]) - plt.clf() - plt.title(f"Downsample ratio {downsampling_resolution}\nfps : {fps:.2f}") - plt.plot(icp_traj[:, 0], icp_traj[:, 1], label="g-icp trajectory", linewidth=3) - plt.legend() - plt.plot(gt_traj[:, 0], gt_traj[:, 1], label="ground truth trajectory") - plt.legend() - plt.axis("equal") - plt.pause(0.01) - def visualize_dataset(data_set: BaseDataset): - vis = PcdVisualizer(intrinsic_matrix=data_set.K) - for i, rgbd_image in enumerate(data_set): - + for i, rgbd_image in enumerate(data_set[600:1500:20]): print(f"Processing image {i + 1}/{len(data_set)}...") vis.update_render( rgbd_image.points.cpu().numpy(), rgbd_image.pose.cpu().numpy(), new_color=rgbd_image.colors.cpu().numpy(), - down_sample=0.01, ) - + vis.run_visualization() + def visualize_trajectory(data_set: BaseDataset): """vis 3d trajectory""" diff --git a/src/data/dataset.py b/src/data/dataset.py index eea8a1a..9f0c213 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -344,7 +344,7 @@ def __init__( def __getitem__(self, index: int) -> AlignData: assert index < len(self._data) - tar, src = self._data[index], self._data[index + 1] + tar, src = self._data[index], self._data[index + 10] # transform to world tar.points = transform_points(tar.pose, tar.points) src.points = transform_points(tar.pose, src.points) diff --git a/src/my_gsplat/gs_trainer.py b/src/my_gsplat/gs_trainer.py index 30bc8c8..29ee602 100644 --- a/src/my_gsplat/gs_trainer.py +++ b/src/my_gsplat/gs_trainer.py @@ -1,6 +1,6 @@ import time from timeit import default_timer - +import torch.nn.functional as F import torch import tqdm @@ -29,7 +29,11 @@ def __init__( self.config = base_config # load data - self.parser = Parser(self.sub_set, normalize=wandb_config.normalize) + self.parser = Parser( + data_set=wandb_config.dataset, + name=self.sub_set, + normalize=wandb_config.normalize, + ) # Losses & Metrics. self.config.init_loss() @@ -100,18 +104,18 @@ def train(self): non_zero_depth_mask = (depths != 0).float() # # RGB L1 Loss - # l1loss = F.l1_loss( - # colors * non_zero_depth_mask, - # pixels * non_zero_depth_mask, - # reduction="sum", - # ) / (non_zero_depth_mask.sum() + 1e-8) + l1loss = F.l1_loss( + colors * non_zero_depth_mask, + pixels * non_zero_depth_mask, + reduction="sum", + ) / (non_zero_depth_mask.sum() + 1e-8) # - # # SSIM Loss - # ssim_value = self.config.ssim( - # (pixels * non_zero_depth_mask).permute(0, 3, 1, 2), - # (colors * non_zero_depth_mask).permute(0, 3, 1, 2), - # ) - # ssimloss = 1.0 - ssim_value + # SSIM Loss + ssim_value = self.config.ssim( + (pixels * non_zero_depth_mask).permute(0, 3, 1, 2), + (colors * non_zero_depth_mask).permute(0, 3, 1, 2), + ) + ssimloss = 1.0 - ssim_value # Depth Loss depth_loss = compute_depth_loss( @@ -146,12 +150,12 @@ def train(self): with torch.no_grad(): # loss self.logger.log_loss("total_loss", total_loss.item(), step=step) - # self.logger.log_loss( - # "pixels", l1loss.item(), step=step, l_type="l1" - # ) - # self.logger.log_loss( - # "pixels", ssimloss.item(), step=step, l_type="ssim" - # ) + self.logger.log_loss( + "pixels", l1loss.item(), step=step, l_type="l1" + ) + self.logger.log_loss( + "pixels", ssimloss.item(), step=step, l_type="ssim" + ) self.logger.log_loss( "depth", depth_loss.item(), step=step, l_type="l1" ) @@ -169,10 +173,10 @@ def train(self): ) # IMAGE if step % 100 == 0: - # psnr = self.config.psnr( - # (pixels * non_zero_depth_mask).permute(0, 3, 1, 2), - # (colors * non_zero_depth_mask).permute(0, 3, 1, 2), - # ) + psnr = self.config.psnr( + (pixels * non_zero_depth_mask).permute(0, 3, 1, 2), + (colors * non_zero_depth_mask).permute(0, 3, 1, 2), + ) self.logger.plot_rgbd( depths_gt[0, :, :, 0], depths[0, :, :, 0], @@ -180,22 +184,22 @@ def train(self): "type": "l1", "value": depth_loss.item(), }, - # color=train_data.pixels, - # rastered_color=colors.squeeze(0), - # color_loss={ - # "type": "psnr", - # "value": psnr.item(), - # }, + color=train_data.pixels, + rastered_color=colors.squeeze(0), + color_loss={ + "type": "psnr", + "value": psnr.item(), + }, # normal_loss={ # "type": "cosine", # "value": normal_loss.item(), # }, - normal=depth_to_normal( - depths_gt[0, :, :, 0], Ks.squeeze(0) - ), - rastered_normal=depth_to_normal( - depths[0, :, :, 0], Ks.squeeze(0) - ), + # normal=depth_to_normal( + # depths_gt[0, :, :, 0], Ks.squeeze(0) + # ), + # rastered_normal=depth_to_normal( + # depths[0, :, :, 0], Ks.squeeze(0) + # ), step=step, fig_title="gs_splats Visualization", )