-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding some more API classes (participant, session, dataset) #50
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""This script is used to test the senselab classes.""" | ||
|
||
from senselab.utils.data_structures.dataset import Participant, SenselabDataset, Session | ||
|
||
# Example usage | ||
participant1 = Participant(metadata={"name": "John Doe", "age": 30}) | ||
session1 = Session(metadata={"session_name": "Baseline", "date": "2024-06-01"}) | ||
|
||
print(participant1) | ||
print(session1) | ||
|
||
# Creating another participant with a specific ID | ||
participant2 = Participant(id="custom-id-123", metadata={"name": "Jane Smith", "age": 25}) | ||
print(participant2) | ||
|
||
# Creating a session with default ID | ||
session2 = Session() | ||
print(session2) | ||
|
||
|
||
# Example usage | ||
dataset = SenselabDataset() | ||
|
||
try: | ||
participant1 = Participant(metadata={"name": "John Doe", "age": 30}) | ||
dataset.add_participant(participant1) | ||
|
||
participant2 = Participant(metadata={"name": "Jane Smith", "age": 25}) | ||
dataset.add_participant(participant2) | ||
|
||
# Creating another participant with a specific ID | ||
participant3 = Participant(id="123", metadata={"name": "Alice"}) | ||
dataset.add_participant(participant3) | ||
|
||
# Attempting to create another participant with the same ID should raise an error | ||
participant4 = Participant(id="123", metadata={"name": "Bob"}) | ||
dataset.add_participant(participant4) | ||
except ValueError as e: | ||
print("Value error:", e) | ||
|
||
try: | ||
session1 = Session(metadata={"session_name": "Baseline", "date": "2024-06-01"}) | ||
dataset.add_session(session1) | ||
|
||
session2 = Session(metadata={"session_name": "Follow-up", "date": "2024-07-01"}) | ||
dataset.add_session(session2) | ||
|
||
# Attempting to create another session with the same ID should raise an error | ||
session3 = Session(id="123") | ||
dataset.add_session(session3) | ||
|
||
session4 = Session(id="123") | ||
dataset.add_session(session4) | ||
except ValueError as e: | ||
print("Value error:", e) | ||
|
||
# Print all participants and sessions | ||
print(dataset.get_participants()) | ||
print(dataset.get_sessions()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""This script is used to test the HF API.""" | ||
|
||
from senselab.utils.hf import _check_hf_repo_exists | ||
|
||
res = _check_hf_repo_exists("gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e", "model") | ||
print(res) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""This module provides the implementation of torchaudio utilities for audio features extraction.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""Data structures relevant for managing datasets.""" | ||
|
||
import uuid | ||
from typing import Any, Dict, List | ||
|
||
from pydantic import BaseModel, Field, field_validator | ||
|
||
|
||
class Participant(BaseModel): | ||
"""Data structure for a participant in a dataset.""" | ||
|
||
id: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
metadata: Dict = Field(default={}) | ||
|
||
@field_validator("id", mode="before") | ||
def set_id(cls, v: str) -> str: | ||
"""Set the unique id of the participant.""" | ||
return v or str(uuid.uuid4()) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
"""Overloads the default BaseModel equality to correctly check that ids are equivalent.""" | ||
if isinstance(other, Participant): | ||
return self.id == other.id | ||
return False | ||
|
||
|
||
class Session(BaseModel): | ||
"""Data structure for a session in a dataset.""" | ||
|
||
id: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
metadata: Dict = Field(default={}) | ||
|
||
@field_validator("id", mode="before") | ||
def set_id(cls, v: str) -> str: | ||
"""Set the unique id of the session.""" | ||
return v or str(uuid.uuid4()) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
"""Overloads the default BaseModel equality to correctly check that ids are equivalent.""" | ||
if isinstance(other, Session): | ||
return self.id == other.id | ||
return False | ||
|
||
|
||
class SenselabDataset(BaseModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like we made complimentary datasets (mine handles Audio/Video, yours has some base functions for the participants and sessions). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds amazing. If you want, we can make a coding session together tomorrow after lunch and marge all what we have |
||
"""Data structure for a Senselab dataset.""" | ||
|
||
participants: Dict[str, Participant] = Field(default_factory=dict) | ||
sessions: Dict[str, Session] = Field(default_factory=dict) | ||
|
||
@field_validator("participants", mode="before") | ||
def check_unique_participant_id(cls, v: Dict[str, Participant], values: Any) -> Dict[str, Participant]: # noqa: ANN401 | ||
"""Check if participant IDs are unique.""" | ||
print("type(values)") | ||
print(type(values)) | ||
input("Press Enter to continue...") | ||
participants = values.get("participants", {}) | ||
for participant_id, _ in v.items(): | ||
if participant_id in participants: | ||
raise ValueError(f"Participant with ID {participant_id} already exists.") | ||
return v | ||
|
||
@field_validator("sessions", mode="before") | ||
def check_unique_session_id(cls, v: Dict[str, Session], values: Any) -> Dict[str, Session]: # noqa: ANN401 | ||
"""Check if session IDs are unique.""" | ||
sessions = values.get("sessions", {}) | ||
for session_id, _ in v.items(): | ||
if session_id in sessions: | ||
raise ValueError(f"Session with ID {session_id} already exists.") | ||
return v | ||
|
||
def add_participant(self, participant: Participant) -> None: | ||
"""Add a participant to the dataset.""" | ||
if participant.id in self.participants: | ||
raise ValueError(f"Participant with ID {participant.id} already exists.") | ||
self.participants[participant.id] = participant | ||
|
||
def add_session(self, session: Session) -> None: | ||
"""Add a session to the dataset.""" | ||
if session.id in self.sessions: | ||
raise ValueError(f"Session with ID {session.id} already exists.") | ||
self.sessions[session.id] = session | ||
|
||
def get_participants(self) -> List[Participant]: | ||
"""Get the list of participants in the dataset.""" | ||
return list(self.participants.values()) | ||
|
||
def get_sessions(self) -> List[Session]: | ||
"""Get the list of sessions in the dataset.""" | ||
return list(self.sessions.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this will ever get run because you have a default_factory if the id is not provided so probably not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the default_factory is if you use it in a Dataset. Otherwise, you need some checks (and this is why we have the validator)