Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
peiva-git committed Feb 19, 2024
1 parent 863b4c5 commit 1c650a8
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion basketballtrainer/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
take a look at the `basketballtrainer.data.convert_dataset` module.
"""

from .convert_dataset import convert_dataset_to_paddleseg_format
from .convert_dataset import convert_dataset_to_paddleseg_format, pseudocolor_mask_to_grayscale
from .dataset_builders import PaddleSegDatasetBuilder
9 changes: 9 additions & 0 deletions basketballtrainer/data/convert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import pathlib
import shutil

import numpy as np


def convert_dataset_to_paddleseg_format(dataset_path: str, target_path: str):
"""
Expand Down Expand Up @@ -75,3 +77,10 @@ def __generate_ordered_filenames_lists(source: pathlib.Path) -> ([str], [str]):
)
args = parser.parse_args()
convert_dataset_to_paddleseg_format(args.source_dir, args.target_dir)


def pseudocolor_mask_to_grayscale(mask: np.ndarray) -> np.ndarray:
mask = np.argmax(mask, axis=-1, keepdims=False)
mask = mask - 2
mask = np.absolute(mask)
return mask
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@
import numpy as np
import paddle as pp


def pseudocolor_mask_to_grayscale(mask: np.ndarray) -> np.ndarray:
mask = np.argmax(mask, axis=-1, keepdims=False)
mask = mask - 2
mask = np.absolute(mask)
return mask
from basketballtrainer.data import pseudocolor_mask_to_grayscale


def postprocess_mask(mask: np.ndarray, min_size: int, max_radius: int, filters=('rm-small',)) -> np.ndarray:
Expand All @@ -28,10 +23,10 @@ def postprocess_mask(mask: np.ndarray, min_size: int, max_radius: int, filters=(
return filtered.astype(np.int64)


def postprocess_labels_dir(source_dir: pathlib.Path,
target_dir: pathlib.Path,
min_size: int = 625,
max_radius: int = 19):
def postprocess_masks_dir(source_dir: pathlib.Path,
target_dir: pathlib.Path,
min_size: int = 625,
max_radius: int = 19):
source_pattern = str(source_dir / '*.png')
masks = imread_collection(source_pattern)
for mask_index, mask in enumerate(masks):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ Pillow>=10.0.0
numpy>=1.25.0
opencv-python>=4.5.0
pdoc==14.1.*
scikit-image~=0.22

0 comments on commit 1c650a8

Please sign in to comment.