Skip to content

Commit

Permalink
Merge pull request #398 from webcoderz/asyncpg
Browse files Browse the repository at this point in the history
Async pg module w/ connection pooling
  • Loading branch information
jlowin authored Jan 11, 2025
2 parents 59d6c90 + e54eaf9 commit bcf422e
Show file tree
Hide file tree
Showing 14 changed files with 800 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ cython_debug/
src/controlflow/_version.py
all_code.md
all_docs.md
llm_guides.md
llm_guides.md
55 changes: 55 additions & 0 deletions examples/asyncpg-memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio

import controlflow as cf
from controlflow.memory.async_memory import AsyncMemory
from controlflow.memory.providers.postgres import AsyncPostgresMemory

provider = AsyncPostgresMemory(
database_url="postgresql+psycopg://postgres:postgres@localhost:5432/database",
# embedding_dimension=1536,
# embedding_fn=OpenAIEmbeddings(),
table_name="vector_db_async",
)

# Create a memory module for user preferences
user_preferences = AsyncMemory(
key="user_preferences",
instructions="Store and retrieve user preferences.",
provider=provider,
)

# Create an agent with access to the memory
agent = cf.Agent(memories=[user_preferences])


# Create a flow to ask for the user's favorite color
@cf.flow
async def remember_pet():
return await cf.run_async(
"Ask the user for their favorite animal and store it in memory",
agents=[agent],
interactive=True,
)


# Create a flow to recall the user's favorite color
@cf.flow
async def recall_pet():
return await cf.run_async(
"What is the user's favorite animal?",
agents=[agent],
)


async def main():
print("First flow:")
await remember_pet()

print("\nSecond flow:")
result = await recall_pet()
print(result)
return result


if __name__ == "__main__":
asyncio.run(main())
2 changes: 1 addition & 1 deletion examples/pg-memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from controlflow.memory.providers.postgres import PostgresMemory

provider = PostgresMemory(
database_url="postgresql://postgres:postgres@localhost:5432/your_database",
database_url="postgresql://postgres:postgres@localhost:5432/database",
# embedding_dimension=1536,
# embedding_fn=OpenAIEmbeddings(),
table_name="vector_db",
Expand Down
1 change: 1 addition & 0 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# functions, utilites, and decorators
from .memory import Memory
from .memory.async_memory import AsyncMemory
from .instructions import instructions
from .decorators import flow, task
from .tools import tool
Expand Down
3 changes: 2 additions & 1 deletion src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from controlflow.llm.models import get_model as get_model_from_string
from controlflow.llm.rules import LLMRules
from controlflow.memory import Memory
from controlflow.memory.async_memory import AsyncMemory
from controlflow.tools.tools import (
Tool,
as_lc_tools,
Expand Down Expand Up @@ -82,7 +83,7 @@ class Agent(ControlFlowModel, abc.ABC):
default=False,
description="If True, the agent is given tools for interacting with a human user.",
)
memories: list[Memory] = Field(
memories: list[Memory] | list[AsyncMemory] = Field(
default=[],
description="A list of memory modules for the agent to use.",
)
Expand Down
5 changes: 4 additions & 1 deletion src/controlflow/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import controlflow.utilities
import controlflow.utilities.logging
from controlflow.llm.models import BaseChatModel
from controlflow.memory.async_memory import AsyncMemoryProvider, get_memory_provider
from controlflow.memory.memory import MemoryProvider, get_memory_provider
from controlflow.utilities.general import ControlFlowModel

Expand Down Expand Up @@ -39,7 +40,9 @@ class Defaults(ControlFlowModel):
model: Optional[Any]
history: History
agent: Agent
memory_provider: Optional[Union[MemoryProvider, str]]
memory_provider: (
Optional[Union[MemoryProvider, str]] | Optional[Union[AsyncMemoryProvider, str]]
)

# add more defaults here
def __repr__(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions src/controlflow/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .memory import Memory
from .async_memory import AsyncMemory
149 changes: 149 additions & 0 deletions src/controlflow/memory/async_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import abc
import re
from typing import Dict, List, Optional, Union

from pydantic import Field, field_validator, model_validator

import controlflow
from controlflow.tools.tools import Tool
from controlflow.utilities.general import ControlFlowModel, unwrap
from controlflow.utilities.logging import get_logger

logger = get_logger("controlflow.memory")


def sanitize_memory_key(key: str) -> str:
# Remove any characters that are not alphanumeric or underscore
return re.sub(r"[^a-zA-Z0-9_]", "", key)


class AsyncMemoryProvider(ControlFlowModel, abc.ABC):
async def configure(self, memory_key: str) -> None:
"""Configure the provider for a specific memory."""
pass

@abc.abstractmethod
async def add(self, memory_key: str, content: str) -> str:
"""Create a new memory and return its ID."""
pass

@abc.abstractmethod
async def delete(self, memory_key: str, memory_id: str) -> None:
"""Delete a memory by its ID."""
pass

@abc.abstractmethod
async def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]:
"""Search for n memories using a string query."""
pass


class AsyncMemory(ControlFlowModel):
"""
A memory module is a partitioned collection of memories that are stored in a
vector database, configured by a MemoryProvider.
"""

key: str
instructions: str = Field(
description="Explain what this memory is for and how it should be used."
)
provider: AsyncMemoryProvider = Field(
default_factory=lambda: controlflow.defaults.memory_provider,
validate_default=True,
)

def __hash__(self) -> int:
return id(self)

@field_validator("provider", mode="before")
@classmethod
def validate_provider(
cls, v: Optional[Union[AsyncMemoryProvider, str]]
) -> AsyncMemoryProvider:
if isinstance(v, str):
return get_memory_provider(v)
if v is None:
raise ValueError(
unwrap(
"""
Memory modules require a MemoryProvider to configure the
underlying vector database. No provider was passed as an
argument, and no default value has been configured.
For more information on configuring a memory provider, see
the [Memory
documentation](https://controlflow.ai/patterns/memory), and
please review the [default provider
guide](https://controlflow.ai/guides/default-memory) for
information on configuring a default provider.
Please note that if you are using ControlFlow for the first
time, this error is expected because ControlFlow does not include
vector dependencies by default.
"""
)
)
return v

@field_validator("key")
@classmethod
def validate_key(cls, v: str) -> str:
sanitized = sanitize_memory_key(v)
if sanitized != v:
raise ValueError(
"Memory key must contain only alphanumeric characters and underscores"
)
return sanitized

async def _configure_provider(self):
await self.provider.configure(self.key)
return self

async def add(self, content: str) -> str:
return await self.provider.add(self.key, content)

async def delete(self, memory_id: str) -> None:
await self.provider.delete(self.key, memory_id)

async def search(self, query: str, n: int = 20) -> Dict[str, str]:
return await self.provider.search(self.key, query, n)

def get_tools(self) -> List[Tool]:
return [
Tool.from_function(
self.add,
name=f"store_memory_{self.key}",
description=f'Create a new memory in Memory: "{self.key}".',
),
Tool.from_function(
self.delete,
name=f"delete_memory_{self.key}",
description=f'Delete a memory by its ID from Memory: "{self.key}".',
),
Tool.from_function(
self.search,
name=f"search_memories_{self.key}",
description=f'Search for memories relevant to a string query in Memory: "{self.key}". Returns a dictionary of memory IDs and their contents.',
),
]


def get_memory_provider(provider: str) -> AsyncMemoryProvider:
logger.debug(f"Loading memory provider: {provider}")

# --- async postgres ---

if provider.startswith("async-postgres"):
try:
import sqlalchemy
except ImportError:
raise ImportError(
"""To use async Postgres as a memory provider, please install the `sqlalchemy, `psycopg-pool`,
`psycopg-binary`, and `psycopg` packages."""
)

import controlflow.memory.providers.postgres as postgres_providers

return postgres_providers.AsyncPostgresMemory()
raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.')
4 changes: 3 additions & 1 deletion src/controlflow/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@ def get_memory_provider(provider: str) -> MemoryProvider:
import sqlalchemy
except ImportError:
raise ImportError(
"To use Postgres as a memory provider, please install the `sqlalchemy` package."
"""To use Postgres as a memory provider, please install the `sqlalchemy, `psycopg-pool`,
`psycopg-binary`, and `psycopg` `psycopg2-binary` packages."""
)

import controlflow.memory.providers.postgres as postgres_providers

return postgres_providers.PostgresMemory()

raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.')
Loading

0 comments on commit bcf422e

Please sign in to comment.