Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix pyright type checks on benchmark directory #935

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Pylint pre-commit hook is now configured as the Pylint docs recommend. (https://github.com/gchq/coreax/pull/899)
- Type annotations so that core coreax package passes Pyright. (https://github.com/gchq/coreax/pull/906)
- Type annotations so that the example scripts pass Pyright. (https://github.com/gchq/coreax/pull/921)
- Type annotations so that the benchmark scripts pass Pyright. (https://github.com/gchq/coreax/pull/935)

### Changed

Expand Down
27 changes: 19 additions & 8 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import json
import os
import time
from typing import TypeVar

import jax
import jax.numpy as jnp
Expand All @@ -47,17 +46,16 @@
from coreax.metrics import KSD, MMD
from coreax.solvers import (
KernelHerding,
KernelThinning,
RandomSample,
RPCholesky,
Solver,
SteinThinning,
)
from coreax.weights import MMDWeightsOptimiser

_Solver = TypeVar("_Solver", bound=Solver)


def setup_kernel(x: jnp.array, random_seed: int = 45) -> SquaredExponentialKernel:
def setup_kernel(x: jax.Array, random_seed: int = 45) -> SquaredExponentialKernel:
"""
Set up a squared exponential kernel using the median heuristic.

Expand Down Expand Up @@ -102,20 +100,23 @@ def setup_solvers(
coreset_size: int,
sq_exp_kernel: SquaredExponentialKernel,
stein_kernel: SteinKernel,
delta: float,
random_seed: int = 45,
) -> list[tuple[str, _Solver]]:
) -> list[tuple[str, Solver]]:
"""
Set up and return a list of solver configurations for reducing a dataset.

:param coreset_size: The size of the coresets to be generated by the solvers.
:param sq_exp_kernel: A Squared Exponential kernel for KernelHerding and RPCholesky.
:param stein_kernel: A Stein kernel object used for the SteinThinning solver.
:param delta: The delta parameter for KernelThinning solver.
:param random_seed: An integer seed for the random number generator.

:return: A list of tuples, where each tuple contains the name of the solver
and the corresponding solver object.
"""
random_key = jax.random.PRNGKey(random_seed)
sqrt_kernel = sq_exp_kernel.get_sqrt_kernel(2)
return [
(
"KernelHerding",
Expand All @@ -141,11 +142,21 @@ def setup_solvers(
regularise=False,
),
),
(
"KernelThinning",
KernelThinning(
coreset_size=coreset_size,
kernel=sq_exp_kernel,
random_key=random_key,
delta=delta,
sqrt_kernel=sqrt_kernel,
),
),
]


def compute_solver_metrics(
solver: _Solver,
solver: Solver,
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
Expand Down Expand Up @@ -188,7 +199,7 @@ def compute_solver_metrics(


def compute_metrics(
solvers: list[tuple[str, _Solver]],
solvers: list[tuple[str, Solver]],
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
Expand Down Expand Up @@ -264,7 +275,7 @@ def main() -> None: # pylint: disable=too-many-locals
aggregated_results[size][solver_name][metric].append(value)

# Average results across seeds
final_results = {"n_samples": n_samples}
final_results: dict = {"n_samples": n_samples}
for size, solvers in aggregated_results.items():
final_results[size] = {}
for solver_name, metrics in solvers.items():
Expand Down
2 changes: 1 addition & 1 deletion benchmark/blobs_benchmark_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def plot_benchmarking_results(data):

# Adjust layout to avoid overlap
plt.subplots_adjust(hspace=15.0, wspace=1.0)
plt.tight_layout(pad=3.0, rect=[0, 0, 1, 0.96])
plt.tight_layout(pad=3.0, rect=(0.0, 0.0, 1.0, 0.96))
plt.show()


Expand Down
7 changes: 5 additions & 2 deletions benchmark/david_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import os
import sys
import time
from pathlib import Path
from typing import Optional
Expand All @@ -38,11 +39,14 @@
import matplotlib.pyplot as plt
import numpy as np
from jax import random
from mnist_benchmark import get_solver_name, initialise_solvers

from benchmark.mnist_benchmark import get_solver_name, initialise_solvers
from coreax import Data
from examples.david_map_reduce_weighted import downsample_opencv

sys.path.append(str(Path(__file__).parent.parent))


MAX_8BIT = 255


Expand All @@ -65,7 +69,6 @@ def benchmark_coreset_algorithms(
"""
# Base directory of the current script
base_dir = os.path.dirname(os.path.abspath(__file__))

# Convert to absolute paths using os.path.join
if not in_path.is_absolute():
in_path = Path(os.path.join(base_dir, in_path))
Expand Down
88 changes: 72 additions & 16 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import umap
from flax import linen as nn
from flax.training import train_state
from jaxtyping import Array, Float
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

Expand All @@ -60,12 +61,14 @@
from coreax.score_matching import KernelDensityMatching
from coreax.solvers import (
KernelHerding,
KernelThinning,
MapReduce,
RandomSample,
RPCholesky,
Solver,
SteinThinning,
)
from coreax.util import KeyArrayLike


# Convert PyTorch dataset to JAX arrays
Expand All @@ -77,7 +80,8 @@ def convert_to_jax_arrays(pytorch_data: Dataset) -> tuple[jnp.ndarray, jnp.ndarr
:return: Tuple of JAX arrays (data, targets).
"""
# Load all data in one batch
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data))
# pyright is wrong here, a Dataset object does have __len__ method
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data)) # type: ignore
# Grab the first batch, which is all data
_data, _targets = next(iter(data_loader))
# Convert to NumPy first, then JAX array
Expand Down Expand Up @@ -149,8 +153,8 @@ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
class TrainState(train_state.TrainState):
"""Custom train state with batch statistics and dropout RNG."""

batch_stats: Optional[dict[str, jnp.ndarray]] = None
dropout_rng: Optional[jnp.ndarray] = None
batch_stats: Optional[dict[str, jnp.ndarray]]
dropout_rng: KeyArrayLike


class Metrics(NamedTuple):
Expand All @@ -161,7 +165,7 @@ class Metrics(NamedTuple):


def create_train_state(
rng: jnp.ndarray, _model: nn.Module, learning_rate: float, weight_decay: float
rng: KeyArrayLike, _model: nn.Module, learning_rate: float, weight_decay: float
) -> TrainState:
"""
Create and initialise the train state.
Expand Down Expand Up @@ -323,7 +327,7 @@ def train_and_evaluate(
train_set: DataSet,
test_set: DataSet,
_model: nn.Module,
rng: jnp.ndarray,
rng: KeyArrayLike,
config: dict[str, Any],
) -> dict[str, float]:
"""
Expand Down Expand Up @@ -426,8 +430,32 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr
return train_data_jax, train_targets_jax, test_data_jax, test_targets_jax


def calculate_delta(n: int) -> Float[Array, "1"]:
"""
Calculate the delta parameter for kernel thinning.

The function evaluates the following cases:
1. If `jnp.log(n)` is positive:
- Further evaluates `jnp.log(jnp.log(n))`.
* If this is also positive, returns `1 / n * jnp.log(jnp.log(n))`.
* Otherwise, returns `1 / n * jnp.log(n)`.
2. If `jnp.log(n)` is negative:
- Returns `1 / n`.

:param n: The size of the dataset we wish to reduce.
:return: The calculated delta value based on the described conditions.
"""
log_n = jnp.log(n)
if log_n > 0:
log_log_n = jnp.log(log_n)
if log_log_n > 0:
return 1 / (n * log_log_n)
return 1 / (n * log_n)
return jnp.array(1 / n)


def initialise_solvers(
train_data_umap: Data, key: jax.random.PRNGKey
train_data_umap: Data, key: KeyArrayLike
) -> list[Callable[[int], Solver]]:
"""
Initialise and return a list of solvers for various coreset algorithms.
Expand All @@ -449,8 +477,30 @@ def initialise_solvers(
random_seed = 45
generator = np.random.default_rng(random_seed)
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
length_scale = median_heuristic(train_data_umap[idx])
length_scale = median_heuristic(jnp.asarray(train_data_umap[idx]))
kernel = SquaredExponentialKernel(length_scale=length_scale)
sqrt_kernel = kernel.get_sqrt_kernel(16)

def _get_thinning_solver(_size: int) -> MapReduce:
"""
Set up KernelThinning to use ``MapReduce``.

Create a KernelThinning solver with the specified size and return
it along with a MapReduce object for reducing a large dataset like
MNIST dataset.

:param _size: The size of the coreset to be generated.
:return: MapReduce solver with KernelThinning as the base solver.
"""
thinning_solver = KernelThinning(
coreset_size=_size,
kernel=kernel,
random_key=key,
delta=calculate_delta(num_data_points).item(),
sqrt_kernel=sqrt_kernel,
)

return MapReduce(thinning_solver, leaf_size=15_000)

def _get_herding_solver(_size: int) -> MapReduce:
"""
Expand All @@ -461,10 +511,10 @@ def _get_herding_solver(_size: int) -> MapReduce:
MNIST dataset.

:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the MapReduce solver.
:return: MapReduce solver with KernelHerding as the base solver.
"""
herding_solver = KernelHerding(_size, kernel)
return MapReduce(herding_solver, leaf_size=3 * _size)
return MapReduce(herding_solver, leaf_size=15_000)

def _get_stein_solver(_size: int) -> MapReduce:
"""
Expand All @@ -475,25 +525,25 @@ def _get_stein_solver(_size: int) -> MapReduce:
a subset of the dataset.

:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the MapReduce solver.
:return: MapReduce solver with SteinThinning as the base solver.
"""
# Generate small dataset for ScoreMatching for Stein Kernel

score_function = KernelDensityMatching(length_scale=length_scale).match(
score_function = KernelDensityMatching(length_scale=length_scale.item()).match(
train_data_umap[idx]
)
stein_kernel = SteinKernel(kernel, score_function)
stein_solver = SteinThinning(
coreset_size=_size, kernel=stein_kernel, regularise=False
)
return MapReduce(stein_solver, leaf_size=3 * _size)
return MapReduce(stein_solver, leaf_size=15_000)

def _get_random_solver(_size: int) -> RandomSample:
"""
Set up Random Sampling to generate a coreset.

:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the RandomSample solver.
:return: A RandomSample solver.
"""
random_solver = RandomSample(_size, key)
return random_solver
Expand All @@ -503,17 +553,23 @@ def _get_rp_solver(_size: int) -> RPCholesky:
Set up Randomised Cholesky solver.

:param _size: The size of the coreset to be generated.
:return: A tuple containing the solver name and the RPCholesky solver.
:return: A RPCholesky solver.
"""
rp_solver = RPCholesky(coreset_size=_size, kernel=kernel, random_key=key)
return rp_solver

return [_get_random_solver, _get_rp_solver, _get_herding_solver, _get_stein_solver]
return [
_get_random_solver,
_get_rp_solver,
_get_herding_solver,
_get_stein_solver,
_get_thinning_solver,
]


def train_model(
data_bundle: dict[str, jnp.ndarray],
key: jax.random.PRNGKey,
key: KeyArrayLike,
config: dict[str, Union[int, float]],
) -> dict[str, float]:
"""
Expand Down
Loading
Loading