Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding some more API classes (participant, session, dataset) #50

Merged
merged 3 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 108 additions & 108 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ packages = [{include = "senselab", from = "src"}]
python = "^3.10"
click = "^8.1.7"
jsonschema = "^4.21.1"
datasets = "^2.18.0"
datasets = "^2.19.2"
torch = "^2.2.2"
torchvision = "^0.17.2"
torchaudio = "^2.2.2"
Expand All @@ -38,9 +38,9 @@ soundfile = "^0.12.1"
ffmpeg-python = "^0.2.0"
ipykernel = "^6.29.4"
pydra = "^0.23"
pydantic = "^2.7.1"
pydantic = "^2.7.3"
accelerate = "^0.29.3"
huggingface-hub = "^0.23.0"
huggingface-hub = "^0.23.3"
praat-parselmouth = "^0.4.3"
iso-639 = {git = "https://github.com/noumar/iso639.git", tag = "0.4.5"}
opensmile = "^2.5.0"
Expand All @@ -54,7 +54,7 @@ speechbrain = "^1.0.0"
optional = true

[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
pytest = "^8.2.2"
pytest-mock = "^3.14.0"
mypy = "^1.9.0"
pre-commit = "^3.7.0"
Expand Down
59 changes: 59 additions & 0 deletions scripts/experiment7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""This script is used to test the senselab classes."""

from senselab.utils.data_structures.dataset import Participant, SenselabDataset, Session

# Example usage
participant1 = Participant(metadata={"name": "John Doe", "age": 30})
session1 = Session(metadata={"session_name": "Baseline", "date": "2024-06-01"})

print(participant1)
print(session1)

# Creating another participant with a specific ID
participant2 = Participant(id="custom-id-123", metadata={"name": "Jane Smith", "age": 25})
print(participant2)

# Creating a session with default ID
session2 = Session()
print(session2)


# Example usage
dataset = SenselabDataset()

try:
participant1 = Participant(metadata={"name": "John Doe", "age": 30})
dataset.add_participant(participant1)

participant2 = Participant(metadata={"name": "Jane Smith", "age": 25})
dataset.add_participant(participant2)

# Creating another participant with a specific ID
participant3 = Participant(id="123", metadata={"name": "Alice"})
dataset.add_participant(participant3)

# Attempting to create another participant with the same ID should raise an error
participant4 = Participant(id="123", metadata={"name": "Bob"})
dataset.add_participant(participant4)
except ValueError as e:
print("Value error:", e)

try:
session1 = Session(metadata={"session_name": "Baseline", "date": "2024-06-01"})
dataset.add_session(session1)

session2 = Session(metadata={"session_name": "Follow-up", "date": "2024-07-01"})
dataset.add_session(session2)

# Attempting to create another session with the same ID should raise an error
session3 = Session(id="123")
dataset.add_session(session3)

session4 = Session(id="123")
dataset.add_session(session4)
except ValueError as e:
print("Value error:", e)

# Print all participants and sessions
print(dataset.get_participants())
print(dataset.get_sessions())
6 changes: 6 additions & 0 deletions scripts/experiment8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""This script is used to test the HF API."""

from senselab.utils.hf import _check_hf_repo_exists

res = _check_hf_repo_exists("gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e", "model")
print(res)
1 change: 1 addition & 0 deletions src/senselab/audio/tasks/features_extraction/torchaudio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""This module provides the implementation of torchaudio utilities for audio features extraction."""
90 changes: 90 additions & 0 deletions src/senselab/utils/data_structures/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Data structures relevant for managing datasets."""

import uuid
from typing import Any, Dict, List

from pydantic import BaseModel, Field, field_validator


class Participant(BaseModel):
"""Data structure for a participant in a dataset."""

id: str = Field(default_factory=lambda: str(uuid.uuid4()))
metadata: Dict = Field(default={})

@field_validator("id", mode="before")
def set_id(cls, v: str) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this will ever get run because you have a default_factory if the id is not provided so probably not needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the default_factory is if you use it in a Dataset. Otherwise, you need some checks (and this is why we have the validator)

"""Set the unique id of the participant."""
return v or str(uuid.uuid4())

def __eq__(self, other: object) -> bool:
"""Overloads the default BaseModel equality to correctly check that ids are equivalent."""
if isinstance(other, Participant):
return self.id == other.id
return False


class Session(BaseModel):
"""Data structure for a session in a dataset."""

id: str = Field(default_factory=lambda: str(uuid.uuid4()))
metadata: Dict = Field(default={})

@field_validator("id", mode="before")
def set_id(cls, v: str) -> str:
"""Set the unique id of the session."""
return v or str(uuid.uuid4())

def __eq__(self, other: object) -> bool:
"""Overloads the default BaseModel equality to correctly check that ids are equivalent."""
if isinstance(other, Session):
return self.id == other.id
return False


class SenselabDataset(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we made complimentary datasets (mine handles Audio/Video, yours has some base functions for the participants and sessions).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds amazing. If you want, we can make a coding session together tomorrow after lunch and marge all what we have

"""Data structure for a Senselab dataset."""

participants: Dict[str, Participant] = Field(default_factory=dict)
sessions: Dict[str, Session] = Field(default_factory=dict)

@field_validator("participants", mode="before")
def check_unique_participant_id(cls, v: Dict[str, Participant], values: Any) -> Dict[str, Participant]: # noqa: ANN401
"""Check if participant IDs are unique."""
print("type(values)")
print(type(values))
input("Press Enter to continue...")
participants = values.get("participants", {})
for participant_id, _ in v.items():
if participant_id in participants:
raise ValueError(f"Participant with ID {participant_id} already exists.")
return v

@field_validator("sessions", mode="before")
def check_unique_session_id(cls, v: Dict[str, Session], values: Any) -> Dict[str, Session]: # noqa: ANN401
"""Check if session IDs are unique."""
sessions = values.get("sessions", {})
for session_id, _ in v.items():
if session_id in sessions:
raise ValueError(f"Session with ID {session_id} already exists.")
return v

def add_participant(self, participant: Participant) -> None:
"""Add a participant to the dataset."""
if participant.id in self.participants:
raise ValueError(f"Participant with ID {participant.id} already exists.")
self.participants[participant.id] = participant

def add_session(self, session: Session) -> None:
"""Add a session to the dataset."""
if session.id in self.sessions:
raise ValueError(f"Session with ID {session.id} already exists.")
self.sessions[session.id] = session

def get_participants(self) -> List[Participant]:
"""Get the list of participants in the dataset."""
return list(self.participants.values())

def get_sessions(self) -> List[Session]:
"""Get the list of sessions in the dataset."""
return list(self.sessions.values())
22 changes: 11 additions & 11 deletions src/senselab/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,27 @@ class HFModel(BaseModel):

@field_validator("hf_model_id")
def validate_hf_model_id(cls, value: str) -> str:
"""Validate the hf_model_id."""
"""Validate the hf_model_id.

# TODO: enabling using HF token
"""
if not value:
raise ValueError("hf_model_id cannot be empty")
if not os.path.isfile(value) and not _check_hf_repo_exists(
value, "model", None
repo_id=value, revision="main", repo_type="model", token=None
):
raise ValueError("hf_model_id is not a valid Hugging Face model")
return value


def _check_hf_repo_exists(
repo_id: str, repo_type: str, token: Optional[str] = None
repo_id: str, revision: str = "main", repo_type: str = "model", token: Optional[str] = None
) -> bool:
"""Private function to check if a Hugging Face repository exists."""
api = HfApi()
try:
repo_refs = api.list_repo_refs(
repo_id=repo_id, repo_type=repo_type, token=token
)
if repo_refs.branches:
return True
except Exception as e:
raise RuntimeError(f"An error occurred: {e}")
return False
api.list_repo_commits(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token)
return True
except Exception:
# raise RuntimeError(f"An error occurred: {e}")
return False
37 changes: 10 additions & 27 deletions src/senselab/utils/tasks/input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def _from_files_to_dataset(files: List[File]) -> Dataset:
return _from_hf_dataset_to_dict(dataset)


def read_dataset_from_disk(
input_path: str, split: str, streaming: bool = False
) -> Dict[str, Any]:
def read_dataset_from_disk(input_path: str, split: str, streaming: bool = False) -> Dict[str, Any]:
"""Loads a Hugging Face `Dataset` object from disk.

It determines the format based on the file extension or directory.
Expand All @@ -58,9 +56,7 @@ def read_dataset_from_disk(
return _from_hf_dataset_to_dict(dataset)
except Exception as e:
# Generic error handling, e.g., network issues, data loading issues
raise RuntimeError(
f"An error occurred while loading the dataset: {str(e)}"
)
raise RuntimeError(f"An error occurred while loading the dataset: {str(e)}")


def read_dataset_from_hub(
Expand All @@ -73,10 +69,9 @@ def read_dataset_from_hub(

It includes support for private repositories.
"""
if not _check_hf_repo_exists(remote_repository, "dataset", hf_token):
if not _check_hf_repo_exists(remote_repository, "main", "dataset", hf_token):
raise RuntimeError(
f"The repository {remote_repository} - {revision} - {split}"
" does not exist or could not be accessed."
f"The repository {remote_repository} - {revision} - {split}" " does not exist or could not be accessed."
)

# Load the dataset
Expand All @@ -89,9 +84,7 @@ def read_dataset_from_hub(
)
except Exception as e:
# Generic error handling, e.g., network issues, data loading issues
raise RuntimeError(
f"An error occurred while loading the dataset: {str(e)}"
)
raise RuntimeError(f"An error occurred while loading the dataset: {str(e)}")

return _from_hf_dataset_to_dict(dataset)

Expand All @@ -117,9 +110,7 @@ def push_dataset_to_hub(
token=hf_token,
)
else:
hf_dataset.push_to_hub(
repo_id=remote_repository, revision=revision, split=split
)
hf_dataset.push_to_hub(repo_id=remote_repository, revision=revision, split=split)
except Exception as e:
raise RuntimeError(f"Failed to push dataset to the hub: {str(e)}")
return
Expand All @@ -140,27 +131,21 @@ def save_dataset_to_disk(
output_path = os.path.join(output_directory, output_name)
# No extension for Arrow, it's a directory
else:
output_path = os.path.join(
output_directory, f"{output_name}.{output_format}"
)
output_path = os.path.join(output_directory, f"{output_name}.{output_format}")

# Create the output directory, ignore error if it already exists
os.makedirs(output_directory, exist_ok=True)

if output_format == "parquet":

def _save_hf_dataset_as_parquet(
dataset: Dataset, output_path: str
) -> None:
def _save_hf_dataset_as_parquet(dataset: Dataset, output_path: str) -> None:
"""Saves a Hugging Face `Dataset` object to parquet format."""
dataset.to_parquet(output_path)

_save_hf_dataset_as_parquet(hf_dataset, output_path)
elif output_format == "json":

def _save_hf_dataset_as_json(
dataset: Dataset, output_path: str
) -> None:
def _save_hf_dataset_as_json(dataset: Dataset, output_path: str) -> None:
"""Saves a Hugging Face `Dataset` object to json format."""
dataset.to_json(output_path)

Expand All @@ -181,9 +166,7 @@ def _save_hf_dataset_as_sql(dataset: Dataset, output_path: str) -> None:
_save_hf_dataset_as_sql(hf_dataset, output_path)
elif output_format == "arrow":

def _save_hf_dataset_as_arrow(
dataset: Dataset, output_path: str
) -> None:
def _save_hf_dataset_as_arrow(dataset: Dataset, output_path: str) -> None:
"""Saves a Hugging Face `Dataset` object in Apache Arrow format."""
dataset.save_to_disk(output_path)

Expand Down
Loading
Loading