generated from sensein/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into 232-bug-_select_device_and_dtype-does-not-wo…
…rk-on-virtualized-hardware
- Loading branch information
Showing
17 changed files
with
1,377 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,11 @@ | ||
"""This module provides the API of the senselab video data structures.""" | ||
|
||
from .pose import ( # noqa: F401 | ||
ImagePose, | ||
IndividualPose, | ||
MediaPipePoseLandmark, | ||
PoseLandmark, | ||
PoseModel, | ||
YOLOPoseLandmark, | ||
) | ||
from .video import Video # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
"""Data structures relevant for pose estimation.""" | ||
|
||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
from typing import Dict, List, Optional | ||
|
||
import numpy as np | ||
from pydantic import BaseModel, ConfigDict, field_validator | ||
|
||
|
||
class PoseModel(str, Enum): | ||
"""Enum representing different pose estimation models. | ||
Attributes: | ||
MEDIAPIPE (str): Enum value for MediaPipe pose estimation. | ||
YOLO (str): Enum value for YOLO pose estimation. | ||
""" | ||
|
||
MEDIAPIPE = "mp" | ||
YOLO = "yolo" | ||
|
||
|
||
class PoseLandmark(ABC, BaseModel): | ||
"""Abstract base class for representing pose landmarks. | ||
Methods: | ||
as_list() -> List[float]: | ||
Convert the landmark data to a list of coordinates. | ||
""" | ||
|
||
@abstractmethod | ||
def as_list(self) -> List[float]: | ||
"""Convert the landmark data to a list of floating-point values. | ||
Returns: | ||
List[float]: List of landmark data as floats. | ||
""" | ||
pass | ||
|
||
|
||
class MediaPipePoseLandmark(PoseLandmark): | ||
"""Represents a pose landmark detected by MediaPipe. | ||
Attributes: | ||
x (float): X-coordinate of the landmark. | ||
y (float): Y-coordinate of the landmark. | ||
z (float): Z-coordinate of the landmark. | ||
visibility (float): Probability of the landmark being visible [0, 1]. | ||
""" | ||
|
||
x: float | ||
y: float | ||
z: float | ||
visibility: float | ||
|
||
def as_list(self) -> List[float]: | ||
"""Convert the landmark data to a list. | ||
Returns: | ||
List[float]: [x, y, z, visibility] values. | ||
""" | ||
return [self.x, self.y, self.z, self.visibility] | ||
|
||
|
||
class YOLOPoseLandmark(PoseLandmark): | ||
"""Represents a pose keypoint detected by YOLO. | ||
Attributes: | ||
x (float): X-coordinate of the keypoint. | ||
y (float): Y-coordinate of the keypoint. | ||
confidence (float): Confidence score of the detected keypoint [0, 1]. | ||
""" | ||
|
||
x: float | ||
y: float | ||
confidence: float | ||
|
||
def as_list(self) -> List[float]: | ||
"""Convert the keypoint data to a list. | ||
Returns: | ||
List[float]: [x, y, confidence] values. | ||
""" | ||
return [self.x, self.y, self.confidence] | ||
|
||
|
||
class IndividualPose(BaseModel): | ||
"""Data structure for the estimated pose of a single individual in an image. | ||
Attributes: | ||
individual_index (int): Index of the individual in the detection result. | ||
normalized_landmarks (Dict[str, PoseLandmark]): Dictionary of body landmarks with normalized coordinates. | ||
world_landmarks (Optional[Dict[str, PoseLandmark]]): Dictionary of body landmarks with real-world coordinates. | ||
""" | ||
|
||
individual_index: int | ||
normalized_landmarks: Dict[str, PoseLandmark] | ||
world_landmarks: Optional[Dict[str, PoseLandmark]] = None | ||
|
||
def get_landmark_coordinates(self, landmark: str, world: bool = False) -> PoseLandmark: | ||
"""Retrieve coordinates for a specific landmark. | ||
Args: | ||
landmark (str): Name of the landmark (e.g., "right_ankle"). | ||
world (bool): Whether to retrieve world coordinates. Defaults to False. | ||
Returns: | ||
PoseLandmark: Object containing information on the specified landmark. | ||
Raises: | ||
ValueError: If the specified landmark is not found. | ||
""" | ||
landmarks = self.world_landmarks if world else self.normalized_landmarks | ||
if not landmarks: | ||
raise ValueError("No landmarks available.") | ||
if landmark not in landmarks: | ||
raise ValueError(f"Landmark '{landmark}' not found. Available landmarks: {sorted(landmarks.keys())}") | ||
return landmarks[landmark] | ||
|
||
|
||
class ImagePose(BaseModel): | ||
"""Data structure for storing estimated poses of multiple individuals in an image. | ||
Attributes: | ||
model (PoseModel): The model used for pose estimation. | ||
image (np.ndarray): Original image as a NumPy array. | ||
individuals (List[IndividualPose]): List of IndividualPose objects for each detected individual. | ||
""" | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
model: PoseModel | ||
image: np.ndarray | ||
individuals: List[IndividualPose] | ||
|
||
@field_validator("image", mode="before") | ||
def validate_image(cls, value: np.ndarray) -> np.ndarray: | ||
"""Ensures image is a 3D numpy array with three color channels.""" | ||
if not isinstance(value, np.ndarray): | ||
raise TypeError("Field 'image' must be a NumPy array.") | ||
if value.ndim != 3 or value.shape[2] != 3: | ||
raise ValueError("Field 'image' must be a 3D array with three color channels (RGB).") | ||
return value | ||
|
||
def get_individual(self, individual_index: int) -> IndividualPose: | ||
"""Retrieve a specific individual's pose data. | ||
Args: | ||
individual_index (int): Index of the individual to retrieve. | ||
Returns: | ||
IndividualPose: Pose data for the specified individual. | ||
Raises: | ||
ValueError: If the index is invalid or no individuals are detected. | ||
""" | ||
if individual_index >= len(self.individuals) or individual_index < 0: | ||
raise ValueError( | ||
f"Individual index {individual_index} is invalid. {len(self.individuals)} poses were estimated. " | ||
f"Valid indices are {f'0 to {len(self.individuals)-1}' if len(self.individuals) > 0 else 'none'}" | ||
) | ||
return self.individuals[individual_index] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Tasks for video processing.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""".. include:: ./doc.md""" # noqa: D415 | ||
|
||
from .api import estimate_pose, visualize_pose # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""This module provides the API for pose estimation tasks.""" | ||
|
||
from typing import Any, Optional | ||
|
||
import numpy as np | ||
|
||
from senselab.video.data_structures.pose import ImagePose | ||
from senselab.video.tasks.pose_estimation.estimate import ( | ||
MediaPipePoseEstimator, | ||
YOLOPoseEstimator, | ||
) | ||
from senselab.video.tasks.pose_estimation.visualization import visualize | ||
|
||
|
||
def estimate_pose(image_path: str, model: str, **kwargs: Any) -> ImagePose: # noqa ANN401 | ||
"""Estimate poses in an image using the specified model. | ||
Args: | ||
image_path (str): Path to the input image file. | ||
model (str): The model to use for pose estimation. Options are 'mediapipe' or 'yolo'. | ||
**kwargs: Additional keyword arguments for model-specific configurations: | ||
- For MediaPipe: | ||
- model_type (str): Type of MediaPipe model ('lite', 'full', 'heavy'). Defaults to 'lite'. | ||
- num_individuals (int): Maximum number of individuals to detect. Defaults to 1. | ||
- For YOLO: | ||
- model_type (str): Type of YOLO model ('8n', '8s', '11l', etc.). Defaults to '8n'. | ||
Returns: | ||
ImagePose: An object containing the estimated poses. | ||
Raises: | ||
ValueError: If an unsupported model or invalid arguments are provided. | ||
Examples: | ||
>>> estimate_pose("image.jpg", model="mediapipe", model_type="full", num_individuals=2) | ||
>>> estimate_pose("image.jpg", model="yolo", model_type="8n") | ||
""" | ||
if model == "mediapipe": | ||
model_type = kwargs.get("model_type", "lite") | ||
num_individuals = kwargs.get("num_individuals", 1) | ||
|
||
if not isinstance(model_type, str): | ||
raise ValueError("Invalid 'model_type' for MediaPipe. Must be a string.") | ||
if not isinstance(num_individuals, int) or num_individuals < 1: | ||
raise ValueError("'num_individuals' must be a positive integer.") | ||
|
||
estimator = MediaPipePoseEstimator(model_type=model_type) | ||
return estimator.estimate_from_path(image_path, num_individuals=num_individuals) | ||
|
||
elif model == "yolo": | ||
model_type = kwargs.get("model_type", "8n") # type: ignore[no-redef] | ||
|
||
if not isinstance(model_type, str): | ||
raise ValueError("Invalid 'model_type' for YOLO. Must be a string.") | ||
|
||
estimator = YOLOPoseEstimator(model_type=model_type) # type: ignore[assignment] | ||
return estimator.estimate_from_path(image_path) | ||
|
||
else: | ||
raise ValueError(f"Unsupported model: {model}") | ||
|
||
|
||
def visualize_pose(pose_image: ImagePose, output_path: Optional[str] = None, plot: bool = False) -> np.ndarray: | ||
"""Visualize detected poses by drawing landmarks and connections on the image. | ||
Args: | ||
pose_image (ImagePose): The pose estimation result containing detected poses. | ||
output_path (str): Optional path to save the visualized image. If provided, saves the | ||
annotated image to this location. | ||
plot (bool): Whether to display the annotated image using matplotlib. | ||
Returns: | ||
np.ndarray: The input image with pose landmarks and connections drawn on it. | ||
""" | ||
annotated_image = visualize(pose_image, output_path=output_path, plot=plot) | ||
return annotated_image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Pose Estimation | ||
|
||
[![Tutorial](https://img.shields.io/badge/Tutorial-Click%20Here-blue?style=for-the-badge)](https://github.com/sensein/senselab/blob/main/tutorials/video/pose_estimation.ipynb) | ||
|
||
|
||
## Task Overview | ||
|
||
Pose estimation is the process of detecting and tracking key points on a human body or other objects in images or videos. These key points represent joints, limbs, or other significant regions of interest. Pose estimation is widely used in applications such as motion analysis, sports performance tracking, gesture recognition, and augmented reality. | ||
|
||
`senselab` supports pose estimation using **MediaPipe** and **YOLO**, offering models with varying accuracy, speed, and computational requirements. | ||
|
||
--- | ||
|
||
## Models | ||
|
||
### [MediaPipe](https://ai.google.dev/edge/mediapipe/solutions/vision/pose_landmarker) | ||
MediaPipe provides three pose estimation models: | ||
- **Lite**: Optimized for mobile devices with low latency requirements. | ||
- **Full**: Balanced between accuracy and efficiency, suitable for most applications. | ||
- **Heavy**: High-accuracy model designed for tasks where precision is critical but latency is less of a concern. | ||
|
||
These models detect 33 key points across the body, including joints, eyes, ears, and the nose. | ||
|
||
### YOLO | ||
YOLO-based pose estimation models are efficient and capable of detecting key points in real-time. Supported variants include: | ||
- **[YOLOv8](https://docs.ultralytics.com/models/yolov8/)** and **[YOLOv11](https://docs.ultralytics.com/models/yolo11/#what-tasks-can-yolo11-models-perform)** families, with increasing model sizes (e.g., `8n`, `8s`, `11n`, `11l`) to balance speed and accuracy. | ||
|
||
These models detect 17 key points, including joints like shoulders, elbows, knees, and ankles. | ||
|
||
|
||
## Evaluation | ||
|
||
### Metrics | ||
- **Percentage of Correct Parts (PCP)**: Evaluates limb detection accuracy. A limb is considered correct if the predicted key points are within half the limb’s length from the true points. | ||
- **Percentage of Correct Keypoints (PCK)**: Considers a key point correct if the distance between the true and predicted points is within a threshold (e.g., 0.2 times the person’s head bone length). | ||
- **Percentage of Detected Joints (PDJ)**: Evaluates joints as correct if the true and predicted points are within a fraction of the torso’s diameter. | ||
- **Object Keypoint Similarity (OKS)**: Measures the normalized distance between true and predicted key points, scaled by the person’s size. Computes the mean Average Precision (mAP) for all key points in the frame. | ||
|
||
|
||
### Benchmark Datasets | ||
- **[COCO Keypoints](https://cocodataset.org/#keypoints-2020)**: Annotated key points for human poses in diverse scenes. | ||
- **[MPII Human Pose](http://human-pose.mpi-inf.mpg.de/)**: Dataset focused on human pose estimation. | ||
- **[Leeds Sports Pose Extended](https://github.com/axelcarlier/lsp)**: 10,000 sports images with up to 14 human joint annotations. |
Oops, something went wrong.