Skip to content

Commit

Permalink
Teachability for any agent (#1091)
Browse files Browse the repository at this point in the history
* Partial implementation

* Partial implementation

* Fixes

* update tests

* cleanup

* update tests

* comments

* logging

* wording

* underscore

* Extend notebook for teachable GPTAssistantAgent

* Notebook for teachable GPTAssistantAgents

* Update notebook

* Update notebook

* Update notebook

* Update notebook

* revert file

* Update blog post and other documentation.

* pre-commit

* Address reviewer feedback.

* Add new nb link to examples page.

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
  • Loading branch information
rickyloynd-microsoft and sonichi authored Jan 7, 2024
1 parent 172df55 commit 3680197
Show file tree
Hide file tree
Showing 13 changed files with 1,719 additions and 472 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,11 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for TeachableAgent
- name: Install packages and dependencies for Teachability
run: |
pip install -e .[teachable]
pip uninstall -y openai
- name: Test TeachableAgent
- name: Test Teachability
if: matrix.python-version != '3.9' # diversify the python versions
run: |
pytest test/agentchat/contrib/test_teachable_agent.py
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ key_aoai.txt
base_aoai.txt
wolfram.txt

# DB on disk for TeachableAgent
# DB on disk for Teachability
tmp/
test/my_tmp/*

Expand Down
15 changes: 15 additions & 0 deletions autogen/agentchat/contrib/capabilities/agent_capability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from autogen.agentchat.assistant_agent import ConversableAgent


class AgentCapability:
"""Base class for composable capabilities that can be added to an agent."""

def __init__(self):
pass

def add_to_agent(self, agent: ConversableAgent):
"""
Adds a particular capability to the given agent. Must be implemented by the capability subclass.
An implementation will typically call agent.register_hook() one or more times. See teachability.py as an example.
"""
raise NotImplementedError

Large diffs are not rendered by default.

10 changes: 1 addition & 9 deletions autogen/agentchat/contrib/text_analyzer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@ def __init__(
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options.
To disable llm-based auto reply, set to False.
teach_config (dict or None): Additional parameters used by TeachableAgent.
To use default config, set to None. Otherwise, set to a dictionary with any of the following keys:
- verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
- reset_db (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored. Default "./tmp/teachable_agent_db"
- prepopulate (Optional, int): True (default) to prepopulate the DB with a set of input-output pairs.
- recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
- max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
**kwargs (dict): other kwargs in [ConversableAgent](../conversable_agent#__init__).
"""
super().__init__(
Expand All @@ -56,7 +48,7 @@ def _analyze_in_reply(
) -> Tuple[bool, Union[str, Dict, None]]:
"""Analyzes the given text as instructed, and returns the analysis as a message.
Assumes exactly two messages containing the text to analyze and the analysis instructions.
See TeachableAgent.analyze for an example of how to use this method."""
See Teachability.analyze for an example of how to use this method."""
if self.llm_config is False:
raise ValueError("TextAnalyzerAgent requires self.llm_config to be set in its base class.")
if messages is None:
Expand Down
67 changes: 66 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def __init__(
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.

def register_reply(
self,
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
Expand Down Expand Up @@ -757,7 +761,7 @@ def generate_code_execution_reply(
else:
messages_to_scan += 1

# iterate through the last n messages reversely
# iterate through the last n messages in reverse
# if code blocks are found, execute the code blocks and return the output
# if no code blocks are found, continue
for i in range(min(len(messages), messages_to_scan)):
Expand Down Expand Up @@ -1173,6 +1177,10 @@ def generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
Expand Down Expand Up @@ -1225,6 +1233,10 @@ async def a_generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
Expand Down Expand Up @@ -1757,3 +1769,56 @@ def _decorator(func: F) -> F:
return func

return _decorator

def register_hook(self, hookable_method: Callable, hook: Callable):
"""
Registers a hook to be called by a hookable method, in order to add a capability to the agent.
Registered hooks are kept in lists (one per hookable method), and are called in their order of registration.
Args:
hookable_method: A hookable method implemented by ConversableAgent.
hook: A method implemented by a subclass of AgentCapability.
"""
assert hookable_method in self.hook_lists, f"{hookable_method} is not a hookable method."
hook_list = self.hook_lists[hookable_method]
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)

def process_last_message(self, messages):
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
as long as the last message is not a function call or exit command.
"""

# If any required condition is not met, return the original message list.
hook_list = self.hook_lists[self.process_last_message]
if len(hook_list) == 0:
return messages # No hooks registered.
if messages is None:
return None # No message to process.
if len(messages) == 0:
return messages # No message to process.
last_message = messages[-1]
if "function_call" in last_message:
return messages # Last message is a function call.
if "context" in last_message:
return messages # Last message contains a context key.
if "content" not in last_message:
return messages # Last message has no content.
user_text = last_message["content"]
if not isinstance(user_text, str):
return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here.
if user_text == "exit":
return messages # Last message is an exit command.

# Call each hook (in order of registration) to process the user's message.
processed_user_text = user_text
for hook in hook_list:
processed_user_text = hook(processed_user_text)
if processed_user_text == user_text:
return messages # No hooks actually modified the user's message.

# Replace the last user message with the expanded one.
messages = messages.copy()
messages[-1]["content"] = processed_user_text
return messages
Loading

0 comments on commit 3680197

Please sign in to comment.