diff --git a/agent/llama_guard.py b/agent/llama_guard.py index 40b44a3..3c1523a 100644 --- a/agent/llama_guard.py +++ b/agent/llama_guard.py @@ -1,3 +1,4 @@ +import os from langchain_core.messages import AnyMessage, HumanMessage, AIMessage from langchain_core.prompts import PromptTemplate from langchain_groq import ChatGroq @@ -75,10 +76,16 @@ def parse_llama_guard_output(output: str) -> LlamaGuardOutput: class LlamaGuard: def __init__(self): + if os.getenv("GROQ_API_KEY") is None: + print("GROQ_API_KEY not set, skipping LlamaGuard") + self.model = None + return self.model = ChatGroq(model="llama-guard-3-8b", temperature=0.0) self.prompt = PromptTemplate.from_template(llama_guard_instructions) async def ainvoke(self, role: str, messages: List[AnyMessage]) -> LlamaGuardOutput: + if self.model is None: + return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE) role_mapping = {"ai": "Agent", "human": "User"} messages_str = [ f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"]