diff --git a/src/config.py b/src/config.py index 05d976c..843ce52 100644 --- a/src/config.py +++ b/src/config.py @@ -18,6 +18,8 @@ class Config: def get_config(path: Union[str, Path] = None) -> Config: if path is None: path = Path(__file__).parent.parent.joinpath("env.yaml") + if not Path(path).exists(): + path = Path(__file__).parent.parent.joinpath("example.env.yaml") with open(path, encoding="utf-8") as f: config = Config(**safe_load(f)) return config diff --git a/src/tests/test_config.py b/src/tests/test_config.py index 25b179c..a4d77f2 100644 --- a/src/tests/test_config.py +++ b/src/tests/test_config.py @@ -1,4 +1,3 @@ -from pathlib import Path from unittest.mock import patch from src.config import Config, get_config @@ -14,33 +13,14 @@ class TestConfigFunctions: - @patch("builtins.open", create=True) @patch( "src.config.safe_load", return_value=conf_dict, ) - def test_get_config_with_custom_path(self, mock_safe_load, mock_open): + def test_get_config_with_custom_path(self, mock_safe_load): custom_path = "custom_path.yaml" config = get_config(custom_path) - assert config == Config( - OPENAI_API_KEY="your_key", - OPENAI_PROMPTS_PATH="path", - SOURCE_DIR="source", - PROCESS_DIR="process", - OUTPUT_DIR="output", - YT_PROBA=80, - ) - mock_open.assert_called_once_with(custom_path, encoding="utf-8") mock_safe_load.assert_called_once() - - @patch("builtins.open", create=True) - @patch( - "src.config.safe_load", - return_value=conf_dict, - ) - def test_get_config_with_default_path(self, mock_safe_load, mock_open): - default_path = Path(__file__).parent.parent.parent.joinpath("env.yaml") - config = get_config() assert config == Config( OPENAI_API_KEY="your_key", OPENAI_PROMPTS_PATH="path", @@ -49,5 +29,3 @@ def test_get_config_with_default_path(self, mock_safe_load, mock_open): OUTPUT_DIR="output", YT_PROBA=80, ) - mock_open.assert_called_once_with(default_path, encoding="utf-8") - mock_safe_load.assert_called_once()