From 047793f1d66f27924188ef5f254a7adaceedafdd Mon Sep 17 00:00:00 2001 From: Karan Sikka Date: Mon, 30 Dec 2024 23:00:30 -0500 Subject: [PATCH] Save changes --- cli/__init__.py | 0 cli/friendly.py | 67 +++++++ cli/main.py | 197 +++++++++++++++++++ cli/types.py | 35 ++++ lightning_pose/model.py | 255 ++++++++++++++++++++++++ lightning_pose/model_config.py | 37 ++++ lightning_pose/train.py | 294 ++++++++++------------------ lightning_pose/utils/predictions.py | 22 +-- lightning_pose/utils/scripts.py | 26 ++- setup.py | 9 +- 10 files changed, 728 insertions(+), 214 deletions(-) create mode 100644 cli/__init__.py create mode 100644 cli/friendly.py create mode 100644 cli/main.py create mode 100644 cli/types.py create mode 100644 lightning_pose/model.py create mode 100644 lightning_pose/model_config.py diff --git a/cli/__init__.py b/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cli/friendly.py b/cli/friendly.py new file mode 100644 index 00000000..728a2c1d --- /dev/null +++ b/cli/friendly.py @@ -0,0 +1,67 @@ +import argparse +import shutil +import sys +from typing import List + + +class ArgumentParser(argparse.ArgumentParser): + def __init__(self, **kwargs): + super().__init__( + formatter_class=_HelpFormatter, + epilog="documentation: \n" + " https://lightning-pose.readthedocs.io/en/latest/source/user_guide/index.html", + **kwargs, + ) + self.is_sub_parser = False + + def print_help(self, with_welcome=True, **kwargs): + if with_welcome and not self.is_sub_parser: + print("Welcome to the lightning-pose CLI!\n") + super().print_help(**kwargs) + + def error(self, message): + red = "\033[91m" + end = "\033[0m" + sys.stderr.write(red + f"error:\n{message}\n\n" + end) + + width = shutil.get_terminal_size().columns + sys.stderr.write("-" * width + "\n") + self.print_help(with_welcome=False) + sys.exit(2) + + +class ArgumentSubParser(ArgumentParser): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_sub_parser = True + + +class _HelpFormatter(argparse.HelpFormatter): + """Modifications on help text formatting for easier readability.""" + + def _split_lines(self, text: str, width: int) -> List[str]: + """Modified to preserve newlines and long words.""" + # First split into paragraphs, then wrap each separately: + # https://docs.python.org/3/library/textwrap.html#textwrap.TextWrapper.replace_whitespace + paragraphs = text.splitlines() + import textwrap + + lines: List[str] = [] + for p in paragraphs: + p_lines = textwrap.wrap( + p, width, break_long_words=False, break_on_hyphens=False + ) + # An empty paragraph should result in a newline. + if not p_lines: + p_lines = [""] + lines.extend(p_lines) + return lines + + def _fill_text(self, text: str, width: int, indent: str) -> str: + return "\n".join( + indent + line for line in self._split_lines(text, width - len(indent)) + ) + + def _format_action(self, *args, **kwargs): + """Modified to add a newline after each argument, for better readability.""" + return super()._format_action(*args, **kwargs) + "\n" diff --git a/cli/main.py b/cli/main.py new file mode 100644 index 00000000..8faeb1e5 --- /dev/null +++ b/cli/main.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import argparse +import datetime +import os +import sys +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from lightning_pose.model import Model + +from . import friendly, types + + +def _build_parser(): + parser = friendly.ArgumentParser() + subparsers = parser.add_subparsers( + dest="command", + required=True, + help="Litpose command to run.", + parser_class=friendly.ArgumentSubParser, + ) + + # Train command + train_parser = subparsers.add_parser( + "train", + description="Train a lightning-pose model using the specified configuration file.", + usage="litpose train \\\n" + " [--output_dir OUTPUT_DIR] \\\n" + " [--overrides KEY=VALUE...]" + "", + ) + train_parser.add_argument( + "config_file", + type=types.config_file, + help="path a config file.\n" + "Download and modify the config template from: \n" + "https://github.com/paninski-lab/lightning-pose/blob/main/scripts/configs/config_default.yaml", + ) + train_parser.add_argument( + "--output_dir", + type=types.model_dir, + help="explicitly specifies the output model directory.\n" + "If not specified, defaults to " + "./outputs/{YYYY-MM-DD}/{HH:MM:SS}/", + ) + train_parser.add_argument( + "--overrides", + nargs="*", + metavar="KEY=VALUE", + help="overrides attributes of the config file. Uses hydra syntax:\n" + "https://hydra.cc/docs/advanced/override_grammar/basic/", + ) + + # Add arguments specific to the 'train' command here + + # Predict command + predict_parser = subparsers.add_parser( + "predict", + description="Predicts keypoints on videos or images.\n" + "\n" + " Video predictions are saved to:\n" + " /\n" + " └── video_preds/\n" + " ├── .csv (predictions)\n" + " ├── _.csv (losses)\n" + " └── labeled_videos/\n" + " └── _labeled.mp4\n" + "\n" + " Image predictions are saved to:\n" + " /\n" + " └── image_preds/\n" + " └── /\n" + " ├── predictions.csv\n" + " ├── predictions_.csv (losses)\n" + " └── _labeled.png\n", + usage="litpose predict ... [OPTIONS]", + ) + predict_parser.add_argument( + "model_dir", type=types.existing_model_dir, help="path to a model directory" + ) + + predict_parser.add_argument( + "input_path", + type=Path, + nargs="+", + help="one or more paths. They can be video files, image files, CSV files, or directories.\n" + " directory: predicts over videos or images in the directory.\n" + " saves image outputs to `image_preds/`\n" + " video file: predicts on the video\n" + " image file: predicts on the image. saves outputs to `image_preds/`\n" + " CSV file: predicts on the images specified in the file.\n" + " uses the labels to compute pixel error.\n" + " saves outputs to `image_preds/`\n", + ) + + post_prediction_args = predict_parser.add_argument_group("post-prediction") + post_prediction_args.add_argument( + "--skip_viz", + action="store_true", + help="skip generating prediction-annotated images/videos", + ) + return parser + + +def main(): + parser = _build_parser() + + # If no commands provided, display the help message. + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + + args = parser.parse_args() + + if args.command == "train": + _train(args) + + elif args.command == "predict": + _predict(args) + + +def _train(args: argparse.Namespace): + import hydra + + if args.output_dir: + output_dir = args.output_dir + else: + now = datetime.datetime.now() + output_dir = ( + Path("outputs") / now.strftime("%Y-%m-%d") / now.strftime("%H:%M:%S") + ) + + print(f"Output directory: {output_dir.absolute()}") + if args.overrides: + print(f"Overrides: {args.overrides}") + + with hydra.initialize_config_dir( + version_base="1.1", config_dir=str(args.config_file.parent.absolute()) + ): + cfg = hydra.compose(config_name=args.config_file.stem, overrides=args.overrides) + + # Delay this import because it's slow. + from lightning_pose.train import train + + # TODO: Move some aspects of directory mgmt to the train function. + output_dir.mkdir(parents=True, exist_ok=True) + # Maintain legacy hydra chdir until downstream no longer depends on it. + os.chdir(output_dir) + train(cfg) + + +def _predict(args: argparse.Namespace): + # Delay this import because it's slow. + from lightning_pose.model import Model + + model_dir = Path(args.model_dir) + if not model_dir.is_dir(): + raise FileNotFoundError(f"Model directory not found: {model_dir.absolute()}") + + model = Model.from_dir(model_dir) + input_paths = [Path(p) for p in args.input_path] + + for p in input_paths: + _predict_multi_type(model, p, args.skip_viz) + + +def _predict_multi_type(model: Model, path: Path, skip_viz: bool): + if path.is_dir(): + image_files = [ + p for p in path.iterdir() if p.is_file() and p.suffix in [".png", ".jpg"] + ] + video_files = [p for p in path.iterdir() if p.is_file() and p.suffix == ".mp4"] + + if len(image_files) > 0: + raise NotImplementedError("Predicting on image dir.") + + for p in video_files: + _predict_multi_type(model, p, skip_viz) + elif path.suffix == ".mp4": + model.predict_on_video_file( + video_file=path, generate_labeled_video=(not skip_viz) + ) + elif path.suffix == ".csv": + model.predict_on_label_csv( + csv_file=path, + generate_labeled_images=False, # TODO: implement visualization + ) + elif path.suffix in [".png", ".jpg"]: + raise NotImplementedError("Not yet implemented: predicting on image files.") + else: + pass + + +if __name__ == "__main__": + main() diff --git a/cli/types.py b/cli/types.py new file mode 100644 index 00000000..a08e454d --- /dev/null +++ b/cli/types.py @@ -0,0 +1,35 @@ +import argparse +from pathlib import Path + + +def config_file(filepath): + """ + Custom argparse type for validating that a file exists and is a yaml file. + + Args: + filepath: The file path string. + + Returns: + A pathlib.Path object if the file is valid, otherwise raises an error. + """ + path = Path(filepath) + if not path.is_file(): + raise argparse.ArgumentTypeError(f"File not found: {filepath}") + if not path.suffix == ".yaml": + raise argparse.ArgumentTypeError(f"File must be a yaml file: {filepath}") + return path + + +def model_dir(filepath): + path = Path(filepath) + return path + + +def existing_model_dir(filepath): + path = model_dir(filepath) + if not path.is_dir(): + raise argparse.ArgumentTypeError( + f"Directory model_dir does not exist: {filepath}" + ) + + return path diff --git a/lightning_pose/model.py b/lightning_pose/model.py new file mode 100644 index 00000000..d1b54a05 --- /dev/null +++ b/lightning_pose/model.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +import copy +from pathlib import Path +from typing import Optional, TypedDict + +import pandas as pd +from omegaconf import DictConfig, OmegaConf + +from lightning_pose.model_config import ModelConfig +from lightning_pose.models import ALLOWED_MODELS +from lightning_pose.utils.io import ckpt_path_from_base_path +from lightning_pose.utils.predictions import ( + export_predictions_and_labeled_video, + load_model_from_checkpoint, + predict_dataset, +) + +# Import as different name to avoid naming conflict with the kwarg `compute_metrics`. +from lightning_pose.utils.scripts import compute_metrics as compute_metrics_fn +from lightning_pose.utils.scripts import ( + get_data_module, + get_dataset, + get_imgaug_transform, +) + +__all__ = ["Model"] + + +class Model: + model_dir: Path + config: ModelConfig + model: Optional[ALLOWED_MODELS] = None + + @staticmethod + def from_dir(model_dir: str | Path): + model_dir = Path(model_dir) + config = ModelConfig.from_yaml_file(model_dir / "config.yaml") + return Model(model_dir, config) + + def __init__(self, model_dir: str | Path, config: ModelConfig): + self.model_dir = Path(model_dir) + self.config = config + + @property + def cfg(self): + return self.config.cfg + + def _load(self): + if self.model is None: + ckpt_file = ckpt_path_from_base_path( + base_path=str(self.model_dir), model_name=self.cfg.model.model_name + ) + self.model = load_model_from_checkpoint( + cfg=self.cfg, + ckpt_file=ckpt_file, + eval=True, + skip_data_module=True, + ) + + def image_preds_dir(self): + return self.model_dir / "image_preds" + + def video_preds_dir(self): + return self.model_dir / "video_preds" + + def labeled_videos_dir(self): + return self.model_dir / "video_preds" / "labeled_videos" + + UNSPECIFIED = "unspecified" + + class PredictionResult(TypedDict): + predictions: pd.DataFrame + metrics: pd.DataFrame + + def predict_on_label_csv( + self, + csv_file: str | Path, + data_dir: Optional[str | Path] = None, + compute_metrics: bool = True, + generate_labeled_images: bool = False, + output_dir: Optional[str | Path] = UNSPECIFIED, + ) -> PredictionResult: + """Predicts on a labeled dataset and computes error/loss metrics if applicable. + + Args: + csv_file (str | Path): Path to the CSV file of images, keypoint locations. + data_dir (str | Path, optional): Root path for relative paths in the CSV file. Defaults to the + parent directory of the CSV file. + compute_metrics (bool, optional): Whether to compute pixel error and loss metrics on + predictions. + generate_labeled_images (bool, optional): Whether to save labeled images. Defaults to False. + output_dir (str | Path, optional): The directory to save outputs to. + Defaults to `{model_dir}/image_preds/{csv_file_name}`. If set to None, outputs are not saved. + Returns: + PredictionResult: A PredictionResult object containing the predictions + and metrics. + """ + return self.predict_on_label_csv_internal( + csv_file=csv_file, + data_dir=data_dir, + compute_metrics=compute_metrics, + generate_labeled_images=generate_labeled_images, + output_dir=output_dir, + output_filename_stem="predictions", + add_train_val_test_set=False, + ) + + def predict_on_label_csv_internal( + self, + csv_file: str | Path, + data_dir: Optional[str | Path] = None, + compute_metrics: bool = True, + generate_labeled_images: bool = False, + output_dir: Optional[str | Path] = UNSPECIFIED, + output_filename_stem: str = "predictions", + add_train_val_test_set: bool = False, + ) -> PredictionResult: + """ + See predict_on_label_csv for the rest of the arguments. The following are the + arguments specific to the internal function. + Args: + output_filename_stem (str): The stem of the output filename. Defaults to 'predictions'. + Used to generate predictions_new for OOD, and predictions_{view_name} for multi-view, in the + model_dir. + add_train_val_test_set (bool): When predicting on training dataset, set to true to add the `set` + column to the prediction output. + """ + + self._load() + csv_file = Path(csv_file) + if data_dir is None: + data_dir = csv_file.parent + + if output_dir == self.__class__.UNSPECIFIED: + output_dir = self.image_preds_dir() / csv_file.name + + elif output_dir is None: + raise NotImplementedError("Currently we must save predictions") + + output_dir.mkdir(parents=True, exist_ok=True) + + if generate_labeled_images: + raise NotImplementedError() + + # Point predict_dataset to the csv_file and data_dir. + cfg_overrides = { + "data": { + "data_dir": str(data_dir), + "csv_file": str(csv_file), + } + } + + # Avoid annotating set=train/val/test for CSV file other than the training CSV file. + if not add_train_val_test_set: + cfg_overrides = {"train_prob": 1, "val_prob": 0, "train_frames": 1} + + cfg_pred = OmegaConf.merge(self.cfg, cfg_overrides) + + data_module_pred = _build_datamodule_pred(cfg_pred) + + preds_file_path = output_dir / (output_filename_stem + ".csv") + preds_file = str(preds_file_path) + + df = predict_dataset( + cfg_pred, data_module_pred, model=self.model, preds_file=preds_file + ) + + if compute_metrics: + # For multiview, Compute metrics requires preds_file be a list in order of sorted view_names. + compute_metrics_on_preds_file = preds_file + if self.config.is_multi_view(): + compute_metrics_on_preds_file = [] + for view_name in sorted(self.config.cfg.data.view_names): + multiview_preds_file_path = preds_file_path.with_name( + preds_file_path.stem + f"_{view_name}.csv" + ) + compute_metrics_on_preds_file.append(str(multiview_preds_file_path)) + + import os + + print(os.getcwd()) + compute_metrics_fn( + cfg=cfg_pred, + preds_file=compute_metrics_on_preds_file, + data_module=data_module_pred, + ) + + # TODO: Generate detector outputs. + + return self.PredictionResult(predictions=df) + + def predict_on_video_file( + self, + video_file: str | Path, + output_dir: Optional[str | Path] = UNSPECIFIED, + compute_metrics: bool = True, + generate_labeled_video: bool = False, + ) -> PredictionResult: + self._load() + video_file = Path(video_file) + + if output_dir == self.__class__.UNSPECIFIED: + output_dir = self.video_preds_dir() + + elif output_dir is None: + raise NotImplementedError("Currently we must save predictions") + + output_dir.mkdir(parents=True, exist_ok=True) + + prediction_csv_file = output_dir / f"{video_file.stem}.csv" + + labeled_mp4_file = None + if generate_labeled_video: + labeled_mp4_file = str( + self.labeled_videos_dir() / f"{video_file.stem}_labeled.mp4" + ) + + if self.config.cfg.eval.get("predict_vids_after_training_save_heatmaps", False): + raise NotImplementedError( + "Implement this after cleaning up _predict_frames: " + "Set a flag on the model to return heatmaps. " + "Use trainer.predict instead of side-stepping it." + ) + df = export_predictions_and_labeled_video( + video_file=str(video_file), + cfg=self.config.cfg, + prediction_csv_file=str(prediction_csv_file), + labeled_mp4_file=labeled_mp4_file, + model=self.model, + ) + + # FIXME: This is only used for computing PCA metrics. + data_module = _build_datamodule_pred(self.cfg) + if compute_metrics: + compute_metrics_fn(self.cfg, str(prediction_csv_file), data_module) + + return self.PredictionResult(predictions=df) + + +def _build_datamodule_pred(cfg: DictConfig): + cfg_pred = copy.deepcopy(cfg) + cfg_pred.training.imgaug = "default" + imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred) + dataset_pred = get_dataset( + cfg=cfg_pred, + data_dir=cfg_pred.data.data_dir, + imgaug_transform=imgaug_transform_pred, + ) + data_module_pred = get_data_module( + cfg=cfg_pred, dataset=dataset_pred, video_dir=cfg_pred.data.video_dir + ) + data_module_pred.setup() + + return data_module_pred diff --git a/lightning_pose/model_config.py b/lightning_pose/model_config.py new file mode 100644 index 00000000..b4251b18 --- /dev/null +++ b/lightning_pose/model_config.py @@ -0,0 +1,37 @@ +from pathlib import Path + +from omegaconf import DictConfig, OmegaConf + +__all__ = ["ModelConfig"] + +from lightning_pose.utils.io import check_video_paths, return_absolute_path + + +class ModelConfig: + + @staticmethod + def from_yaml_file(filepath): + return ModelConfig(OmegaConf.load(filepath)) + + def __init__(self, cfg: DictConfig): + self.cfg = cfg + + def is_single_view(self): + return not self.is_multi_view() + + def is_multi_view(self): + if self.cfg.data.get("view_names") is None: + return False + if len(self.cfg.data.view_names) == 1: + raise ValueError( + "view_names should not be specified if there is only one view." + ) + return True + + ## Eval ## + + def test_video_files(self) -> list[Path]: + files = check_video_paths( + return_absolute_path(self.cfg.eval.test_videos_directory) + ) + return [Path(f) for f in files] diff --git a/lightning_pose/train.py b/lightning_pose/train.py index e641b037..8b25996e 100644 --- a/lightning_pose/train.py +++ b/lightning_pose/train.py @@ -1,5 +1,5 @@ """Example model training function.""" -import copy + import os import random import shutil @@ -13,22 +13,11 @@ from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from typeguard import typechecked +from lightning_pose.model import Model from lightning_pose.utils import pretty_print_cfg, pretty_print_str -from lightning_pose.utils.cropzoom import generate_cropped_labeled_frames, generate_cropped_video -from lightning_pose.utils.io import ( - check_video_paths, - ckpt_path_from_base_path, - return_absolute_data_paths, - return_absolute_path, -) -from lightning_pose.utils.predictions import ( - export_predictions_and_labeled_video, - load_model_from_checkpoint, - predict_dataset, -) +from lightning_pose.utils.io import return_absolute_data_paths from lightning_pose.utils.scripts import ( calculate_train_batches, - compute_metrics, get_callbacks, get_data_module, get_dataset, @@ -37,13 +26,112 @@ get_model, ) -# to ignore imports for sphix-autoapidoc +# to ignore imports for sphinx-autoapidoc __all__ = ["train"] @typechecked def train(cfg: DictConfig) -> None: + model = _train(cfg) + # Comment out the above, and uncomment the below to skip + # training and go straight to post-training analysis: + # import os + # from lightning_pose.model import Model + # model = Model.from_dir(os.getcwd()) + + _evaluate_on_training_dataset(model) + _evaluate_on_ood_dataset(model) + _predict_test_videos(model) + + +def _absolute_csv_file(csv_file, data_dir): + csv_file = Path(csv_file) + if not csv_file.is_absolute(): + return Path(data_dir) / csv_file + return csv_file + + +def _evaluate_on_training_dataset(model: Model): + pretty_print_str("Predicting train/val/test images...") + + if model.config.is_single_view(): + csv_file = _absolute_csv_file( + model.config.cfg.data.csv_file, model.config.cfg.data.data_dir + ) + csv_files = [csv_file] + output_filename_stems = ["predictions"] + else: + csv_files = [] + output_filename_stems = [] + for csv_file, view_name in zip( + model.config.cfg.data.csv_file, model.config.cfg.data.view_names + ): + csv_files.append( + _absolute_csv_file(csv_file, model.config.cfg.data.data_dir) + ) + output_filename_stems.append(f"predictions_{view_name}") + + for csv_file, output_filename_stem in zip(csv_files, output_filename_stems): + model.predict_on_label_csv_internal( + csv_file=csv_file, + data_dir=model.config.cfg.data.data_dir, + # TODO annotate with train/val/test split metadata. + compute_metrics=True, + generate_labeled_images=False, + output_dir=model.model_dir, + output_filename_stem=output_filename_stem, + add_train_val_test_set=True, + ) + +def _evaluate_on_ood_dataset(model: Model): + if model.config.is_single_view(): + csv_file = _absolute_csv_file( + model.config.cfg.data.csv_file, model.config.cfg.data.data_dir + ) + ood_csv_file = csv_file.with_stem(csv_file.stem + "_new") + ood_csv_files = [ood_csv_file] + output_filename_stems = ["predictions_new"] + else: + ood_csv_files = [] + output_filename_stems = [] + for csv_file, view_name in zip( + model.config.cfg.data.csv_file, model.config.cfg.data.view_names + ): + csv_file = _absolute_csv_file(csv_file, model.config.cfg.data.data_dir) + ood_csv_file = csv_file.with_stem(csv_file.stem + "_new") + ood_csv_files.append(ood_csv_file) + output_filename_stems.append(f"predictions_new_{view_name}") + + if ood_csv_files[0].is_file(): + pretty_print_str("Predicting OOD images...") + + for ood_csv_file, output_filename_stem in zip( + ood_csv_files, output_filename_stems + ): + model.predict_on_label_csv_internal( + csv_file=ood_csv_file, + data_dir=model.config.cfg.data.data_dir, + compute_metrics=True, + generate_labeled_images=False, + output_dir=model.model_dir, + output_filename_stem=output_filename_stem, + ) + + +def _predict_test_videos(model: Model): + if model.config.cfg.eval.predict_vids_after_training: + pretty_print_str(f"Predicting videos in cfg.eval.test_videos_directory...") + for video_file in model.config.test_video_files(): + pretty_print_str(f"Predicting video: {video_file}...") + + model.predict_on_video_file( + Path(video_file), + generate_labeled_video=model.config.cfg.eval.save_vids_after_training, + ) + + +def _train(cfg: DictConfig) -> Model: # reset all seeds seed = 0 os.environ["PYTHONHASHSEED"] = str(seed) @@ -55,6 +143,7 @@ def train(cfg: DictConfig) -> None: # record lightning-pose version from lightning_pose import __version__ as lightning_pose_version + with open_dict(cfg): cfg.model.lightning_pose_version = lightning_pose_version @@ -127,7 +216,7 @@ def train(cfg: DictConfig) -> None: cfg, early_stopping=cfg.training.get("early_stopping", False), lr_monitor=True, - ckpt_every_n_epochs=cfg.training.get("ckpt_every_n_epochs", None) + ckpt_every_n_epochs=cfg.training.get("ckpt_every_n_epochs", None), ) # calculate number of batches for both labeled and unlabeled data per epoch @@ -169,175 +258,6 @@ def train(cfg: DictConfig) -> None: if not trainer.is_global_zero: sys.exit(0) - # ---------------------------------------------------------------------------------- - # Post-training analysis - # ---------------------------------------------------------------------------------- - # get best ckpt - best_ckpt = ckpt_path_from_base_path( - base_path=hydra_output_directory, model_name=cfg.model.model_name - ) - print(f"Best checkpoint: {os.path.basename(best_ckpt)}") - # check if best_ckpt is a file - if not os.path.isfile(best_ckpt): - raise FileNotFoundError("Cannot find checkpoint. Have you trained for too few epochs?") - - # make unaugmented data_loader if necessary - if cfg.training.imgaug != "default": - cfg_pred = copy.deepcopy(cfg) - cfg_pred.training.imgaug = "default" - imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred) - dataset_pred = get_dataset( - cfg=cfg_pred, data_dir=data_dir, imgaug_transform=imgaug_transform_pred - ) - data_module_pred = get_data_module(cfg=cfg_pred, dataset=dataset_pred, video_dir=video_dir) - data_module_pred.setup() - else: - data_module_pred = data_module - - model = load_model_from_checkpoint( - cfg=cfg, - ckpt_file=best_ckpt, - eval=True, - data_module=data_module_pred, - ) + from lightning_pose.model import Model - # ---------------------------------------------------------------------------------- - # predict on all labeled frames (train/val/test) - # ---------------------------------------------------------------------------------- - # Rebuild trainer with devices=1 for prediction. Training flags not needed. - trainer = pl.Trainer(accelerator="gpu", devices=1) - pretty_print_str("Predicting train/val/test images...") - # compute and save frame-wise predictions - preds_file = os.path.join(hydra_output_directory, "predictions.csv") - predict_dataset( - cfg=cfg, - trainer=trainer, - model=model, - data_module=data_module_pred, - preds_file=preds_file, - ) - # compute and save various metrics - # for multiview, predict_dataset outputs one pred file per view. - multiview_pred_files = [ - str(Path(hydra_output_directory) / p) - for p in Path(hydra_output_directory).glob("predictions_*.csv") - ] - if len(multiview_pred_files) > 0: - preds_file = multiview_pred_files - compute_metrics(cfg=cfg, preds_file=preds_file, data_module=data_module_pred) - - is_detector = ( - cfg.get("detector") is not None and cfg.detector.get("crop_ratio") is not None - ) - if is_detector: - generate_cropped_labeled_frames( - root_directory=Path(data_dir), - output_directory=Path(hydra_output_directory), - detector_cfg=cfg.detector, - ) - - # ---------------------------------------------------------------------------------- - # predict folder of videos - # ---------------------------------------------------------------------------------- - if cfg.eval.predict_vids_after_training or is_detector: - pretty_print_str("Predicting videos...") - if cfg.eval.test_videos_directory is None: - filenames = [] - else: - filenames = check_video_paths(return_absolute_path(cfg.eval.test_videos_directory)) - vidstr = "video" if (len(filenames) == 1) else "videos" - pretty_print_str( - f"Found {len(filenames)} {vidstr} to predict (in cfg.eval.test_videos_directory)" - ) - - for video_file in filenames: - assert os.path.isfile(video_file) - - pretty_print_str(f"Predicting video: {video_file}...") - # get save name for prediction csv file - video_pred_dir = os.path.join(hydra_output_directory, "video_preds") - video_pred_name = os.path.splitext(os.path.basename(video_file))[0] - prediction_csv_file = os.path.join(video_pred_dir, video_pred_name + ".csv") - # get save name labeled video csv - if cfg.eval.save_vids_after_training: - labeled_vid_dir = os.path.join(video_pred_dir, "labeled_videos") - labeled_mp4_file = os.path.join(labeled_vid_dir, video_pred_name + "_labeled.mp4") - else: - labeled_mp4_file = None - # predict on video - export_predictions_and_labeled_video( - video_file=video_file, - cfg=cfg, - prediction_csv_file=prediction_csv_file, - labeled_mp4_file=labeled_mp4_file, - trainer=trainer, - model=model, - data_module=data_module_pred, - save_heatmaps=cfg.eval.get( - "predict_vids_after_training_save_heatmaps", False - ), - ) - - compute_metrics( - cfg=cfg, - preds_file=prediction_csv_file, - data_module=data_module_pred, - ) - - if is_detector: - generate_cropped_video( - video_path=Path(video_file), - detector_model_dir=Path(hydra_output_directory), - detector_cfg=cfg.detector, - ) - - # ---------------------------------------------------------------------------------- - # predict on OOD frames - # ---------------------------------------------------------------------------------- - # update config file to point to OOD data - if isinstance(cfg.data.csv_file, list) or isinstance(cfg.data.csv_file, ListConfig): - csv_file_ood = [] - for csv_file in cfg.data.csv_file: - csv_file_ood.append( - os.path.join(cfg.data.data_dir, csv_file).replace(".csv", "_new.csv")) - else: - csv_file_ood = os.path.join( - cfg.data.data_dir, cfg.data.csv_file).replace(".csv", "_new.csv") - - if (isinstance(csv_file_ood, str) and os.path.exists(csv_file_ood)) \ - or (isinstance(csv_file_ood, list) and os.path.exists(csv_file_ood[0])): - cfg_ood = cfg.copy() - cfg_ood.data.csv_file = csv_file_ood - cfg_ood.training.imgaug = "default" - cfg_ood.training.train_prob = 1 - cfg_ood.training.val_prob = 0 - cfg_ood.training.train_frames = 1 - # build dataset/datamodule - imgaug_transform_ood = get_imgaug_transform(cfg=cfg_ood) - dataset_ood = get_dataset( - cfg=cfg_ood, data_dir=data_dir, imgaug_transform=imgaug_transform_ood - ) - data_module_ood = get_data_module(cfg=cfg_ood, dataset=dataset_ood, video_dir=video_dir) - data_module_ood.setup() - pretty_print_str("Predicting OOD images...") - # compute and save frame-wise predictions - preds_file_ood = os.path.join(hydra_output_directory, "predictions_new.csv") - predict_dataset( - cfg=cfg_ood, - trainer=trainer, - model=model, - data_module=data_module_ood, - preds_file=preds_file_ood, - ) - # compute and save various metrics - try: - # take care of multiview case, where multiple csv files have been saved - preds_files = [ - os.path.join(hydra_output_directory, path) for path in - os.listdir(hydra_output_directory) if path.startswith("predictions_new") - ] - if len(preds_files) > 1: - preds_file_ood = preds_files - compute_metrics(cfg=cfg_ood, preds_file=preds_file_ood, data_module=data_module_ood) - except Exception as e: - print(f"Error computing metrics\n{e}") + return Model.from_dir(hydra_output_directory) diff --git a/lightning_pose/utils/predictions.py b/lightning_pose/utils/predictions.py index 13cf5dd7..ffca4cb6 100644 --- a/lightning_pose/utils/predictions.py +++ b/lightning_pose/utils/predictions.py @@ -73,11 +73,9 @@ def __init__( if data_module is None: if video_file is None: raise ValueError("must pass data_module to constructor if predicting on a dataset") - if cfg.data.get("keypoint_names", None) is None \ - and cfg.data.get("keypoints", None) is None: - raise ValueError( - "must include `keypoint_names` or `keypoints` field in cfg.data if not " - "passing data_module as an argument to PredictionHandler") + if cfg.data.get("keypoint_names", None) is None: + raise ValueError( + "must include `keypoint_names` field in cfg.data") self.cfg = cfg self.data_module = data_module @@ -95,14 +93,7 @@ def frame_count(self) -> int: @property def keypoint_names(self): - if self.cfg.data.get("keypoint_names", None) is not None: - if isinstance(self.cfg.data.get("keypoint_names"), DictConfig): - return dict(self.cfg.data.get("keypoint_names")) - return list(self.cfg.data.keypoint_names) - elif self.cfg.data.get("keypoints", None) is not None: - return list(self.cfg.data.keypoints) - else: - return self.data_module.dataset.keypoint_names + return list(self.cfg.data.keypoint_names) @property def do_context(self): @@ -310,7 +301,7 @@ def predict_dataset( Args: cfg: hydra config data_module: data module that contains dataloaders for train, val, test splits - preds_file: absolute filename for the predictions .csv file + preds_file: path for the predictions .csv file ckpt_file: absolute path to the checkpoint of your trained model; requires .ckpt suffix trainer: pl.Trainer object model: Lightning Module @@ -887,7 +878,7 @@ def export_predictions_and_labeled_video( data_module: Optional[Union[BaseDataModule, UnlabeledDataModule]] = None, labeled_mp4_file: Optional[str] = None, save_heatmaps: Optional[bool] = False, -) -> None: +) -> pd.DataFrame: """Export predictions csv and a labeled video for a single video file.""" if ckpt_file is None and model is None: @@ -923,3 +914,4 @@ def export_predictions_and_labeled_video( output_video_path=labeled_mp4_file, colormap=cfg.eval.get("colormap", "cool") ) + return preds_df diff --git a/lightning_pose/utils/scripts.py b/lightning_pose/utils/scripts.py index dab91800..e81ac909 100644 --- a/lightning_pose/utils/scripts.py +++ b/lightning_pose/utils/scripts.py @@ -3,6 +3,7 @@ import os import warnings from collections import OrderedDict +from pathlib import Path from typing import Dict, List, Optional, Union import imgaug.augmenters as iaa @@ -553,21 +554,29 @@ def compute_metrics( preds_file: Union[str, List[str]], data_module: Optional[Union[BaseDataModule, UnlabeledDataModule]] = None, ) -> None: - """Compute various metrics on predictions csv file, potentially for multiple views.""" + """Compute various metrics on predictions csv file, potentially for multiple views. + Saves metrics to files next to predictions file, in the convention of: + {prediction_file_stem}_{metric_name}.csv + + Args: + cfg: the config used to determine whether single or multiview and which metrics + to compute + preds_file: Path to model predictions used to compute metrics. + For multiview, a list of paths. + + """ if ( cfg.data.get("view_names", None) and len(cfg.data.view_names) > 1 and isinstance(preds_file, list) ): - preds_file = sorted(preds_file) for view_name, csv_file, preds_file_ in zip( sorted(cfg.data.view_names), sorted(cfg.data.csv_file), - preds_file + sorted(preds_file) ): assert view_name in preds_file_ labels_file = return_absolute_path(os.path.join(cfg.data.data_dir, csv_file)) - # preds_file_ = preds_file.replace(".csv", f"_{view_name}.csv") compute_metrics_single( cfg=cfg, labels_file=labels_file, @@ -643,6 +652,7 @@ def compute_metrics_single( ): metrics_to_compute += ["pca_multiview"] + preds_file_path = Path(preds_file) # compute metrics; csv files will be saved to the same directory the prdictions are stored in if "pixel_error" in metrics_to_compute: keypoints_true = labels_df.to_numpy().reshape(labels_df.shape[0], -1, 2) @@ -651,7 +661,7 @@ def compute_metrics_single( # add train/val/test split if set is not None: error_df["set"] = set - save_file = preds_file.replace(".csv", "_pixel_error.csv") + save_file = preds_file_path.with_name(preds_file_path.stem + "_pixel_error.csv") error_df.to_csv(save_file) if "temporal" in metrics_to_compute: @@ -662,7 +672,7 @@ def compute_metrics_single( # add train/val/test split if set is not None: temporal_norm_df["set"] = set - save_file = preds_file.replace(".csv", "_temporal_norm.csv") + save_file = preds_file_path.with_name(preds_file_path.stem + "_temporal_norm.csv") temporal_norm_df.to_csv(save_file) if "pca_singleview" in metrics_to_compute: @@ -684,7 +694,7 @@ def compute_metrics_single( # add train/val/test split if set is not None: pcasv_df["set"] = set - save_file = preds_file.replace(".csv", "_pca_singleview_error.csv") + save_file = preds_file_path.with_name(preds_file_path.stem + "_pca_singleview_error.csv") pcasv_df.to_csv(save_file) if "pca_multiview" in metrics_to_compute: @@ -705,5 +715,5 @@ def compute_metrics_single( # add train/val/test split if set is not None: pcamv_df["set"] = set - save_file = preds_file.replace(".csv", "_pca_multiview_error.csv") + save_file = preds_file_path.with_name(preds_file_path.stem + "_pca_multiview_error.csv") pcamv_df.to_csv(save_file) diff --git a/setup.py b/setup.py index f62678e2..b71bb4c8 100644 --- a/setup.py +++ b/setup.py @@ -105,7 +105,7 @@ def get_cuda_version(): setup( name="lightning-pose", - packages=find_packages() + ["mirror_mouse_example"], # include data for wheel packaging + packages=find_packages(), version=get_version(Path("lightning_pose").joinpath("__init__.py")), description="Semi-supervised pose estimation using pytorch lightning", long_description=long_description, @@ -116,9 +116,10 @@ def get_cuda_version(): author_email="danbider@gmail.com", url="https://github.com/danbider/lightning-pose", keywords=["machine learning", "deep learning", "computer_vision"], - package_dir={ - "lightning_pose": "lightning_pose", - "mirror_mouse_example": "data/mirror-mouse-example", # remap 'data/mirror-mouse-example' + entry_points={ + 'console_scripts': [ + 'litpose = cli.main:main', + ], }, include_package_data=True, # required to get the non-.py data files in the wheel )