Skip to content

Commit

Permalink
chore: clean slam version
Browse files Browse the repository at this point in the history
  • Loading branch information
AtticusZeller committed Jul 30, 2024
1 parent 4000c0d commit c28e706
Show file tree
Hide file tree
Showing 12 changed files with 241 additions and 1,142 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
## GSLoc-Slam
[ ] clean config in base.py
[x] clean config in base.py
1. tracking
1. [x] normalize pcd and pose via PCA
2. [x] update depth_gt with a proper method
3. [x] loss with depth and edge and normals
4. [x] find an early stop condition !!! -> total loss and later than 100 step
5. [ ] sync data shape and avoid too much middle vars through backward
5. [x] sync data shape and avoid too much middle vars through backward
6. [x] total dataset eval

6. [ ] total dataset eval
1. [x] simply
2. [ ] add gs.add_gs method
4 changes: 2 additions & 2 deletions src/eval/logger.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from datetime import datetime
from pathlib import Path
from typing import Literal
import pandas

import torch
from matplotlib import pyplot as plt
from pandas import DataFrame
from torch import Tensor

import wandb

from .utils import calculate_RMSE
from ..my_gsplat.geometry import compute_silhouette_diff
from .utils import calculate_RMSE


class WandbLogger:
Expand Down
98 changes: 76 additions & 22 deletions src/gsplat_run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import sys
import time
from pathlib import Path
Expand All @@ -8,43 +9,96 @@
from src.eval.experiment import WandbConfig


# from my_gsplat.gs_trainer_total2 import Runner
def parse_arguments():
parser = argparse.ArgumentParser(
description="Run GSplat training on specified rooms."
)
parser.add_argument(
"--rooms",
nargs="+",
help="Specify room names manually (e.g., room0 room1 office0)",
)
parser.add_argument(
"--all", action="store_true", help="Run for all rooms (room0-2 and office0-4)"
)
parser.add_argument(
"--room-range",
nargs=2,
type=int,
metavar=("START", "END"),
help="Specify a range of room numbers (e.g., 0 2 for room0 to room2)",
)
parser.add_argument(
"--office-range",
nargs=2,
type=int,
metavar=("START", "END"),
help="Specify a range of office numbers (e.g., 0 4 for office0 to office4)",
)
parser.add_argument(
"--cam-opt",
choices=["quat", "6d", "6d+"],
default="quat",
help="Camera optimization method (default: quat)",
)
parser.add_argument(
"--num-iters",
type=int,
default=2000,
help="Number of iterations (default: 2000)",
)
parser.add_argument(
"--disable-viewer",
action="store_true",
default=True,
help="Disable the viewer",
)
return parser.parse_args()


# from my_gsplat.gs_trainer import Runner


# from my_gsplat.trainer import Runner
def get_rooms(args):
if args.all:
return ["room" + str(i) for i in range(3)] + [
"office" + str(i) for i in range(5)
]
elif args.rooms:
return args.rooms
elif args.room_range:
return [
"room" + str(i) for i in range(args.room_range[0], args.room_range[1] + 1)
]
elif args.office_range:
return [
"office" + str(i)
for i in range(args.office_range[0], args.office_range[1] + 1)
]
else:
return ["room0"] # Default to room0 if no option is specified


def main():
# rooms = ["room" + str(i) for i in range(3)] + ["office" + str(i) for i in range(5)]
rooms = ["office" + str(i) for i in range(4, 5)]
# rooms = ["office" + str(i) for i in range(3, 4)]
# rooms = ["office" + str(i) for i in range(2, 3)]
# rooms = ["office" + str(i) for i in range(1, 2)]
# rooms = ["office" + str(i) for i in range(0, 1)]
args = parse_arguments()
rooms = get_rooms(args)

for room in rooms:
config = WandbConfig(
# sub_set="office2",
sub_set=room,
algorithm="gsplat_v3_filter_knn10-10",
# algorithm="gsplat_slam_v1",
# algorithm="gsplat_outlier",
algorithm="gsplat_v4_filter_knn10-10",
implementation="pytorch",
num_iters=2000,
num_iters=args.num_iters,
normalize=True,
)

runner = Runner(config, extra_config={"cam_opt": "quat"})
# runner = Runner(config, extra_config={"cam_opt": "6d"})
# runner = Runner(config, extra_config={"cam_opt": "6d+"})
runner = Runner(config, extra_config={"cam_opt": args.cam_opt})
runner.config.adjust_steps()
runner.train()

if not runner.config.disable_viewer:
print("Viewer running... Ctrl+C to exit.")
time.sleep(1000000)
if not args.disable_viewer:
print(f"Viewer running for {room}... Ctrl+C to move to the next room.")
try:
time.sleep(10) # Wait for 10 seconds before moving to the next room
except KeyboardInterrupt:
print(f"Moving to the next room...")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/my_gsplat/datasets/Image.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self._pcd = depth_to_points(self._depth, self._K)

# NOTE: remove outliers
self._pcd, inlier_mask = remove_outliers(self._pcd, verbose=True)
self._pcd, inlier_mask = remove_outliers(self._pcd, verbose=False)
self._colors = (self._rgb / 255.0).reshape(-1, 3)[inlier_mask] # N,3

# self._colors = (self._rgb / 255.0).reshape(-1, 3) # N,3
Expand Down
71 changes: 8 additions & 63 deletions src/my_gsplat/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from dataclasses import dataclass

import torch
from nerfview import Viewer
Expand All @@ -16,62 +15,25 @@


@dataclass
class DatasetConfig:
data_dir: str = "./data/360_v2/garden"
data_factor: int = 4
result_dir: str | Path = "./results/Replica"
test_every: int = 8
patch_size: int | None = None
global_scale: float = 1.0

# make_dir
res_dir: Path | None = None
stats_dir: Path | None = None
render_dir: Path | None = None
ckpt_dir: Path | None = None

def make_dir(self):
# Where to dump results.
self.res_dir = Path(self.result_dir)
self.res_dir.mkdir(exist_ok=True, parents=True)

# Setup output directories.
self.ckpt_dir = self.res_dir / "ckpts"
self.ckpt_dir.mkdir(exist_ok=True)
self.stats_dir = self.res_dir / "stats"
self.stats_dir.mkdir(exist_ok=True)
self.render_dir = self.res_dir / "renders"
self.render_dir.mkdir(exist_ok=True)


@dataclass
class TrainingConfig:
batch_size: int = 1
class OptimizationConfig:
max_steps: int = 1000
eval_steps: list[int] = field(default_factory=lambda: [200, 30_000])
save_steps: list[int] = field(default_factory=lambda: [1000, 30_000])
steps_scaler: float = 1.0
refine_start_iter: int = 500
refine_stop_iter: int = 15_000
refine_every: int = 100
reset_every: int = 3000


@dataclass
class OptimizationConfig:
ssim_lambda: float = 0.5
depth_lambda: float = 0.8
normal_lambda: float = 0.0

ssim: StructuralSimilarityIndexMeasure = None
psnr: PeakSignalNoiseRatio = None
lpips: LearnedPerceptualImagePatchSimilarity = None

early_stop: bool = True
patience = 100
patience = 200
best_eR = float("inf")
best_eT = float("inf")
best_loss = float("inf")
best_depth_loss = float("inf")
best_silhouette_loss = float("inf")
best_pose: Tensor = torch.eye(4)

counter = 0

Expand All @@ -81,13 +43,6 @@ def init_loss(self):
self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to(DEVICE)


@dataclass
class DepthLossConfig:
depth_loss: bool = False
depth_lambda: float = 0.8
normal_lambda: float = 0.00


@dataclass
class ViewerConfig:
disable_viewer: bool = True
Expand All @@ -109,22 +64,13 @@ def init_view(self, viewer_render_fn: Callable):

@dataclass
class Config(
TrainingConfig,
DatasetConfig,
OptimizationConfig,
DepthLossConfig,
ViewerConfig,
):
ckpt: str | None = None

def adjust_steps(self, factor: float = 1.0):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
self.save_steps = [int(i * factor) for i in self.save_steps]
self.max_steps = int(self.max_steps * factor)
self.refine_start_iter = int(self.refine_start_iter * factor)
self.refine_stop_iter = int(self.refine_stop_iter * factor)
self.reset_every = int(self.reset_every * factor)
self.refine_every = int(self.refine_every * factor)


@dataclass
Expand Down Expand Up @@ -181,11 +127,10 @@ class TrainData(TensorWrapper):
# for GS
points: Tensor # N,3 in camera
colors: Tensor # N,3
pixels: Tensor # H,W,3
pixels: Tensor # [1, H, W, 3]

depth: Tensor # H,w
depth: Tensor # [1, H, W, 1]
c2w: Tensor # 4,4
c2w_gt: Tensor
pca_factor: Tensor = torch.scalar_tensor(
1.0
) # for scale depth after rot normalized
55 changes: 10 additions & 45 deletions src/my_gsplat/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from ..geometry import compute_depth_gt, transform_points
from ..utils import as_intrinsics_matrix, load_camera_cfg, to_tensor
from .base import AlignData, TrainData
from .base import AlignData
from .Image import RGBDImage
from .normalize import normalize_2C, normalize_T
from .normalize import normalize_2C


class DataLoaderBase:
Expand Down Expand Up @@ -165,24 +165,15 @@ def __getitem__(self, index: int) -> AlignData:
ks = self.K.unsqueeze(0) # [1, 3, 3]
h, w = src.depth.shape

# # NOTE: normalize_points_spherical
# tar.points, _ = normalize_points_spherical(tar.points)
# src.points, sphere_factor = normalize_points_spherical(src.points)
# tar.pose = adjust_pose_spherical(tar.pose, _)
# src.pose = adjust_pose_spherical(src.pose, sphere_factor)

# NOTE: project depth
src.depth = (
compute_depth_gt(
src.points,
src.colors,
ks,
c2w=tar.pose.unsqueeze(0),
height=h,
width=w,
)
# / pca_factor
) # / sphere_factor
src.depth = compute_depth_gt(
src.points,
src.colors,
ks,
c2w=tar.pose.unsqueeze(0),
height=h,
width=w,
)
return AlignData(
pca_factor=pca_factor,
colors=tar.colors,
Expand All @@ -194,29 +185,3 @@ def __getitem__(self, index: int) -> AlignData:
src_c2w=src.pose,
tar_nums=tar.points.shape[0],
)


class Parser2(Replica):
def __init__(self, name: str = "room0", normalize: bool = False):
super().__init__(name=name)
self.K = to_tensor(self.K, requires_grad=True)
# normalize points and pose
init_rgb_d: RGBDImage = super().__getitem__(0)
init_rgb_d.points = transform_points(init_rgb_d.pose, init_rgb_d.points)
self.normalize_T = normalize_T(init_rgb_d) if normalize else None

def __len__(self) -> int:
return super().__len__()

def __getitem__(self, index: int) -> TrainData:
assert index < len(self)
tar = super().__getitem__(index)

return TrainData(
colors=tar.colors,
pixels=tar.rgbs / 255.0,
points=tar.points,
depth=tar.depth,
c2w=tar.pose,
c2w_gt=tar.pose,
)
Loading

0 comments on commit c28e706

Please sign in to comment.