Skip to content

Commit

Permalink
feat: add plot rmse
Browse files Browse the repository at this point in the history
  • Loading branch information
AtticusZeller committed Jul 29, 2024
1 parent 0eef5e1 commit 7424edd
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 5 deletions.
80 changes: 76 additions & 4 deletions src/eval/logger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from datetime import datetime
from pathlib import Path
from typing import Literal

import pandas
import torch
from matplotlib import pyplot as plt
from pandas import DataFrame
from torch import Tensor

import wandb

from .utils import calculate_RMSE
from ..my_gsplat.geometry import compute_silhouette_diff


Expand All @@ -19,13 +22,16 @@ def __init__(self, run_name: str | None = None, config: dict = None):
if run_name is None:
run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
else:
run_name = run_name + "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name = run_name
self.entity = "supavision"
self.project = "ABGICP"
wandb.init(
project="ABGICP",
entity="supavision",
project=self.project,
entity=self.entity,
name=run_name,
config=config,
)
self.api = wandb.Api()

print(f"Run name: {run_name}:config: {config}")

Expand Down Expand Up @@ -227,3 +233,69 @@ def plot_rgbd(
wandb.log({fig_title: wandb.Image(fig)}, step=step)

plt.close()

def plot_bar(
self,
plot_name: str,
label_name: str,
value_name: str,
labels: list,
values: list,
):

data = [[label, val] for (label, val) in zip(labels, values)]
table = wandb.Table(data=data, columns=[label_name, value_name])
wandb.log(
{plot_name: wandb.plot.bar(table, label_name, value_name, title=plot_name)}
)

def load_history(self, tags: str = "gsplatloc") -> dict[DataFrame]:
"""
https://docs.wandb.ai/ref/python/public-api/api#runs
Parameters
----------
tags: str
Returns
-------
histories: dict[str(sub_set),run.history]
"""
# filter tags runs
_runs = self.api.runs(
path="supavision/ABGICP", filters={"tags": tags}, order="config.+.sub_set"
)
histories = {}
for _run in _runs:
run_path = Path(*_run.path).as_posix()
run = self.api.run(path=run_path)
histories[run.config["sub_set"]] = run.history(2000)
assert len(histories) == len(_runs)
return histories

def plot_RMSE(self):
histories = self.load_history()
ates = []
ares = []
scenes = []
# Read and calculate RMSEs
for scene, his in histories.items():
scenes.append(scene)
eT = his["Translation Error"].to_numpy()
eR = his["Rotation Error"].to_numpy()
ates.append(calculate_RMSE(eT))
ares.append(calculate_RMSE(eR))

self.plot_bar(
"ATE of Replica",
label_name="scenes",
value_name="ATEs",
values=ates,
labels=scenes,
)
self.plot_bar(
"ARE of Replica",
label_name="scenes",
value_name="AREs",
values=ares,
labels=scenes,
)
11 changes: 10 additions & 1 deletion src/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numpy._typing import NDArray
from numpy.typing import NDArray


def calculate_translation_error(
Expand Down Expand Up @@ -99,3 +99,12 @@ def diff_pcd_COM(pcd_1: NDArray[np.float64], pcd_2: NDArray[np.float64]) -> floa
com2 = np.mean(pcd_2, axis=0)
distance = np.linalg.norm(com1 - com2)
return distance


def calculate_RMSE(eT: NDArray) -> float:
"""
Returns
-------
RMSE: float
"""
return np.sqrt(np.mean(np.square(eT)))
6 changes: 6 additions & 0 deletions src/plot_rmse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from src.eval.logger import WandbLogger

if __name__ == "__main__":
log = WandbLogger(run_name="plot_RMSE", config={"description": "plots"})

log.plot_RMSE()

0 comments on commit 7424edd

Please sign in to comment.