Skip to content

Commit

Permalink
avoid using union
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
  • Loading branch information
bangtianliu committed Jan 20, 2025
1 parent 4d96e98 commit 05e84ca
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
24 changes: 11 additions & 13 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tqdm import tqdm
import hashlib
from dataclasses import dataclass, field
from typing import Type, Optional, Callable, Iterable, Any, Union
from typing import Type, Optional, Callable, Iterable, Any
from abc import ABC, abstractmethod
import iree.runtime as ireert # type: ignore
import iree.compiler as ireec # type: ignore
Expand Down Expand Up @@ -876,14 +876,13 @@ def is_valid_for_device(self, device_id: str) -> bool:

def get_candidates_ordered_by_speedup(
self, candidate_results: list[BenchmarkResult]
) -> Union[list[BenchmarkResult], list[tuple[BenchmarkResult, float]]]:
) -> list[tuple[BenchmarkResult, float]]:
"""
returns:
- `list[BenchmarkResult]` sorted by runtime if no valid baselines exist.
- `list[tuple[BenchmarkResult, float]]` sorted by speedup when baselines are available.
Returns a list of tuples (BenchmarkResult, speedup) sorted in ascending order based on speedup
or raw runtime.
If no valid baseline times are available across all devices, candidates are sorted
and returned based on their raw runtime in ascending order.
If no valid baseline times are available across all devices, candidates are sorted based on
their raw runtime. A placeholder speedup value of 1.0 is assigned to each candidate.
If valid baseline times exist, speedup is defined as the ratio of the candidate's runtime to
the average baseline time for the corresponding device as:
Expand All @@ -896,7 +895,10 @@ def get_candidates_ordered_by_speedup(
if not self.is_valid():
logging.warning("No valid baseline times available.")
# Use the candidate time directly when no baselines are available.
return sorted(candidate_results, key=lambda candidate: candidate.time)
return sorted(
[(candidate, 1.0) for candidate in candidate_results],
key=lambda x: x[0].time,
)

# Calculate the fallback baseline as the average of all valid times across devices.
valid_baseline_times = [
Expand Down Expand Up @@ -1065,11 +1067,7 @@ def benchmark(
f"({percentage_of_baseline:.1f}% of baseline)"
)
else:
top_candiates = [
result if isinstance(result, BenchmarkResult) else result[0]
for result in top_candidates_with_speedup
]
for candidate in top_candiates:
for candidate, _ in top_candidates_with_speedup:
time_ms = candidate.time
candidate_id = candidate.candidate_id
top_candidate_ids.append(candidate_id)
Expand Down
17 changes: 10 additions & 7 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,14 @@ def test_baseline_result_handler_speedup():
]

handler = libtuner.BaselineResultHandler()
all_candidates = handler.get_candidates_ordered_by_speedup(candidates)
assert all_candidates == [
candidates[1],
candidates[0],
candidates[2],
all_candidates_with_speedup = handler.get_candidates_ordered_by_speedup(candidates)
assert all_candidates_with_speedup == [
(candidates[1], 1.0),
(candidates[0], 1.0),
(candidates[2], 1.0),
]
top_candidates_with_speedup = all_candidates_with_speedup[:2]
assert [candidate.candidate_id for candidate, _ in top_candidates_with_speedup] == [
6,
5,
]
top_candidates = all_candidates[:2]
assert [candidate.candidate_id for candidate in top_candidates] == [6, 5]

0 comments on commit 05e84ca

Please sign in to comment.