Skip to content

Commit

Permalink
Update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Scienfitz committed Jan 6, 2025
1 parent 8213bef commit 0a42492
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import functools
import logging
from collections.abc import Callable, Collection, Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Literal, TypeVar, overload

import numpy as np
Expand Down Expand Up @@ -70,7 +70,7 @@ def to_tensor(*x: np.ndarray | pd.DataFrame) -> Tensor | tuple[Tensor, ...]:

def add_fake_measurements(
data: pd.DataFrame,
targets: Collection[Target],
targets: Iterable[Target],
good_reference_values: dict[str, list] | None = None,
good_intervals: dict[str, tuple[float, float]] | None = None,
bad_intervals: dict[str, tuple[float, float]] | None = None,
Expand Down Expand Up @@ -279,8 +279,8 @@ def add_parameter_noise(


def create_fake_input(
parameters: Sequence[Parameter],
targets: Sequence[Target],
parameters: Iterable[Parameter],
targets: Iterable[Target],
n_rows: int = 1,
**kwargs: dict,
) -> pd.DataFrame:
Expand Down
6 changes: 3 additions & 3 deletions baybe/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import math
from collections.abc import Callable, Sequence
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any

import pandas as pd
Expand Down Expand Up @@ -75,7 +75,7 @@ def validator(self: Any, attribute: Attribute, value: Any) -> None:
"""Validator for non-infinite floats."""


def validate_target_input(data: pd.DataFrame, targets: Sequence[Target]) -> None:
def validate_target_input(data: pd.DataFrame, targets: Iterable[Target]) -> None:
"""Validate input dataframe columns corresponding to targets.
Args:
Expand Down Expand Up @@ -117,7 +117,7 @@ def validate_target_input(data: pd.DataFrame, targets: Sequence[Target]) -> None

def validate_parameter_input(
data: pd.DataFrame,
parameters: Sequence[Parameter],
parameters: Iterable[Parameter],
numerical_measurements_must_be_within_tolerance: bool = False,
) -> None:
"""Validate input dataframe columns corresponding to parameters.
Expand Down

0 comments on commit 0a42492

Please sign in to comment.