From 8c246733eea0a499db6f6232231fbed3b4d41a75 Mon Sep 17 00:00:00 2001 From: Nada Amin Date: Fri, 29 Nov 2024 17:46:09 +0100 Subject: [PATCH 1/2] Support provider Ollama for automatic configuration. --- docs/guides/configure-llms.mdx | 1 + src/controlflow/llm/models.py | 8 ++++++++ tests/llm/test_models.py | 5 +++++ 3 files changed, 14 insertions(+) diff --git a/docs/guides/configure-llms.mdx b/docs/guides/configure-llms.mdx index 8df6e55d..438db196 100644 --- a/docs/guides/configure-llms.mdx +++ b/docs/guides/configure-llms.mdx @@ -47,6 +47,7 @@ At this time, supported providers for automatic configuration include: | Anthropic | `anthropic` | (included) | | Google | `google` | `langchain_google_genai` | | Groq | `groq` | `langchain_groq` | +| Ollama | `ollama` | `langchain-ollama` | If the required dependencies are not installed, ControlFlow will be unable to load the model and will raise an error. diff --git a/src/controlflow/llm/models.py b/src/controlflow/llm/models.py index 6938e197..11a66c8e 100644 --- a/src/controlflow/llm/models.py +++ b/src/controlflow/llm/models.py @@ -60,6 +60,14 @@ def get_model( "To use Groq as an LLM provider, please install the `langchain_groq` package." ) cls = ChatGroq + elif provider == "ollama": + try: + from langchain_ollama import ChatOllama + except ImportError: + raise ImportError( + "To use Ollama as an LLM provider, please install the `langchain-ollama` package." + ) + cls = ChatOllama else: raise ValueError( f"Could not load provider `{provider}` automatically. Please provide the LLM class manually." diff --git a/tests/llm/test_models.py b/tests/llm/test_models.py index b46bb0d0..fe2e3c27 100644 --- a/tests/llm/test_models.py +++ b/tests/llm/test_models.py @@ -2,6 +2,7 @@ from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI from langchain_groq import ChatGroq +from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI from controlflow.llm.models import get_model @@ -45,6 +46,10 @@ def test_get_groq_model(monkeypatch): assert isinstance(model, ChatGroq) assert model.model_name == "mixtral-8x7b-32768" +def test_get_ollama_model(monkeypatch): + model = get_model("ollama/qwen2.5") + assert isinstance(model, ChatOllama) + assert model.model == "qwen2.5" def test_get_model_with_invalid_format(): with pytest.raises(ValueError, match="The model `gpt-4o` is not valid."): From 6cab490e42b386589abc3f450fb7aae476be91b4 Mon Sep 17 00:00:00 2001 From: Nada Amin Date: Fri, 29 Nov 2024 17:53:16 +0100 Subject: [PATCH 2/2] add langchain-ollama as an optional tests dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 9ec67d61..c569a797 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ tests = [ "langchain_community", "langchain_google_genai", "langchain_groq", + "langchain-ollama', "pytest-asyncio>=0.18.2,!=0.22.0,<0.23.0", "pytest-env>=0.8,<2.0", "pytest-rerunfailures>=10,<14",