diff --git a/tests/regression_test.py b/tests/regression_test.py index fd4a2d7..f13b403 100644 --- a/tests/regression_test.py +++ b/tests/regression_test.py @@ -59,6 +59,23 @@ def myprint(instr): print(f"{datetime.datetime.now()} - {instr}") +def download_file(url, filename): + response = requests.get(url) + + target_dir = os.path.dirname(filename) + if target_dir and not os.path.exists(target_dir): + os.makedirs(target_dir) + + # Check if the request was successful + if response.status_code == 200: + with open(filename, 'wb') as file: + file.write(response.content) + else: + print(f"Failed to download file. Status code: {response.status_code}") + + return filename + + def download_and_extract_zip(zip_file_name, cache_dir, url, demo_data_dir): # Check and construct the full path of the zip file zip_file_path = os.path.join(cache_dir, zip_file_name) @@ -375,19 +392,28 @@ def test_calibrator(self): default_baseline_path = os.path.join(PROJECT_DIR, "tests", "sybil_ensemble_v1.4.0_calibrations.json") baseline_path = os.environ.get("SYBIL_TEST_BASELINE_PATH", default_baseline_path) - new_cal_dict_path = os.environ.get("SYBIL_TEST_COMPARE_PATH", "~/.sybil/sybil_ensemble_simple_calibrator.json") - new_cal_dict_path = os.path.expanduser(new_cal_dict_path) - raw_new_cal_dict = json.load(open(new_cal_dict_path, "r")) - new_cal_dict = {} - for key, val in raw_new_cal_dict.items(): - new_cal_dict[key] = sybil.models.calibrator.SimpleClassifierGroup.from_json(val) + if not os.path.exists(baseline_path) and baseline_path == default_baseline_path: + os.makedirs(os.path.dirname(default_baseline_path), exist_ok=True) + reference_calibrations_url = "https://www.dropbox.com/scl/fi/2fx6ukmozia7y3u8mie97/sybil_ensemble_v1.4.0_calibrations.json?rlkey=tquids9qo4mkkuf315nqdq0o7&dl=1" + download_file(reference_calibrations_url, default_baseline_path) + + default_cal_dict_path = os.path.expanduser("~/.sybil/sybil_ensemble_simple_calibrator.json") + compare_calibrator_path = os.environ.get("SYBIL_TEST_COMPARE_PATH", default_cal_dict_path) + compare_calibrator_path = os.path.expanduser(compare_calibrator_path) + if not os.path.exists(compare_calibrator_path) and compare_calibrator_path == default_cal_dict_path: + test_model = Sybil("sybil_ensemble") + + raw_calibrator_dict = json.load(open(compare_calibrator_path, "r")) + new_calibrator_dict = {} + for key, val in raw_calibrator_dict.items(): + new_calibrator_dict[key] = sybil.models.calibrator.SimpleClassifierGroup.from_json(val) baseline_preds = json.load(open(baseline_path, "r")) test_probs = np.array(baseline_preds["x"]).reshape(-1, 1) year_keys = [key for key in baseline_preds.keys() if key.startswith("Year")] for year_key in year_keys: baseline_scores = np.array(baseline_preds[year_key]).flatten() - new_cal = new_cal_dict[year_key] + new_cal = new_calibrator_dict[year_key] new_scores = new_cal.predict_proba(test_probs).flatten() self.assertTrue(np.allclose(baseline_scores, new_scores, atol=1e-10), f"Calibration mismatch for {year_key}")