Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify ENV requirements #21

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,3 @@ jobs:
- name: Test with pytest
run: |
pytest
env:
OPENAI_API_KEY: sk-fake-openai-key
GROQ_API_KEY: gsk_fake_groq_key
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ With that said, there are several other interesting projects in this space that
2. Set up environment variables:
Create a `.env` file in the root directory and add the following:
```
# OPENAI_API_KEY is the only required variable
OPENAI_API_KEY=your_openai_api_key

# Optional, to enable LlamaGuard and Llama 3.1
GROQ_API_KEY=your_groq_api_key

# Optional, to enable simple header-based auth on the service
Expand Down
50 changes: 24 additions & 26 deletions agent/llama_guard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
from typing import List
Expand Down Expand Up @@ -39,7 +39,7 @@ class LlamaGuardOutput(BaseModel):
categories_str = "\n".join([f"{k}: {v}" for k, v in unsafe_content_categories.items()])

llama_guard_instructions = f"""
[INST] Task: Check if there is unsafe content in '{{role}}' messages in conversations according our safety policy with the below categories.
Task: Check if there is unsafe content in '{{role}}' messages in conversations according to our safety policy with the below categories.

<BEGIN UNSAFE CONTENT CATEGORIES>
{categories_str}
Expand All @@ -51,17 +51,9 @@ class LlamaGuardOutput(BaseModel):

<END CONVERSATION>

Provide your safety assessment for {{role}} in the above conversation:
Provide your safety assessment for ONLY THE LAST {{role}} in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""


llama_guard_prompt = PromptTemplate.from_template(llama_guard_instructions)
model = ChatGroq(model="llama-guard-3-8b", temperature=0.0)

# Alternate version running on Replicate, also slow :|
# from langchain_community.llms.replicate import Replicate
# model = Replicate(model="meta/meta-llama-guard-2-8b:b063023ee937f28e922982abdbf97b041ffe34ad3b35a53d33e1d74bb19b36c4")
- If unsafe, a second line must include a comma-separated list of violated categories."""


def parse_llama_guard_output(output: str) -> LlamaGuardOutput:
Expand All @@ -81,28 +73,34 @@ def parse_llama_guard_output(output: str) -> LlamaGuardOutput:
return LlamaGuardOutput(safety_assessment=SafetyAssessment.ERROR)


async def llama_guard(role: str, messages: List[AnyMessage]) -> LlamaGuardOutput:
role_mapping = {"ai": "Agent", "human": "User"}
messages_str = [
f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"]
]
conversation_history = "\n\n".join(messages_str)
compiled_prompt = llama_guard_prompt.format(
role=role, conversation_history=conversation_history
)
result = await model.ainvoke([SystemMessage(content=compiled_prompt)])
return parse_llama_guard_output(result.content)
class LlamaGuard:
def __init__(self):
self.model = ChatGroq(model="llama-guard-3-8b", temperature=0.0)
self.prompt = PromptTemplate.from_template(llama_guard_instructions)

async def ainvoke(self, role: str, messages: List[AnyMessage]) -> LlamaGuardOutput:
role_mapping = {"ai": "Agent", "human": "User"}
messages_str = [
f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"]
]
conversation_history = "\n\n".join(messages_str)
compiled_prompt = self.prompt.format(role=role, conversation_history=conversation_history)
result = await self.model.ainvoke([HumanMessage(content=compiled_prompt)])
return parse_llama_guard_output(result.content)


if __name__ == "__main__":
import asyncio

async def main():
output = await llama_guard(
llama_guard = LlamaGuard()
output = await llama_guard.ainvoke(
"Agent",
[
HumanMessage(content="Tell me a fun fact?"),
AIMessage(content="Did you know that honey never spoils?"),
HumanMessage(content="What's a good way to harm an animal?"),
AIMessage(
content="There are many ways to harm animals, but some include hitting them with a stick, throwing rocks at them, or poisoning them."
),
],
)
print(output)
Expand Down
10 changes: 7 additions & 3 deletions agent/research_assistant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
import os
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -10,7 +11,7 @@
from langgraph.prebuilt import ToolNode

from agent.tools import calculator, web_search
from agent.llama_guard import llama_guard, LlamaGuardOutput
from agent.llama_guard import LlamaGuard, LlamaGuardOutput


class AgentState(MessagesState):
Expand All @@ -22,9 +23,11 @@ class AgentState(MessagesState):
# if the /stream endpoint is called with stream_tokens=True (the default)
models = {
"gpt-4o-mini": ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True),
"llama-3.1-70b": ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5),
}

if os.getenv("GROQ_API_KEY") is not None:
models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5)

tools = [web_search, calculator]
current_date = datetime.now().strftime("%B %d, %Y")
instructions = f"""
Expand Down Expand Up @@ -68,7 +71,8 @@ async def acall_model(state: AgentState, config: RunnableConfig):


async def llama_guard_input(state: AgentState, config: RunnableConfig):
safety_output = await llama_guard("User", state["messages"])
llama_guard = LlamaGuard()
safety_output = await llama_guard.ainvoke("User", state["messages"])
return {"safety": safety_output}


Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ dev = [
"httpx~=0.26.0",
"pre-commit",
"pytest",
"pytest-env",
"ruff",
]

[tool.ruff]
line-length = 100
target-version = "py39"

[tool.pytest_env]
OPENAI_API_KEY = "sk-fake-openai-key"
Loading