Skip to content

Commit

Permalink
Save changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ksikka committed Jan 15, 2025
1 parent 2b99f81 commit 047793f
Show file tree
Hide file tree
Showing 10 changed files with 728 additions and 214 deletions.
Empty file added cli/__init__.py
Empty file.
67 changes: 67 additions & 0 deletions cli/friendly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import argparse
import shutil
import sys
from typing import List


class ArgumentParser(argparse.ArgumentParser):
def __init__(self, **kwargs):
super().__init__(
formatter_class=_HelpFormatter,
epilog="documentation: \n"
" https://lightning-pose.readthedocs.io/en/latest/source/user_guide/index.html",
**kwargs,
)
self.is_sub_parser = False

def print_help(self, with_welcome=True, **kwargs):
if with_welcome and not self.is_sub_parser:
print("Welcome to the lightning-pose CLI!\n")
super().print_help(**kwargs)

def error(self, message):
red = "\033[91m"
end = "\033[0m"
sys.stderr.write(red + f"error:\n{message}\n\n" + end)

width = shutil.get_terminal_size().columns
sys.stderr.write("-" * width + "\n")
self.print_help(with_welcome=False)
sys.exit(2)


class ArgumentSubParser(ArgumentParser):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_sub_parser = True


class _HelpFormatter(argparse.HelpFormatter):
"""Modifications on help text formatting for easier readability."""

def _split_lines(self, text: str, width: int) -> List[str]:
"""Modified to preserve newlines and long words."""
# First split into paragraphs, then wrap each separately:
# https://docs.python.org/3/library/textwrap.html#textwrap.TextWrapper.replace_whitespace
paragraphs = text.splitlines()
import textwrap

lines: List[str] = []
for p in paragraphs:
p_lines = textwrap.wrap(
p, width, break_long_words=False, break_on_hyphens=False
)
# An empty paragraph should result in a newline.
if not p_lines:
p_lines = [""]
lines.extend(p_lines)
return lines

def _fill_text(self, text: str, width: int, indent: str) -> str:
return "\n".join(
indent + line for line in self._split_lines(text, width - len(indent))
)

def _format_action(self, *args, **kwargs):
"""Modified to add a newline after each argument, for better readability."""
return super()._format_action(*args, **kwargs) + "\n"
197 changes: 197 additions & 0 deletions cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from __future__ import annotations

import argparse
import datetime
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from lightning_pose.model import Model

from . import friendly, types


def _build_parser():
parser = friendly.ArgumentParser()
subparsers = parser.add_subparsers(
dest="command",
required=True,
help="Litpose command to run.",
parser_class=friendly.ArgumentSubParser,
)

# Train command
train_parser = subparsers.add_parser(
"train",
description="Train a lightning-pose model using the specified configuration file.",
usage="litpose train <config_file> \\\n"
" [--output_dir OUTPUT_DIR] \\\n"
" [--overrides KEY=VALUE...]"
"",
)
train_parser.add_argument(
"config_file",
type=types.config_file,
help="path a config file.\n"
"Download and modify the config template from: \n"
"https://github.com/paninski-lab/lightning-pose/blob/main/scripts/configs/config_default.yaml",
)
train_parser.add_argument(
"--output_dir",
type=types.model_dir,
help="explicitly specifies the output model directory.\n"
"If not specified, defaults to "
"./outputs/{YYYY-MM-DD}/{HH:MM:SS}/",
)
train_parser.add_argument(
"--overrides",
nargs="*",
metavar="KEY=VALUE",
help="overrides attributes of the config file. Uses hydra syntax:\n"
"https://hydra.cc/docs/advanced/override_grammar/basic/",
)

# Add arguments specific to the 'train' command here

# Predict command
predict_parser = subparsers.add_parser(
"predict",
description="Predicts keypoints on videos or images.\n"
"\n"
" Video predictions are saved to:\n"
" <model_dir>/\n"
" └── video_preds/\n"
" ├── <video_filename>.csv (predictions)\n"
" ├── <video_filename>_<metric>.csv (losses)\n"
" └── labeled_videos/\n"
" └── <video_filename>_labeled.mp4\n"
"\n"
" Image predictions are saved to:\n"
" <model_dir>/\n"
" └── image_preds/\n"
" └── <image_dirname | csv_filename | timestamp>/\n"
" ├── predictions.csv\n"
" ├── predictions_<metric>.csv (losses)\n"
" └── <image_filename>_labeled.png\n",
usage="litpose predict <model_dir> <input_path:video|image|dir|csv>... [OPTIONS]",
)
predict_parser.add_argument(
"model_dir", type=types.existing_model_dir, help="path to a model directory"
)

predict_parser.add_argument(
"input_path",
type=Path,
nargs="+",
help="one or more paths. They can be video files, image files, CSV files, or directories.\n"
" directory: predicts over videos or images in the directory.\n"
" saves image outputs to `image_preds/<directory_name>`\n"
" video file: predicts on the video\n"
" image file: predicts on the image. saves outputs to `image_preds/<timestamp>`\n"
" CSV file: predicts on the images specified in the file.\n"
" uses the labels to compute pixel error.\n"
" saves outputs to `image_preds/<csv_file_name>`\n",
)

post_prediction_args = predict_parser.add_argument_group("post-prediction")
post_prediction_args.add_argument(
"--skip_viz",
action="store_true",
help="skip generating prediction-annotated images/videos",
)
return parser


def main():
parser = _build_parser()

# If no commands provided, display the help message.
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1)

args = parser.parse_args()

if args.command == "train":
_train(args)

elif args.command == "predict":
_predict(args)


def _train(args: argparse.Namespace):
import hydra

if args.output_dir:
output_dir = args.output_dir
else:
now = datetime.datetime.now()
output_dir = (
Path("outputs") / now.strftime("%Y-%m-%d") / now.strftime("%H:%M:%S")
)

print(f"Output directory: {output_dir.absolute()}")
if args.overrides:
print(f"Overrides: {args.overrides}")

with hydra.initialize_config_dir(
version_base="1.1", config_dir=str(args.config_file.parent.absolute())
):
cfg = hydra.compose(config_name=args.config_file.stem, overrides=args.overrides)

# Delay this import because it's slow.
from lightning_pose.train import train

# TODO: Move some aspects of directory mgmt to the train function.
output_dir.mkdir(parents=True, exist_ok=True)
# Maintain legacy hydra chdir until downstream no longer depends on it.
os.chdir(output_dir)
train(cfg)


def _predict(args: argparse.Namespace):
# Delay this import because it's slow.
from lightning_pose.model import Model

model_dir = Path(args.model_dir)
if not model_dir.is_dir():
raise FileNotFoundError(f"Model directory not found: {model_dir.absolute()}")

model = Model.from_dir(model_dir)
input_paths = [Path(p) for p in args.input_path]

for p in input_paths:
_predict_multi_type(model, p, args.skip_viz)


def _predict_multi_type(model: Model, path: Path, skip_viz: bool):
if path.is_dir():
image_files = [
p for p in path.iterdir() if p.is_file() and p.suffix in [".png", ".jpg"]
]
video_files = [p for p in path.iterdir() if p.is_file() and p.suffix == ".mp4"]

if len(image_files) > 0:
raise NotImplementedError("Predicting on image dir.")

for p in video_files:
_predict_multi_type(model, p, skip_viz)
elif path.suffix == ".mp4":
model.predict_on_video_file(
video_file=path, generate_labeled_video=(not skip_viz)
)
elif path.suffix == ".csv":
model.predict_on_label_csv(
csv_file=path,
generate_labeled_images=False, # TODO: implement visualization
)
elif path.suffix in [".png", ".jpg"]:
raise NotImplementedError("Not yet implemented: predicting on image files.")
else:
pass


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions cli/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import argparse
from pathlib import Path


def config_file(filepath):
"""
Custom argparse type for validating that a file exists and is a yaml file.
Args:
filepath: The file path string.
Returns:
A pathlib.Path object if the file is valid, otherwise raises an error.
"""
path = Path(filepath)
if not path.is_file():
raise argparse.ArgumentTypeError(f"File not found: {filepath}")
if not path.suffix == ".yaml":
raise argparse.ArgumentTypeError(f"File must be a yaml file: {filepath}")
return path


def model_dir(filepath):
path = Path(filepath)
return path


def existing_model_dir(filepath):
path = model_dir(filepath)
if not path.is_dir():
raise argparse.ArgumentTypeError(
f"Directory model_dir does not exist: {filepath}"
)

return path
Loading

0 comments on commit 047793f

Please sign in to comment.