diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d8d9a27..de169a4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,6 +37,3 @@ jobs: - name: Test with pytest run: | pytest - env: - OPENAI_API_KEY: sk-fake-openai-key - GROQ_API_KEY: gsk_fake_groq_key diff --git a/README.md b/README.md index 57b487d..05b27d1 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,10 @@ With that said, there are several other interesting projects in this space that 2. Set up environment variables: Create a `.env` file in the root directory and add the following: ``` + # OPENAI_API_KEY is the only required variable OPENAI_API_KEY=your_openai_api_key + + # Optional, to enable LlamaGuard and Llama 3.1 GROQ_API_KEY=your_groq_api_key # Optional, to enable simple header-based auth on the service diff --git a/agent/llama_guard.py b/agent/llama_guard.py index 3a0b440..40b44a3 100644 --- a/agent/llama_guard.py +++ b/agent/llama_guard.py @@ -1,4 +1,4 @@ -from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage +from langchain_core.messages import AnyMessage, HumanMessage, AIMessage from langchain_core.prompts import PromptTemplate from langchain_groq import ChatGroq from typing import List @@ -39,7 +39,7 @@ class LlamaGuardOutput(BaseModel): categories_str = "\n".join([f"{k}: {v}" for k, v in unsafe_content_categories.items()]) llama_guard_instructions = f""" -[INST] Task: Check if there is unsafe content in '{{role}}' messages in conversations according our safety policy with the below categories. +Task: Check if there is unsafe content in '{{role}}' messages in conversations according to our safety policy with the below categories. {categories_str} @@ -51,17 +51,9 @@ class LlamaGuardOutput(BaseModel): -Provide your safety assessment for {{role}} in the above conversation: +Provide your safety assessment for ONLY THE LAST {{role}} in the above conversation: - First line must read 'safe' or 'unsafe'. -- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" - - -llama_guard_prompt = PromptTemplate.from_template(llama_guard_instructions) -model = ChatGroq(model="llama-guard-3-8b", temperature=0.0) - -# Alternate version running on Replicate, also slow :| -# from langchain_community.llms.replicate import Replicate -# model = Replicate(model="meta/meta-llama-guard-2-8b:b063023ee937f28e922982abdbf97b041ffe34ad3b35a53d33e1d74bb19b36c4") +- If unsafe, a second line must include a comma-separated list of violated categories.""" def parse_llama_guard_output(output: str) -> LlamaGuardOutput: @@ -81,28 +73,34 @@ def parse_llama_guard_output(output: str) -> LlamaGuardOutput: return LlamaGuardOutput(safety_assessment=SafetyAssessment.ERROR) -async def llama_guard(role: str, messages: List[AnyMessage]) -> LlamaGuardOutput: - role_mapping = {"ai": "Agent", "human": "User"} - messages_str = [ - f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"] - ] - conversation_history = "\n\n".join(messages_str) - compiled_prompt = llama_guard_prompt.format( - role=role, conversation_history=conversation_history - ) - result = await model.ainvoke([SystemMessage(content=compiled_prompt)]) - return parse_llama_guard_output(result.content) +class LlamaGuard: + def __init__(self): + self.model = ChatGroq(model="llama-guard-3-8b", temperature=0.0) + self.prompt = PromptTemplate.from_template(llama_guard_instructions) + + async def ainvoke(self, role: str, messages: List[AnyMessage]) -> LlamaGuardOutput: + role_mapping = {"ai": "Agent", "human": "User"} + messages_str = [ + f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"] + ] + conversation_history = "\n\n".join(messages_str) + compiled_prompt = self.prompt.format(role=role, conversation_history=conversation_history) + result = await self.model.ainvoke([HumanMessage(content=compiled_prompt)]) + return parse_llama_guard_output(result.content) if __name__ == "__main__": import asyncio async def main(): - output = await llama_guard( + llama_guard = LlamaGuard() + output = await llama_guard.ainvoke( "Agent", [ - HumanMessage(content="Tell me a fun fact?"), - AIMessage(content="Did you know that honey never spoils?"), + HumanMessage(content="What's a good way to harm an animal?"), + AIMessage( + content="There are many ways to harm animals, but some include hitting them with a stick, throwing rocks at them, or poisoning them." + ), ], ) print(output) diff --git a/agent/research_assistant.py b/agent/research_assistant.py index 8eea701..f4eb977 100644 --- a/agent/research_assistant.py +++ b/agent/research_assistant.py @@ -1,4 +1,5 @@ from datetime import datetime +import os from langchain_openai import ChatOpenAI from langchain_groq import ChatGroq from langchain_core.language_models.chat_models import BaseChatModel @@ -10,7 +11,7 @@ from langgraph.prebuilt import ToolNode from agent.tools import calculator, web_search -from agent.llama_guard import llama_guard, LlamaGuardOutput +from agent.llama_guard import LlamaGuard, LlamaGuardOutput class AgentState(MessagesState): @@ -22,9 +23,11 @@ class AgentState(MessagesState): # if the /stream endpoint is called with stream_tokens=True (the default) models = { "gpt-4o-mini": ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True), - "llama-3.1-70b": ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5), } +if os.getenv("GROQ_API_KEY") is not None: + models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5) + tools = [web_search, calculator] current_date = datetime.now().strftime("%B %d, %Y") instructions = f""" @@ -68,7 +71,8 @@ async def acall_model(state: AgentState, config: RunnableConfig): async def llama_guard_input(state: AgentState, config: RunnableConfig): - safety_output = await llama_guard("User", state["messages"]) + llama_guard = LlamaGuard() + safety_output = await llama_guard.ainvoke("User", state["messages"]) return {"safety": safety_output} diff --git a/pyproject.toml b/pyproject.toml index 4024b31..b4793d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,9 +44,13 @@ dev = [ "httpx~=0.26.0", "pre-commit", "pytest", + "pytest-env", "ruff", ] [tool.ruff] line-length = 100 target-version = "py39" + +[tool.pytest_env] +OPENAI_API_KEY = "sk-fake-openai-key"