From 08dc6aa12e0f00673bc5230cad3defd45c511f0f Mon Sep 17 00:00:00 2001 From: "Mo (laptop)" Date: Mon, 13 Nov 2023 00:43:07 +0500 Subject: [PATCH] Switch to example.env.yaml if no env.yaml Config test tweak --- src/config.py | 2 ++ src/tests/test_config.py | 24 +----------------------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/config.py b/src/config.py index 05d976c..843ce52 100644 --- a/src/config.py +++ b/src/config.py @@ -18,6 +18,8 @@ class Config: def get_config(path: Union[str, Path] = None) -> Config: if path is None: path = Path(__file__).parent.parent.joinpath("env.yaml") + if not Path(path).exists(): + path = Path(__file__).parent.parent.joinpath("example.env.yaml") with open(path, encoding="utf-8") as f: config = Config(**safe_load(f)) return config diff --git a/src/tests/test_config.py b/src/tests/test_config.py index 25b179c..a4d77f2 100644 --- a/src/tests/test_config.py +++ b/src/tests/test_config.py @@ -1,4 +1,3 @@ -from pathlib import Path from unittest.mock import patch from src.config import Config, get_config @@ -14,33 +13,14 @@ class TestConfigFunctions: - @patch("builtins.open", create=True) @patch( "src.config.safe_load", return_value=conf_dict, ) - def test_get_config_with_custom_path(self, mock_safe_load, mock_open): + def test_get_config_with_custom_path(self, mock_safe_load): custom_path = "custom_path.yaml" config = get_config(custom_path) - assert config == Config( - OPENAI_API_KEY="your_key", - OPENAI_PROMPTS_PATH="path", - SOURCE_DIR="source", - PROCESS_DIR="process", - OUTPUT_DIR="output", - YT_PROBA=80, - ) - mock_open.assert_called_once_with(custom_path, encoding="utf-8") mock_safe_load.assert_called_once() - - @patch("builtins.open", create=True) - @patch( - "src.config.safe_load", - return_value=conf_dict, - ) - def test_get_config_with_default_path(self, mock_safe_load, mock_open): - default_path = Path(__file__).parent.parent.parent.joinpath("env.yaml") - config = get_config() assert config == Config( OPENAI_API_KEY="your_key", OPENAI_PROMPTS_PATH="path", @@ -49,5 +29,3 @@ def test_get_config_with_default_path(self, mock_safe_load, mock_open): OUTPUT_DIR="output", YT_PROBA=80, ) - mock_open.assert_called_once_with(default_path, encoding="utf-8") - mock_safe_load.assert_called_once()