diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index 1fa5cdc269a..ca90580ba3f 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -1,7 +1,9 @@ import pytest from unittest import mock -from dbt.context.providers import BaseResolver +from dbt.adapters.base import BaseRelation +from dbt.context.providers import BaseResolver, RuntimeRefResolver, RuntimeSourceResolver +from dbt.contracts.graph.unparsed import Quoting class TestBaseResolver: @@ -26,3 +28,77 @@ def test_resolve_limit(self, resolver, empty, expected_resolve_limit): resolver.config.args.EMPTY = empty assert resolver.resolve_limit == expected_resolve_limit + + +class TestRuntimeRefResolver: + @pytest.fixture + def resolver(self): + mock_db_wrapper = mock.Mock() + mock_db_wrapper.Relation = BaseRelation + + return RuntimeRefResolver( + db_wrapper=mock_db_wrapper, + model=mock.Mock(), + config=mock.Mock(), + manifest=mock.Mock(), + ) + + @pytest.mark.parametrize( + "empty,is_ephemeral_model,expected_limit", + [ + (False, False, None), + (True, False, 0), + (False, True, None), + (True, True, 0), + ], + ) + def test_create_relation_with_empty(self, resolver, empty, is_ephemeral_model, expected_limit): + # setup resolver and input node + resolver.config.args.EMPTY = empty + mock_node = mock.Mock() + mock_node.database = "test" + mock_node.schema = "test" + mock_node.identifier = "test" + mock_node.alias = "test" + mock_node.is_ephemeral_model = is_ephemeral_model + + # create limited relation + with mock.patch("dbt.adapters.base.relation.ParsedNode", new=mock.Mock): + relation = resolver.create_relation(mock_node) + assert relation.limit == expected_limit + + +class TestRuntimeSourceResolver: + @pytest.fixture + def resolver(self): + mock_db_wrapper = mock.Mock() + mock_db_wrapper.Relation = BaseRelation + + return RuntimeSourceResolver( + db_wrapper=mock_db_wrapper, + model=mock.Mock(), + config=mock.Mock(), + manifest=mock.Mock(), + ) + + @pytest.mark.parametrize( + "empty,expected_limit", + [ + (False, None), + (True, 0), + ], + ) + def test_create_relation_with_empty(self, resolver, empty, expected_limit): + # setup resolver and input source + resolver.config.args.EMPTY = empty + + mock_source = mock.Mock() + mock_source.database = "test" + mock_source.schema = "test" + mock_source.identifier = "test" + mock_source.quoting = Quoting() + resolver.manifest.resolve_source.return_value = mock_source + + # create limited relation + relation = resolver.resolve("test", "test") + assert relation.limit == expected_limit