-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscript_detect_bad_h5.py
executable file
·161 lines (127 loc) · 4.3 KB
/
script_detect_bad_h5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python3
# coding: utf-8
"""
- Checks if the h5 files have labels.
- Checks if the h5 files have inputs.
- Checks if all the h5 files have the same inputs.
- Checks if all the h5 files have the same labels.
"""
print("Doing imports.")
# Stdlib
import math
from pathlib import Path
import itertools
import json
import shlex
import subprocess
import time
import re
# Third Party
from beartype import beartype
from beartype.typing import *
import h5py # type: ignore[import]
import fire # type: ignore[import]
import more_itertools
import numpy as np
import pretty_traceback # type: ignore[import]
import rich
import torch
from tqdm import tqdm # type: ignore[import]
# First Party
import general_utils
print("Done with imports.")
pretty_traceback.install()
SCRIPT_DIR = Path(__file__).absolute().parent
H5_INPUT_IDS_KEY = "input_samples"
H5_LABEL_IDS_KEY = "label_ids"
H5_PREDICTIONS_KEY = "predictions"
def length_stats(h5):
lengths = (h5[H5_PREDICTIONS_KEY][:] != 0).cumsum(axis=2).max(axis=2)
assert lengths.shape == h5[H5_PREDICTIONS_KEY].shape[:2]
return lengths.mean(), lengths.std(), lengths.max()
def detect_bads(h5_paths: Sequence[Union[Path, str]], num_epochs: int):
h5_paths = cast(Sequence[Path], [Path(h5_path) for h5_path in h5_paths])
files = [h5py.File(path, "r") for path in h5_paths]
all_inputs_are_the_same = all([
np.all(files[0][H5_INPUT_IDS_KEY][:num_epochs] ==
files[i][H5_INPUT_IDS_KEY][:num_epochs])
for i in range(len(files))
])
all_labels_are_the_same = all([
np.all(files[0][H5_LABEL_IDS_KEY][:num_epochs] ==
files[i][H5_LABEL_IDS_KEY][:num_epochs])
for i in range(len(files))
])
assert all_inputs_are_the_same
rich.print("[bold green]All files had the same inputs.")
assert all_labels_are_the_same
rich.print("[bold green]All files had the same labels.")
print()
rich.print("[bold]Doing prediction length stats.")
means = []
stds = []
maxes = []
for i, (file, path) in enumerate(more_itertools.zip_equal(tqdm(files), h5_paths)):
size = path.stat().st_size
mean, std, max_ = length_stats(file)
means.append(mean)
stds.append(std)
maxes.append(max_)
rich.print(f" - {general_utils.shorten_path(path)}")
rich.print(f" - {general_utils.to_human_size(size)}")
rich.print(f" - {file['predictions'].shape}")
rich.print(f" - {mean = :.1f}")
rich.print(f" - {std = :.1f}")
rich.print(f" - {max_ = }")
print()
assert max_ == file['predictions'].shape[2]
print()
rich.print("[bold]Means:")
rich.print(means)
print()
rich.print("[bold]Stds:")
rich.print(stds)
print()
rich.print("[bold]Maxes:")
rich.print(maxes)
@beartype
def main(
directory: Union[Path, str] = SCRIPT_DIR / "log_results" / "oracle",
max_epochs: int = 60,
):
general_utils.check_and_print_args(locals().copy(), main)
directory = Path(directory)
assert directory.exists(), directory
h5_paths = general_utils.sort_iterable_text(list(directory.glob("**/predictions.h5")))
print()
rich.print(f"[bold]All paths: [/bold]({len(h5_paths)})")
general_utils.print_list(h5_paths)
print()
assert h5_paths
bad_ones = detect_bads(h5_paths, num_epochs=max_epochs)
good_ones = set(h5_paths) - bad_ones
print()
rich.print(f"[bold]Good ones: [/bold]({len(good_ones)}/{len(h5_paths)})")
# general_utils.print_list(general_utils.sort_iterable_text(good_ones))
print()
rich.print(f"[bold]Bad ones: [/bold]({len(bad_ones)}/{len(h5_paths)})")
general_utils.print_list(general_utils.sort_iterable_text(bad_ones))
print()
dont_have_label_ids = set()
keys = []
for path in h5_paths:
with h5py.File(path, "r") as file:
keys.extend(file.keys())
if H5_LABEL_IDS_KEY not in file:
dont_have_label_ids.add(path)
rich.print(f"[bold]Keys:")
general_utils.print_dict(dict(Counter(keys).items()))
print()
rich.print(
f"[bold]Don't have label_ids: [/bold]"
f"({len(dont_have_label_ids)}/{len(h5_paths)})"
)
general_utils.print_list(general_utils.sort_iterable_text(dont_have_label_ids))
print()
if __name__ == "__main__":
fire.Fire(main)