From 7494e325ba036d792b01319abf8ae858d6b2eacd Mon Sep 17 00:00:00 2001 From: Sunish Sheth Date: Thu, 2 May 2024 13:57:02 -0700 Subject: [PATCH] [MLflow] Updating langchain databricks_dependency to use resources format as well (#11869) Signed-off-by: Sunish Sheth --- mlflow/langchain/__init__.py | 7 +- mlflow/langchain/databricks_dependencies.py | 58 ++++++++----- .../_unity_catalog/registry/rest_store.py | 57 ++++++++---- ...gchain_databricks_dependency_extraction.py | 31 +++++-- .../langchain/test_langchain_model_export.py | 84 +++++++++++++++--- .../test_unity_catalog_rest_store.py | 87 +++++++++++++++++++ 6 files changed, 266 insertions(+), 58 deletions(-) diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py index 6f35f15365dbb..24bd11dd82566 100644 --- a/mlflow/langchain/__init__.py +++ b/mlflow/langchain/__init__.py @@ -57,6 +57,7 @@ from mlflow.models import Model, ModelInputExample, ModelSignature, get_model_info from mlflow.models.model import MLMODEL_FILE_NAME from mlflow.models.model_config import _set_model_config +from mlflow.models.resources import _ResourceBuilder from mlflow.models.signature import _infer_signature_from_input_example from mlflow.models.utils import _convert_llm_input_data, _save_example from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS @@ -367,8 +368,12 @@ def load_retriever(persist_directory): ) if Version(langchain.__version__) >= Version("0.0.311"): - if databricks_dependency := _detect_databricks_dependencies(lc_model): + (databricks_dependency, databricks_resources) = _detect_databricks_dependencies(lc_model) + if databricks_dependency: flavor_conf[_DATABRICKS_DEPENDENCY_KEY] = databricks_dependency + if databricks_resources: + serialized_databricks_resources = _ResourceBuilder.from_resources(databricks_resources) + mlflow_model.resources = serialized_databricks_resources mlflow_model.add_flavor( FLAVOR_NAME, diff --git a/mlflow/langchain/databricks_dependencies.py b/mlflow/langchain/databricks_dependencies.py index 12a173bffedcd..a1a2414caaf7e 100644 --- a/mlflow/langchain/databricks_dependencies.py +++ b/mlflow/langchain/databricks_dependencies.py @@ -2,6 +2,8 @@ from collections import defaultdict from typing import Any, DefaultDict, Dict, List, Set +from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex, Resource + _DATABRICKS_DEPENDENCY_KEY = "databricks_dependency" _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY = "databricks_vector_search_index_name" _DATABRICKS_VECTOR_SEARCH_ENDPOINT_NAME_KEY = "databricks_vector_search_endpoint_name" @@ -14,7 +16,7 @@ def _extract_databricks_dependencies_from_retriever( - retriever, dependency_dict: DefaultDict[str, List[Any]] + retriever, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource] ): try: from langchain.embeddings import DatabricksEmbeddings as LegacyDatabricksEmbeddings @@ -38,20 +40,18 @@ def _extract_databricks_dependencies_from_retriever( index = vectorstore.index dependency_dict[_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY].append(index.name) dependency_dict[_DATABRICKS_VECTOR_SEARCH_ENDPOINT_NAME_KEY].append(index.endpoint_name) + dependency_list.append(DatabricksVectorSearchIndex(index_name=index.name)) + dependency_list.append(DatabricksServingEndpoint(endpoint_name=index.endpoint_name)) embeddings = getattr(vectorstore, "embeddings", None) if isinstance(embeddings, (DatabricksEmbeddings, LegacyDatabricksEmbeddings)): dependency_dict[_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY].append(embeddings.endpoint) - elif ( - callable(getattr(vectorstore, "_is_databricks_managed_embeddings", None)) - and vectorstore._is_databricks_managed_embeddings() - ): - dependency_dict[_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY].append( - "_is_databricks_managed_embeddings" - ) + dependency_list.append(DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)) -def _extract_databricks_dependencies_from_llm(llm, dependency_dict: DefaultDict[str, List[Any]]): +def _extract_databricks_dependencies_from_llm( + llm, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource] +): try: from langchain.llms import Databricks as LegacyDatabricks except ImportError: @@ -61,10 +61,11 @@ def _extract_databricks_dependencies_from_llm(llm, dependency_dict: DefaultDict[ if isinstance(llm, (LegacyDatabricks, Databricks)): dependency_dict[_DATABRICKS_LLM_ENDPOINT_NAME_KEY].append(llm.endpoint_name) + dependency_list.append(DatabricksServingEndpoint(endpoint_name=llm.endpoint_name)) def _extract_databricks_dependencies_from_chat_model( - chat_model, dependency_dict: DefaultDict[str, List[Any]] + chat_model, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource] ): try: from langchain.chat_models import ChatDatabricks as LegacyChatDatabricks @@ -77,6 +78,7 @@ def _extract_databricks_dependencies_from_chat_model( if isinstance(chat_model, (LegacyChatDatabricks, ChatDatabricks)): dependency_dict[_DATABRICKS_CHAT_ENDPOINT_NAME_KEY].append(chat_model.endpoint) + dependency_list.append(DatabricksServingEndpoint(endpoint_name=chat_model.endpoint)) _LEGACY_MODEL_ATTR_SET = { @@ -92,7 +94,9 @@ def _extract_databricks_dependencies_from_chat_model( } -def _extract_dependency_dict_from_lc_model(lc_model, dependency_dict: DefaultDict[str, List[Any]]): +def _extract_dependency_dict_from_lc_model( + lc_model, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource] +): """ This function contains the logic to examine a non-Runnable component of a langchain model. The logic here does not cover all legacy chains. If you need to support a custom chain, @@ -102,16 +106,23 @@ def _extract_dependency_dict_from_lc_model(lc_model, dependency_dict: DefaultDic return # leaf node - _extract_databricks_dependencies_from_chat_model(lc_model, dependency_dict) - _extract_databricks_dependencies_from_retriever(lc_model, dependency_dict) - _extract_databricks_dependencies_from_llm(lc_model, dependency_dict) + _extract_databricks_dependencies_from_chat_model(lc_model, dependency_dict, dependency_list) + _extract_databricks_dependencies_from_retriever(lc_model, dependency_dict, dependency_list) + _extract_databricks_dependencies_from_llm(lc_model, dependency_dict, dependency_list) # recursively inspect legacy chain for attr_name in _LEGACY_MODEL_ATTR_SET: - _extract_dependency_dict_from_lc_model(getattr(lc_model, attr_name, None), dependency_dict) + _extract_dependency_dict_from_lc_model( + getattr(lc_model, attr_name, None), dependency_dict, dependency_list + ) -def _traverse_runnable(lc_model, dependency_dict: DefaultDict[str, List[Any]], visited: Set[str]): +def _traverse_runnable( + lc_model, + dependency_dict: DefaultDict[str, List[Any]], + dependency_list: List[Resource], + visited: Set[str], +): """ This function contains the logic to traverse a langchain_core.runnables.RunnableSerializable object. It first inspects the current object using _extract_dependency_dict_from_lc_model @@ -127,19 +138,21 @@ def _traverse_runnable(lc_model, dependency_dict: DefaultDict[str, List[Any]], v # Visit the current object visited.add(current_object_id) - _extract_dependency_dict_from_lc_model(lc_model, dependency_dict) + _extract_dependency_dict_from_lc_model(lc_model, dependency_dict, dependency_list) if isinstance(lc_model, Runnable): # Visit the returned graph for node in lc_model.get_graph().nodes.values(): - _traverse_runnable(node.data, dependency_dict, visited) + _traverse_runnable(node.data, dependency_dict, dependency_list, visited) else: # No-op for non-runnable, if any pass return -def _detect_databricks_dependencies(lc_model, log_errors_as_warnings=True) -> Dict[str, List[Any]]: +def _detect_databricks_dependencies( + lc_model, log_errors_as_warnings=True +) -> (Dict[str, List[Any]], List[Resource]): """ Detects the databricks dependencies of a langchain model and returns a dictionary of detected endpoint names and index names. @@ -162,8 +175,9 @@ def _detect_databricks_dependencies(lc_model, log_errors_as_warnings=True) -> Di """ try: dependency_dict = defaultdict(list) - _traverse_runnable(lc_model, dependency_dict, set()) - return dict(dependency_dict) + dependency_list = [] + _traverse_runnable(lc_model, dependency_dict, dependency_list, set()) + return (dict(dependency_dict), dependency_list) except Exception: if log_errors_as_warnings: _logger.warning( @@ -171,5 +185,5 @@ def _detect_databricks_dependencies(lc_model, log_errors_as_warnings=True) -> Di "Set logging level to DEBUG to see the full traceback." ) _logger.debug("", exc_info=True) - return {} + return {}, [] raise diff --git a/mlflow/store/_unity_catalog/registry/rest_store.py b/mlflow/store/_unity_catalog/registry/rest_store.py index 4e1c623c33225..599926f1f56c5 100644 --- a/mlflow/store/_unity_catalog/registry/rest_store.py +++ b/mlflow/store/_unity_catalog/registry/rest_store.py @@ -176,31 +176,54 @@ def get_model_version_dependencies(model_dir): _DATABRICKS_LLM_ENDPOINT_NAME_KEY, _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY, ) + from mlflow.models.resources import ResourceType model = _load_model(model_dir) model_info = model.get_model_info() dependencies = [] - index_names = _fetch_langchain_dependency_from_model_info( - model_info, _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY - ) - for index_name in index_names: - dependencies.append({"type": "DATABRICKS_VECTOR_INDEX", "name": index_name}) - for key in ( - _DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY, - _DATABRICKS_LLM_ENDPOINT_NAME_KEY, - _DATABRICKS_CHAT_ENDPOINT_NAME_KEY, - ): - endpoint_names = _fetch_langchain_dependency_from_model_info(model_info, key) + + databricks_resources = getattr(model, "resources", {}) + + if databricks_resources: + databricks_dependencies = databricks_resources.get("databricks", {}) + index_names = _fetch_langchain_dependency_from_model_info( + databricks_dependencies, ResourceType.VECTOR_SEARCH_INDEX.value + ) + for index_name in index_names: + dependencies.append({"type": "DATABRICKS_VECTOR_INDEX", **index_name}) + endpoint_names = _fetch_langchain_dependency_from_model_info( + databricks_dependencies, ResourceType.SERVING_ENDPOINT.value + ) for endpoint_name in endpoint_names: - dependencies.append({"type": "DATABRICKS_MODEL_ENDPOINT", "name": endpoint_name}) - return dependencies + dependencies.append({"type": "DATABRICKS_MODEL_ENDPOINT", **endpoint_name}) + else: + # import here to work around circular imports + from mlflow.langchain.databricks_dependencies import _DATABRICKS_DEPENDENCY_KEY + databricks_dependencies = model_info.flavors.get("langchain", {}).get( + _DATABRICKS_DEPENDENCY_KEY, {} + ) + + index_names = _fetch_langchain_dependency_from_model_info( + databricks_dependencies, _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY + ) + for index_name in index_names: + dependencies.append({"type": "DATABRICKS_VECTOR_INDEX", "name": index_name}) + for key in ( + _DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY, + _DATABRICKS_LLM_ENDPOINT_NAME_KEY, + _DATABRICKS_CHAT_ENDPOINT_NAME_KEY, + ): + endpoint_names = _fetch_langchain_dependency_from_model_info( + databricks_dependencies, key + ) + for endpoint_name in endpoint_names: + dependencies.append({"type": "DATABRICKS_MODEL_ENDPOINT", "name": endpoint_name}) + return dependencies -def _fetch_langchain_dependency_from_model_info(model_info, key): - # import here to work around circular imports - from mlflow.langchain.databricks_dependencies import _DATABRICKS_DEPENDENCY_KEY - return model_info.flavors.get("langchain", {}).get(_DATABRICKS_DEPENDENCY_KEY, {}).get(key, []) +def _fetch_langchain_dependency_from_model_info(databricks_dependencies, key): + return databricks_dependencies.get(key, []) @experimental diff --git a/tests/langchain/test_langchain_databricks_dependency_extraction.py b/tests/langchain/test_langchain_databricks_dependency_extraction.py index 7c2f5e1d0b41f..170397255e4de 100644 --- a/tests/langchain/test_langchain_databricks_dependency_extraction.py +++ b/tests/langchain/test_langchain_databricks_dependency_extraction.py @@ -17,6 +17,7 @@ _extract_databricks_dependencies_from_retriever, ) from mlflow.langchain.utils import IS_PICKLE_SERIALIZATION_RESTRICTED +from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex class MockDatabricksServingEndpointClient: @@ -54,8 +55,12 @@ def test_parsing_dependency_from_databricks_llm(monkeypatch: pytest.MonkeyPatch) llm = Databricks(**llm_kwargs) d = defaultdict(list) - _extract_databricks_dependencies_from_llm(llm, d) + resources = [] + _extract_databricks_dependencies_from_llm(llm, d, resources) assert d.get(_DATABRICKS_LLM_ENDPOINT_NAME_KEY) == ["databricks-mixtral-8x7b-instruct"] + assert resources == [ + DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct") + ] class MockVectorSearchIndex: @@ -96,10 +101,16 @@ def test_parsing_dependency_from_databricks_retriever(monkeypatch: pytest.Monkey vectorstore = DatabricksVectorSearch(vs_index, text_column="content", embedding=embedding_model) retriever = vectorstore.as_retriever() d = defaultdict(list) - _extract_databricks_dependencies_from_retriever(retriever, d) + resources = [] + _extract_databricks_dependencies_from_retriever(retriever, d, resources) assert d.get(_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY) == ["databricks-bge-large-en"] assert d.get(_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY) == ["mlflow.rag.vs_index"] assert d.get(_DATABRICKS_VECTOR_SEARCH_ENDPOINT_NAME_KEY) == ["dbdemos_vs_endpoint"] + assert resources == [ + DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"), + DatabricksServingEndpoint(endpoint_name="dbdemos_vs_endpoint"), + DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), + ] @pytest.mark.skipif( @@ -124,10 +135,16 @@ def test_parsing_dependency_from_databricks_retriever(monkeypatch: pytest.Monkey vectorstore = DatabricksVectorSearch(vs_index, text_column="content", embedding=embedding_model) retriever = vectorstore.as_retriever() d = defaultdict(list) - _extract_databricks_dependencies_from_retriever(retriever, d) + resources = [] + _extract_databricks_dependencies_from_retriever(retriever, d, resources) assert d.get(_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY) == ["databricks-bge-large-en"] assert d.get(_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY) == ["mlflow.rag.vs_index"] assert d.get(_DATABRICKS_VECTOR_SEARCH_ENDPOINT_NAME_KEY) == ["dbdemos_vs_endpoint"] + assert resources == [ + DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"), + DatabricksServingEndpoint(endpoint_name="dbdemos_vs_endpoint"), + DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), + ] @pytest.mark.skipif( @@ -142,8 +159,10 @@ def test_parsing_dependency_from_databricks_chat(monkeypatch: pytest.MonkeyPatch chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500) d = defaultdict(list) - _extract_databricks_dependencies_from_chat_model(chat_model, d) + resources = [] + _extract_databricks_dependencies_from_chat_model(chat_model, d, resources) assert d.get(_DATABRICKS_CHAT_ENDPOINT_NAME_KEY) == ["databricks-llama-2-70b-chat"] + assert resources == [DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat")] @pytest.mark.skipif( @@ -158,5 +177,7 @@ def test_parsing_dependency_from_databricks_chat(monkeypatch: pytest.MonkeyPatch chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500) d = defaultdict(list) - _extract_databricks_dependencies_from_chat_model(chat_model, d) + resources = [] + _extract_databricks_dependencies_from_chat_model(chat_model, d, resources) assert d.get(_DATABRICKS_CHAT_ENDPOINT_NAME_KEY) == ["databricks-llama-2-70b-chat"] + assert resources == [DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat")] diff --git a/tests/langchain/test_langchain_model_export.py b/tests/langchain/test_langchain_model_export.py index 32cade7d3fa16..6affe36d267f2 100644 --- a/tests/langchain/test_langchain_model_export.py +++ b/tests/langchain/test_langchain_model_export.py @@ -55,7 +55,12 @@ _LC_MIN_VERSION_SUPPORT_CHAT_OPEN_AI, IS_PICKLE_SERIALIZATION_RESTRICTED, ) +from mlflow.models import Model +from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex, Resource from mlflow.models.signature import ModelSignature, Schema, infer_signature +from mlflow.tracking.artifact_utils import ( + _download_artifact_from_uri, +) from mlflow.types.schema import Array, ColSpec, DataType, Object, Property from mlflow.utils.openai_utils import ( TEST_CONTENT, @@ -1776,9 +1781,12 @@ def extract_history(input): } -def _extract_endpoint_name_from_lc_model(lc_model, dependency_dict: DefaultDict[str, List[Any]]): +def _extract_endpoint_name_from_lc_model( + lc_model, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource] +): if type(lc_model).__name__ == type(get_fake_chat_model()).__name__: dependency_dict["fake_chat_model_endpoint_name"].append(lc_model.endpoint_name) + dependency_list.append(DatabricksServingEndpoint(endpoint_name=lc_model.endpoint_name)) @pytest.mark.skipif( @@ -1803,8 +1811,9 @@ def test_databricks_dependency_extraction_from_lcel_chain(): chain = prompt_1 | {"joke1": model_1, "joke2": model_2} | prompt_2 | model_3 | output_parser - with mlflow.start_run(): - model_info = mlflow.langchain.log_model(chain, "basic_chain") + pyfunc_artifact_path = "basic_chain" + with mlflow.start_run() as run: + model_info = mlflow.langchain.log_model(chain, pyfunc_artifact_path) langchain_flavor = model_info.flavors["langchain"] assert langchain_flavor["databricks_dependency"] == { @@ -1814,10 +1823,20 @@ def test_databricks_dependency_extraction_from_lcel_chain(): "fake-endpoint-3", ] } + pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}" + pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) + reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) + assert reloaded_model.resources["databricks"] == { + "serving_endpoint": [ + {"name": "fake-endpoint-1"}, + {"name": "fake-endpoint-2"}, + {"name": "fake-endpoint-3"}, + ] + } def _extract_databricks_dependencies_from_retriever( - retriever, dependency_dict: DefaultDict[str, List[Any]] + retriever, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource] ): import langchain_community @@ -1825,15 +1844,20 @@ def _extract_databricks_dependencies_from_retriever( if vectorstore: if isinstance(vectorstore, langchain_community.vectorstores.faiss.FAISS): dependency_dict["fake_index"].append("faiss-index") + dependency_list.append(DatabricksVectorSearchIndex(index_name="faiss-index")) embeddings = getattr(vectorstore, "embeddings", None) if isinstance(embeddings, FakeEmbeddings): dependency_dict["fake_embeddings_size"].append(embeddings.size) + dependency_list.append(DatabricksServingEndpoint(endpoint_name="fake-embeddings")) -def _extract_databricks_dependencies_from_llm(llm, dependency_dict: DefaultDict[str, List[Any]]): +def _extract_databricks_dependencies_from_llm( + llm, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource] +): if isinstance(llm, FakeLLM): dependency_dict["fake_llm_endpoint_name"].append(llm.endpoint_name) + dependency_list.append(DatabricksServingEndpoint(endpoint_name=llm.endpoint_name)) @pytest.mark.skipif( @@ -1867,10 +1891,11 @@ def load_retriever(persist_directory): vectorstore = FAISS.load_local(persist_directory, embeddings) return vectorstore.as_retriever() - with mlflow.start_run(): + pyfunc_artifact_path = "retrieval_qa_chain" + with mlflow.start_run() as run: logged_model = mlflow.langchain.log_model( retrievalQA, - "retrieval_qa_chain", + pyfunc_artifact_path, loader_fn=load_retriever, persist_dir=persist_dir, ) @@ -1880,6 +1905,20 @@ def load_retriever(persist_directory): "fake_index": ["faiss-index"], "fake_embeddings_size": [5], } + pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}" + pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) + reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) + actual = reloaded_model.resources["databricks"] + expected = { + "serving_endpoint": [ + {"name": "fake-llm-endpoint"}, + {"name": "fake-embeddings"}, + ], + "vector_search_index": [{"name": "faiss-index"}], + } + assert all(item in actual["serving_endpoint"] for item in expected["serving_endpoint"]) + assert all(item in expected["serving_endpoint"] for item in actual["serving_endpoint"]) + assert actual["vector_search_index"] == expected["vector_search_index"] def _error_func(*args, **kwargs): @@ -1908,9 +1947,14 @@ def test_databricks_dependency_extraction_log_errors_as_warnings(mock_warning): with pytest.raises(ValueError, match="error"): _detect_databricks_dependencies(model, log_errors_as_warnings=False) - with mlflow.start_run(): - logged_model = mlflow.langchain.log_model(model, "langchain_model") + pyfunc_artifact_path = "langchain_model" + with mlflow.start_run() as run: + logged_model = mlflow.langchain.log_model(model, pyfunc_artifact_path) assert logged_model.flavors["langchain"].get(_DATABRICKS_DEPENDENCY_KEY) is None + pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}" + pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) + reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) + assert reloaded_model.resources is None @pytest.mark.skipif( @@ -2337,10 +2381,11 @@ def test_save_load_chain_as_code(chain_model_signature): } ] } - with mlflow.start_run(): + artifact_path = "model_path" + with mlflow.start_run() as run: model_info = mlflow.langchain.log_model( lc_model="tests/langchain/chain.py", - artifact_path="model_path", + artifact_path=artifact_path, signature=chain_model_signature, input_example=input_example, model_config="tests/langchain/config.yml", @@ -2374,6 +2419,12 @@ def test_save_load_chain_as_code(chain_model_signature): assert langchain_flavor["databricks_dependency"] == { "databricks_chat_endpoint_name": ["fake-endpoint"] } + pyfunc_model_uri = f"runs:/{run.info.run_id}/{artifact_path}" + pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) + reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) + assert reloaded_model.resources["databricks"] == { + "serving_endpoint": [{"name": "fake-endpoint"}] + } @pytest.mark.skipif( @@ -2546,10 +2597,11 @@ def test_save_load_chain_as_code_optional_code_path(chain_model_signature): } ] } - with mlflow.start_run(): + artifact_path = "model_path" + with mlflow.start_run() as run: model_info = mlflow.langchain.log_model( lc_model="tests/langchain/no_config/chain.py", - artifact_path="model_path", + artifact_path=artifact_path, signature=chain_model_signature, input_example=input_example, ) @@ -2585,6 +2637,12 @@ def test_save_load_chain_as_code_optional_code_path(chain_model_signature): assert langchain_flavor["databricks_dependency"] == { "databricks_chat_endpoint_name": ["fake-endpoint"] } + pyfunc_model_uri = f"runs:/{run.info.run_id}/{artifact_path}" + pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) + reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) + assert reloaded_model.resources["databricks"] == { + "serving_endpoint": [{"name": "fake-endpoint"}] + } def test_config_path_context(): diff --git a/tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py b/tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py index cebf04ec345e4..9a60842abc159 100644 --- a/tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py +++ b/tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py @@ -222,6 +222,31 @@ def langchain_local_model_dir(tmp_path): return tmp_path +@pytest.fixture +def langchain_local_model_dir_with_resources(tmp_path): + fake_signature = ModelSignature( + inputs=Schema([ColSpec(DataType.string)]), outputs=Schema([ColSpec(DataType.string)]) + ) + fake_mlmodel_contents = { + "artifact_path": "some-artifact-path", + "run_id": "abc123", + "signature": fake_signature.to_dict(), + "resources": { + "databricks": { + "serving_endpoint": [ + {"name": "embedding_endpoint"}, + {"name": "llm_endpoint"}, + {"name": "chat_endpoint"}, + ], + "vector_search_index": [{"name": "index1"}, {"name": "index2"}], + } + }, + } + with open(tmp_path.joinpath(MLMODEL_FILE_NAME), "w") as handle: + yaml.dump(fake_mlmodel_contents, handle) + return tmp_path + + @pytest.fixture def langchain_local_model_dir_no_dependencies(tmp_path): fake_signature = ModelSignature( @@ -300,6 +325,68 @@ def test_create_model_version_with_langchain_dependencies(store, langchain_local ) +def test_create_model_version_with_resources(store, langchain_local_model_dir_with_resources): + access_key_id = "fake-key" + secret_access_key = "secret-key" + session_token = "session-token" + aws_temp_creds = TemporaryCredentials( + aws_temp_credentials=AwsCredentials( + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + ) + ) + storage_location = "s3://blah" + source = str(langchain_local_model_dir_with_resources) + model_name = "model_1" + version = "1" + tags = [ + ModelVersionTag(key="key", value="value"), + ModelVersionTag(key="anotherKey", value="some other value"), + ] + model_version_dependencies = [ + {"type": "DATABRICKS_VECTOR_INDEX", "name": "index1"}, + {"type": "DATABRICKS_VECTOR_INDEX", "name": "index2"}, + {"type": "DATABRICKS_MODEL_ENDPOINT", "name": "embedding_endpoint"}, + {"type": "DATABRICKS_MODEL_ENDPOINT", "name": "llm_endpoint"}, + {"type": "DATABRICKS_MODEL_ENDPOINT", "name": "chat_endpoint"}, + ] + + mock_artifact_repo = mock.MagicMock(autospec=OptimizedS3ArtifactRepository) + with mock.patch( + "mlflow.utils.rest_utils.http_request", + side_effect=get_request_mock( + name=model_name, + version=version, + temp_credentials=aws_temp_creds, + storage_location=storage_location, + source=source, + tags=tags, + model_version_dependencies=model_version_dependencies, + ), + ) as request_mock, mock.patch( + "mlflow.store.artifact.optimized_s3_artifact_repo.OptimizedS3ArtifactRepository", + return_value=mock_artifact_repo, + ) as optimized_s3_artifact_repo_class_mock, mock.patch.dict("sys.modules", {"boto3": {}}): + store.create_model_version(name=model_name, source=source, tags=tags) + # Verify that s3 artifact repo mock was called with expected args + optimized_s3_artifact_repo_class_mock.assert_called_once_with( + artifact_uri=storage_location, + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + ) + mock_artifact_repo.log_artifacts.assert_called_once_with(local_dir=ANY, artifact_path="") + _assert_create_model_version_endpoints_called( + request_mock=request_mock, + name=model_name, + source=source, + version=version, + tags=tags, + model_version_dependencies=model_version_dependencies, + ) + + def test_create_model_version_with_langchain_no_dependencies( store, langchain_local_model_dir_no_dependencies ):