diff --git a/openml/config.py b/openml/config.py index b21c981e2..bf7ba1031 100644 --- a/openml/config.py +++ b/openml/config.py @@ -8,6 +8,7 @@ import logging.handlers import os import platform +import shutil import warnings from io import StringIO from pathlib import Path @@ -20,6 +21,8 @@ console_handler: logging.StreamHandler | None = None file_handler: logging.handlers.RotatingFileHandler | None = None +OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR" + class _Config(TypedDict): apikey: str @@ -101,14 +104,50 @@ def set_file_log_level(file_output_level: int) -> None: # Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards) _user_path = Path("~").expanduser().absolute() + + +def _resolve_default_cache_dir() -> Path: + user_defined_cache_dir = os.environ.get(OPENML_CACHE_DIR_ENV_VAR) + if user_defined_cache_dir is not None: + return Path(user_defined_cache_dir) + + if platform.system().lower() != "linux": + return _user_path / ".openml" + + xdg_cache_home = os.environ.get("XDG_CACHE_HOME") + if xdg_cache_home is None: + return Path("~", ".cache", "openml") + + # This is the proper XDG_CACHE_HOME directory, but + # we unfortunately had a problem where we used XDG_CACHE_HOME/org, + # we check heuristically if this old directory still exists and issue + # a warning if it does. There's too much data to move to do this for the user. + + # The new cache directory exists + cache_dir = Path(xdg_cache_home) / "openml" + if cache_dir.exists(): + return cache_dir + + # The old cache directory *does not* exist + heuristic_dir_for_backwards_compat = Path(xdg_cache_home) / "org" / "openml" + if not heuristic_dir_for_backwards_compat.exists(): + return cache_dir + + root_dir_to_delete = Path(xdg_cache_home) / "org" + openml_logger.warning( + "An old cache directory was found at '%s'. This directory is no longer used by " + "OpenML-Python. To silence this warning you would need to delete the old cache " + "directory. The cached files will then be located in '%s'.", + root_dir_to_delete, + cache_dir, + ) + return Path(xdg_cache_home) + + _defaults: _Config = { "apikey": "", "server": "https://www.openml.org/api/v1/xml", - "cachedir": ( - Path(os.environ.get("XDG_CACHE_HOME", _user_path / ".cache" / "openml")) - if platform.system() == "Linux" - else _user_path / ".openml" - ), + "cachedir": _resolve_default_cache_dir(), "avoid_duplicate_runs": True, "retry_policy": "human", "connection_n_retries": 5, @@ -218,11 +257,66 @@ def stop_using_configuration_for_example(cls) -> None: cls._start_last_called = False +def _handle_xdg_config_home_backwards_compatibility( + xdg_home: str, +) -> Path: + # NOTE(eddiebergman): A previous bug results in the config + # file being located at `${XDG_CONFIG_HOME}/config` instead + # of `${XDG_CONFIG_HOME}/openml/config`. As to maintain backwards + # compatibility, where users may already may have had a configuration, + # we copy it over an issue a warning until it's deleted. + # As a heurisitic to ensure that it's "our" config file, we try parse it first. + config_dir = Path(xdg_home) / "openml" + + backwards_compat_config_file = Path(xdg_home) / "config" + if not backwards_compat_config_file.exists(): + return config_dir + + # If it errors, that's a good sign it's not ours and we can + # safely ignore it, jumping out of this block. This is a heurisitc + try: + _parse_config(backwards_compat_config_file) + except Exception: # noqa: BLE001 + return config_dir + + # Looks like it's ours, lets try copy it to the correct place + correct_config_location = config_dir / "config" + try: + # We copy and return the new copied location + shutil.copy(backwards_compat_config_file, correct_config_location) + openml_logger.warning( + "An openml configuration file was found at the old location " + f"at {backwards_compat_config_file}. We have copied it to the new " + f"location at {correct_config_location}. " + "\nTo silence this warning please verify that the configuration file " + f"at {correct_config_location} is correct and delete the file at " + f"{backwards_compat_config_file}." + ) + return config_dir + except Exception as e: # noqa: BLE001 + # We failed to copy and its ours, return the old one. + openml_logger.warning( + "While attempting to perform a backwards compatible fix, we " + f"failed to copy the openml config file at " + f"{backwards_compat_config_file}' to {correct_config_location}" + f"\n{type(e)}: {e}", + "\n\nTo silence this warning, please copy the file " + "to the new location and delete the old file at " + f"{backwards_compat_config_file}.", + ) + return backwards_compat_config_file + + def determine_config_file_path() -> Path: - if platform.system() == "Linux": - config_dir = Path(os.environ.get("XDG_CONFIG_HOME", Path("~") / ".config" / "openml")) + if platform.system().lower() == "linux": + xdg_home = os.environ.get("XDG_CONFIG_HOME") + if xdg_home is not None: + config_dir = _handle_xdg_config_home_backwards_compatibility(xdg_home) + else: + config_dir = Path("~", ".config", "openml") else: config_dir = Path("~") / ".openml" + # Still use os.path.expanduser to trigger the mock in the unit test config_dir = Path(config_dir).expanduser().resolve() return config_dir / "config" @@ -260,11 +354,15 @@ def _setup(config: _Config | None = None) -> None: apikey = config["apikey"] server = config["server"] show_progress = config["show_progress"] - short_cache_dir = Path(config["cachedir"]) n_retries = int(config["connection_n_retries"]) set_retry_policy(config["retry_policy"], n_retries) + user_defined_cache_dir = os.environ.get(OPENML_CACHE_DIR_ENV_VAR) + if user_defined_cache_dir is not None: + short_cache_dir = Path(user_defined_cache_dir) + else: + short_cache_dir = Path(config["cachedir"]) _root_cache_directory = short_cache_dir.expanduser().resolve() try: diff --git a/tests/test_openml/test_config.py b/tests/test_openml/test_config.py index a92cd0cfd..d9b8c30b9 100644 --- a/tests/test_openml/test_config.py +++ b/tests/test_openml/test_config.py @@ -1,10 +1,12 @@ # License: BSD 3-Clause from __future__ import annotations +from contextlib import contextmanager import os import tempfile import unittest.mock from copy import copy +from typing import Any, Iterator from pathlib import Path import pytest @@ -13,6 +15,24 @@ import openml.testing +@contextmanager +def safe_environ_patcher(key: str, value: Any) -> Iterator[None]: + """Context manager to temporarily set an environment variable. + + Safe to errors happening in the yielded to function. + """ + _prev = os.environ.get(key) + os.environ[key] = value + try: + yield + except Exception as e: + raise e + finally: + os.environ.pop(key) + if _prev is not None: + os.environ[key] = _prev + + class TestConfig(openml.testing.TestBase): @unittest.mock.patch("openml.config.openml_logger.warning") @unittest.mock.patch("openml.config._create_log_handlers") @@ -29,15 +49,22 @@ def test_non_writable_home(self, log_handler_mock, warnings_mock): assert not log_handler_mock.call_args_list[0][1]["create_file_handler"] assert openml.config._root_cache_directory == Path(td) / "something-else" - @unittest.mock.patch("os.path.expanduser") - def test_XDG_directories_do_not_exist(self, expanduser_mock): + def test_XDG_directories_do_not_exist(self): with tempfile.TemporaryDirectory(dir=self.workdir) as td: + # Save previous state + path = Path(td) / "fake_xdg_cache_home" + with safe_environ_patcher("XDG_CONFIG_HOME", str(path)): + expected_config_dir = path / "openml" + expected_determined_config_file_path = expected_config_dir / "config" - def side_effect(path_): - return os.path.join(td, str(path_).replace("~/", "")) + # Ensure that it correctly determines the path to the config file + determined_config_file_path = openml.config.determine_config_file_path() + assert determined_config_file_path == expected_determined_config_file_path - expanduser_mock.side_effect = side_effect - openml.config._setup() + # Ensure that setup will create the config folder as the configuration + # will be written to that location. + openml.config._setup() + assert expected_config_dir.exists() def test_get_config_as_dict(self): """Checks if the current configuration is returned accurately as a dict.""" @@ -121,7 +148,7 @@ def test_example_configuration_start_twice(self): def test_configuration_file_not_overwritten_on_load(): - """ Regression test for #1337 """ + """Regression test for #1337""" config_file_content = "apikey = abcd" with tempfile.TemporaryDirectory() as tmpdir: config_file_path = Path(tmpdir) / "config" @@ -136,12 +163,22 @@ def test_configuration_file_not_overwritten_on_load(): assert config_file_content == new_file_content assert "abcd" == read_config["apikey"] + def test_configuration_loads_booleans(tmp_path): config_file_content = "avoid_duplicate_runs=true\nshow_progress=false" - with (tmp_path/"config").open("w") as config_file: + with (tmp_path / "config").open("w") as config_file: config_file.write(config_file_content) read_config = openml.config._parse_config(tmp_path) # Explicit test to avoid truthy/falsy modes of other types assert True == read_config["avoid_duplicate_runs"] assert False == read_config["show_progress"] + + +def test_openml_cache_dir_env_var(tmp_path: Path) -> None: + expected_path = tmp_path / "test-cache" + + with safe_environ_patcher("OPENML_CACHE_DIR", str(expected_path)): + openml.config._setup() + assert openml.config._root_cache_directory == expected_path + assert openml.config.get_cache_directory() == str(expected_path / "org" / "openml" / "www")