From 7424edd9b187afe31f4cb2e9fdb530711daa5991 Mon Sep 17 00:00:00 2001 From: atticuszz <1831768457@qq.com> Date: Mon, 29 Jul 2024 12:25:40 +0800 Subject: [PATCH] feat: add plot rmse --- src/eval/logger.py | 80 +++++++++++++++++++++++++++++++++++++++++++--- src/eval/utils.py | 11 ++++++- src/plot_rmse.py | 6 ++++ 3 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 src/plot_rmse.py diff --git a/src/eval/logger.py b/src/eval/logger.py index 1b206ac..dab43c7 100644 --- a/src/eval/logger.py +++ b/src/eval/logger.py @@ -1,12 +1,15 @@ from datetime import datetime +from pathlib import Path from typing import Literal - +import pandas import torch from matplotlib import pyplot as plt +from pandas import DataFrame from torch import Tensor import wandb +from .utils import calculate_RMSE from ..my_gsplat.geometry import compute_silhouette_diff @@ -19,13 +22,16 @@ def __init__(self, run_name: str | None = None, config: dict = None): if run_name is None: run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") else: - run_name = run_name + "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + run_name = run_name + self.entity = "supavision" + self.project = "ABGICP" wandb.init( - project="ABGICP", - entity="supavision", + project=self.project, + entity=self.entity, name=run_name, config=config, ) + self.api = wandb.Api() print(f"Run name: {run_name}:config: {config}") @@ -227,3 +233,69 @@ def plot_rgbd( wandb.log({fig_title: wandb.Image(fig)}, step=step) plt.close() + + def plot_bar( + self, + plot_name: str, + label_name: str, + value_name: str, + labels: list, + values: list, + ): + + data = [[label, val] for (label, val) in zip(labels, values)] + table = wandb.Table(data=data, columns=[label_name, value_name]) + wandb.log( + {plot_name: wandb.plot.bar(table, label_name, value_name, title=plot_name)} + ) + + def load_history(self, tags: str = "gsplatloc") -> dict[DataFrame]: + """ + https://docs.wandb.ai/ref/python/public-api/api#runs + Parameters + ---------- + tags: str + + Returns + ------- + histories: dict[str(sub_set),run.history] + """ + # filter tags runs + _runs = self.api.runs( + path="supavision/ABGICP", filters={"tags": tags}, order="config.+.sub_set" + ) + histories = {} + for _run in _runs: + run_path = Path(*_run.path).as_posix() + run = self.api.run(path=run_path) + histories[run.config["sub_set"]] = run.history(2000) + assert len(histories) == len(_runs) + return histories + + def plot_RMSE(self): + histories = self.load_history() + ates = [] + ares = [] + scenes = [] + # Read and calculate RMSEs + for scene, his in histories.items(): + scenes.append(scene) + eT = his["Translation Error"].to_numpy() + eR = his["Rotation Error"].to_numpy() + ates.append(calculate_RMSE(eT)) + ares.append(calculate_RMSE(eR)) + + self.plot_bar( + "ATE of Replica", + label_name="scenes", + value_name="ATEs", + values=ates, + labels=scenes, + ) + self.plot_bar( + "ARE of Replica", + label_name="scenes", + value_name="AREs", + values=ares, + labels=scenes, + ) diff --git a/src/eval/utils.py b/src/eval/utils.py index e1ceafe..bf5f6c1 100644 --- a/src/eval/utils.py +++ b/src/eval/utils.py @@ -1,5 +1,5 @@ import numpy as np -from numpy._typing import NDArray +from numpy.typing import NDArray def calculate_translation_error( @@ -99,3 +99,12 @@ def diff_pcd_COM(pcd_1: NDArray[np.float64], pcd_2: NDArray[np.float64]) -> floa com2 = np.mean(pcd_2, axis=0) distance = np.linalg.norm(com1 - com2) return distance + + +def calculate_RMSE(eT: NDArray) -> float: + """ + Returns + ------- + RMSE: float + """ + return np.sqrt(np.mean(np.square(eT))) diff --git a/src/plot_rmse.py b/src/plot_rmse.py new file mode 100644 index 0000000..ab6c8cb --- /dev/null +++ b/src/plot_rmse.py @@ -0,0 +1,6 @@ +from src.eval.logger import WandbLogger + +if __name__ == "__main__": + log = WandbLogger(run_name="plot_RMSE", config={"description": "plots"}) + + log.plot_RMSE()