Skip to content

Commit

Permalink
Download reference calibrator and calibrations in unit test if they d…
Browse files Browse the repository at this point in the history
…on't already exist.
  • Loading branch information
jsilter committed Jul 31, 2024
1 parent fa292a6 commit 552eda0
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def myprint(instr):
print(f"{datetime.datetime.now()} - {instr}")


def download_file(url, filename):
response = requests.get(url)

target_dir = os.path.dirname(filename)
if target_dir and not os.path.exists(target_dir):
os.makedirs(target_dir)

# Check if the request was successful
if response.status_code == 200:
with open(filename, 'wb') as file:
file.write(response.content)
else:
print(f"Failed to download file. Status code: {response.status_code}")

return filename


def download_and_extract_zip(zip_file_name, cache_dir, url, demo_data_dir):
# Check and construct the full path of the zip file
zip_file_path = os.path.join(cache_dir, zip_file_name)
Expand Down Expand Up @@ -375,19 +392,28 @@ def test_calibrator(self):
default_baseline_path = os.path.join(PROJECT_DIR, "tests", "sybil_ensemble_v1.4.0_calibrations.json")
baseline_path = os.environ.get("SYBIL_TEST_BASELINE_PATH", default_baseline_path)

new_cal_dict_path = os.environ.get("SYBIL_TEST_COMPARE_PATH", "~/.sybil/sybil_ensemble_simple_calibrator.json")
new_cal_dict_path = os.path.expanduser(new_cal_dict_path)
raw_new_cal_dict = json.load(open(new_cal_dict_path, "r"))
new_cal_dict = {}
for key, val in raw_new_cal_dict.items():
new_cal_dict[key] = sybil.models.calibrator.SimpleClassifierGroup.from_json(val)
if not os.path.exists(baseline_path) and baseline_path == default_baseline_path:
os.makedirs(os.path.dirname(default_baseline_path), exist_ok=True)
reference_calibrations_url = "https://www.dropbox.com/scl/fi/2fx6ukmozia7y3u8mie97/sybil_ensemble_v1.4.0_calibrations.json?rlkey=tquids9qo4mkkuf315nqdq0o7&dl=1"
download_file(reference_calibrations_url, default_baseline_path)

default_cal_dict_path = os.path.expanduser("~/.sybil/sybil_ensemble_simple_calibrator.json")
compare_calibrator_path = os.environ.get("SYBIL_TEST_COMPARE_PATH", default_cal_dict_path)
compare_calibrator_path = os.path.expanduser(compare_calibrator_path)
if not os.path.exists(compare_calibrator_path) and compare_calibrator_path == default_cal_dict_path:
test_model = Sybil("sybil_ensemble")

raw_calibrator_dict = json.load(open(compare_calibrator_path, "r"))
new_calibrator_dict = {}
for key, val in raw_calibrator_dict.items():
new_calibrator_dict[key] = sybil.models.calibrator.SimpleClassifierGroup.from_json(val)

baseline_preds = json.load(open(baseline_path, "r"))
test_probs = np.array(baseline_preds["x"]).reshape(-1, 1)
year_keys = [key for key in baseline_preds.keys() if key.startswith("Year")]
for year_key in year_keys:
baseline_scores = np.array(baseline_preds[year_key]).flatten()
new_cal = new_cal_dict[year_key]
new_cal = new_calibrator_dict[year_key]
new_scores = new_cal.predict_proba(test_probs).flatten()

self.assertTrue(np.allclose(baseline_scores, new_scores, atol=1e-10), f"Calibration mismatch for {year_key}")
Expand Down

0 comments on commit 552eda0

Please sign in to comment.