Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
AtticusZeller committed Dec 28, 2024
1 parent ced2397 commit c65a49f
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 84 deletions.
8 changes: 4 additions & 4 deletions scripts/run_eval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ run_training() {
}
cd ../src || echo "failed to cd to ../src dir !"
# Replica dataset
#run_training Replica room0 room1
#run_training Replica room2 office0
#run_training Replica office1 office2
#run_training Replica office3 office4
run_training Replica room0 room1
run_training Replica room2 office0
run_training Replica office1 office2
run_training Replica office3 office4

# TUM dataset
run_training TUM freiburg1_desk freiburg1_desk2
Expand Down
2 changes: 1 addition & 1 deletion src/GsplatLoc_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from src.eval.experiment import WandbConfig
from src.eval.utils import set_random_seed
from src.my_gsplat.gs_trainer_total import Runner
from src.my_gsplat.gs_trainer import Runner

sys.path.append("..")
set_random_seed(42)
Expand Down
129 changes: 86 additions & 43 deletions src/component/visualize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pathlib import Path
import time
from typing import List, Optional

import numpy as np
import open3d as o3d
Expand All @@ -11,7 +13,6 @@


class PcdVisualizer:

def __init__(self, intrinsic_matrix: NDArray[np.int32], view_scale=1.0):
self.o3d_vis = Visualizer()
self.o3d_vis.create_window(
Expand All @@ -30,77 +31,119 @@ def __init__(self, intrinsic_matrix: NDArray[np.int32], view_scale=1.0):
cy=intrinsic_matrix[1, 2] * view_scale,
)
self.camera_params.intrinsic = intrinsic
render_option = Path(__file__).parents[2] / "data" / "render_option.json"
self.o3d_vis.get_render_option().load_from_json(render_option.as_posix())
render_option = Path(__file__).parents[2] / "datasets" / "render_option.json"
# self.o3d_vis.get_render_option().load_from_json(render_option.as_posix())

# Initialize trajectory variables
self.camera_positions: List[np.ndarray] = []
self.camera_poses: List[np.ndarray] = []
self.trajectory_line: Optional[o3d.geometry.LineSet] = None
self.colormap = plt.get_cmap("cool")

self.line_width = 10.0 # 增加轨迹线的宽度
self.o3d_vis.get_render_option().line_width = self.line_width

def update_render(
self,
new_pcd: NDArray[np.float32],
estimate_pose: NDArray[np.float32],
new_color: NDArray[np.float32] | None = None,
down_sample: float = 0.05,
new_pcd: np.ndarray,
estimate_pose: np.ndarray,
new_color: Optional[np.ndarray] = None,
):
# new_pcd = o3d.utility.Vector3dVector(new_pcd[:, :3])
new_pcd = o3d.utility.Vector3dVector(new_pcd)
pcd_o3d = o3d.geometry.PointCloud(new_pcd)
# Create and add/update point cloud
pcd_o3d = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(new_pcd))
if new_color is not None:
pcd_o3d.colors = o3d.utility.Vector3dVector(new_color)

# pcd_o3d = pcd_o3d.random_down_sample(down_sample)

pcd_o3d.transform(estimate_pose)
self.o3d_vis.add_geometry(pcd_o3d)
self.o3d_vis.update_geometry(pcd_o3d)
self._follow_camera(estimate_pose)


if 0<=len(self.camera_positions)<=3:
# Update camera trajectory
self.camera_positions.append(estimate_pose[:3, 3])
self.camera_poses.append(estimate_pose)
self.update_trajectory()

# Follow camera
# self._follow_camera(estimate_pose)

# Update visualization
self.o3d_vis.poll_events()
self.o3d_vis.update_renderer()

def _follow_camera(self, c2w: NDArray[np.float64]):
"""Adjust the view control to follow a series of camera transformations."""
def update_trajectory(self):

self.camera_params.extrinsic = np.linalg.inv(
c2w
) # Convert c2w to camera intrinsic
# Update line set
lines = [[i, i + 1] for i in range(len(self.camera_positions) - 1)]
colors = [
self.colormap(i / (len(self.camera_positions) - 1))[:3]
for i in range(len(lines))
]

if self.trajectory_line is None:
self.trajectory_line = o3d.geometry.LineSet()
self.o3d_vis.add_geometry(self.trajectory_line)

self.trajectory_line.points = o3d.utility.Vector3dVector(
self.camera_positions
)
self.trajectory_line.lines = o3d.utility.Vector2iVector(lines)
self.trajectory_line.colors = o3d.utility.Vector3dVector(colors)

# Add camera frustum and marker for the latest position
latest_pose = self.camera_poses[-1]
self.add_camera_frustum(latest_pose, [1, 0, 0]) # Blue for latest
# self.add_camera_marker(self.camera_positions[-1], [0, 0, 1])

self.o3d_vis.update_geometry(self.trajectory_line)

def add_camera_frustum(self, pose: np.ndarray, color: List[float]):
# 将 c2w 转换为 w2c
w2c = np.linalg.inv(pose)

frustum = o3d.geometry.LineSet.create_camera_visualization(
self.camera_params.intrinsic.width,
self.camera_params.intrinsic.height,
self.camera_params.intrinsic.intrinsic_matrix.astype(np.float64),
w2c.astype(np.float64), # 使用 w2c 而不是 pose
scale=0.1,
)
frustum.paint_uniform_color(color)
self.o3d_vis.add_geometry(frustum)

def add_camera_marker(self, position: np.ndarray, color: List[float]):
sphere = o3d.geometry.TriangleMesh.create_sphere(
radius=0.0001
) # 减小轨迹点的大小
sphere.translate(position)
sphere.paint_uniform_color(color)
self.o3d_vis.add_geometry(sphere)

def _follow_camera(self, c2w: np.ndarray):
self.camera_params.extrinsic = np.linalg.inv(c2w)
self.view_control.convert_from_pinhole_camera_parameters(
self.camera_params, allow_arbitrary=True
)
def run_visualization(self):
print("Visualization complete. Use mouse/keyboard to interact.")
print("Press Q or Esc to exit.")
self.o3d_vis.run()

def close(self):
self.o3d_vis.destroy_window()

def vis_trajectory(
self,
gt_poses: list[NDArray],
estimated_poses: list[NDArray],
downsampling_resolution: float,
fps: float,
) -> None:
"""Visualize the camera trajectory in 2D space."""
gt_traj = np.array([pose[:3, 3] for pose in gt_poses])
icp_traj = np.array([pose[:3, 3] for pose in estimated_poses])
plt.clf()
plt.title(f"Downsample ratio {downsampling_resolution}\nfps : {fps:.2f}")
plt.plot(icp_traj[:, 0], icp_traj[:, 1], label="g-icp trajectory", linewidth=3)
plt.legend()
plt.plot(gt_traj[:, 0], gt_traj[:, 1], label="ground truth trajectory")
plt.legend()
plt.axis("equal")
plt.pause(0.01)


def visualize_dataset(data_set: BaseDataset):

vis = PcdVisualizer(intrinsic_matrix=data_set.K)
for i, rgbd_image in enumerate(data_set):

for i, rgbd_image in enumerate(data_set[600:1500:20]):
print(f"Processing image {i + 1}/{len(data_set)}...")
vis.update_render(
rgbd_image.points.cpu().numpy(),
rgbd_image.pose.cpu().numpy(),
new_color=rgbd_image.colors.cpu().numpy(),
down_sample=0.01,
)

vis.run_visualization()


def visualize_trajectory(data_set: BaseDataset):
"""vis 3d trajectory"""
Expand Down
2 changes: 1 addition & 1 deletion src/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def __init__(

def __getitem__(self, index: int) -> AlignData:
assert index < len(self._data)
tar, src = self._data[index], self._data[index + 1]
tar, src = self._data[index], self._data[index + 10]
# transform to world
tar.points = transform_points(tar.pose, tar.points)
src.points = transform_points(tar.pose, src.points)
Expand Down
74 changes: 39 additions & 35 deletions src/my_gsplat/gs_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from timeit import default_timer

import torch.nn.functional as F
import torch
import tqdm

Expand Down Expand Up @@ -29,7 +29,11 @@ def __init__(

self.config = base_config
# load data
self.parser = Parser(self.sub_set, normalize=wandb_config.normalize)
self.parser = Parser(
data_set=wandb_config.dataset,
name=self.sub_set,
normalize=wandb_config.normalize,
)

# Losses & Metrics.
self.config.init_loss()
Expand Down Expand Up @@ -100,18 +104,18 @@ def train(self):
non_zero_depth_mask = (depths != 0).float()

# # RGB L1 Loss
# l1loss = F.l1_loss(
# colors * non_zero_depth_mask,
# pixels * non_zero_depth_mask,
# reduction="sum",
# ) / (non_zero_depth_mask.sum() + 1e-8)
l1loss = F.l1_loss(
colors * non_zero_depth_mask,
pixels * non_zero_depth_mask,
reduction="sum",
) / (non_zero_depth_mask.sum() + 1e-8)
#
# # SSIM Loss
# ssim_value = self.config.ssim(
# (pixels * non_zero_depth_mask).permute(0, 3, 1, 2),
# (colors * non_zero_depth_mask).permute(0, 3, 1, 2),
# )
# ssimloss = 1.0 - ssim_value
# SSIM Loss
ssim_value = self.config.ssim(
(pixels * non_zero_depth_mask).permute(0, 3, 1, 2),
(colors * non_zero_depth_mask).permute(0, 3, 1, 2),
)
ssimloss = 1.0 - ssim_value

# Depth Loss
depth_loss = compute_depth_loss(
Expand Down Expand Up @@ -146,12 +150,12 @@ def train(self):
with torch.no_grad():
# loss
self.logger.log_loss("total_loss", total_loss.item(), step=step)
# self.logger.log_loss(
# "pixels", l1loss.item(), step=step, l_type="l1"
# )
# self.logger.log_loss(
# "pixels", ssimloss.item(), step=step, l_type="ssim"
# )
self.logger.log_loss(
"pixels", l1loss.item(), step=step, l_type="l1"
)
self.logger.log_loss(
"pixels", ssimloss.item(), step=step, l_type="ssim"
)
self.logger.log_loss(
"depth", depth_loss.item(), step=step, l_type="l1"
)
Expand All @@ -169,33 +173,33 @@ def train(self):
)
# IMAGE
if step % 100 == 0:
# psnr = self.config.psnr(
# (pixels * non_zero_depth_mask).permute(0, 3, 1, 2),
# (colors * non_zero_depth_mask).permute(0, 3, 1, 2),
# )
psnr = self.config.psnr(
(pixels * non_zero_depth_mask).permute(0, 3, 1, 2),
(colors * non_zero_depth_mask).permute(0, 3, 1, 2),
)
self.logger.plot_rgbd(
depths_gt[0, :, :, 0],
depths[0, :, :, 0],
{
"type": "l1",
"value": depth_loss.item(),
},
# color=train_data.pixels,
# rastered_color=colors.squeeze(0),
# color_loss={
# "type": "psnr",
# "value": psnr.item(),
# },
color=train_data.pixels,
rastered_color=colors.squeeze(0),
color_loss={
"type": "psnr",
"value": psnr.item(),
},
# normal_loss={
# "type": "cosine",
# "value": normal_loss.item(),
# },
normal=depth_to_normal(
depths_gt[0, :, :, 0], Ks.squeeze(0)
),
rastered_normal=depth_to_normal(
depths[0, :, :, 0], Ks.squeeze(0)
),
# normal=depth_to_normal(
# depths_gt[0, :, :, 0], Ks.squeeze(0)
# ),
# rastered_normal=depth_to_normal(
# depths[0, :, :, 0], Ks.squeeze(0)
# ),
step=step,
fig_title="gs_splats Visualization",
)
Expand Down

0 comments on commit c65a49f

Please sign in to comment.