Skip to content

Commit

Permalink
Refactor logging out of components as an intermediate step to cleanin…
Browse files Browse the repository at this point in the history
…g up `get_last_log` uses.

This change does a few things:
* Add a new entity agent that implements `get_last_log` and remove `get_last_log` from EntityAgent.
* Remove all `verbose` and `color` from new components, as well as their `get_last_log`. Instead, pass a channel for outputting debug data to log.
* Modify the agent factory to use the new system.

PiperOrigin-RevId: 653792934
Change-Id: I76e75ab2ecf0fc926b2dc773fbb23c2459616aab
  • Loading branch information
duenez authored and copybara-github committed Jul 18, 2024
1 parent 2bdbec9 commit 8fe5710
Show file tree
Hide file tree
Showing 15 changed files with 237 additions and 189 deletions.
7 changes: 0 additions & 7 deletions concordia/agents/entity_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,3 @@ def observe(self, observation: str) -> None:

self._phase = component_v2.Phase.UPDATE
self._parallel_call_('update')

def get_last_log(self):
logs = self._parallel_call_('get_last_log')
return {
'__act__': self._act_component.get_last_log(),
**logs,
}
88 changes: 88 additions & 0 deletions concordia/agents/entity_agent_with_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A modular entity agent using the new component system with side logging."""

from collections.abc import Mapping
import types
from typing import Any

from concordia.agents import entity_agent
from concordia.typing import agent
from concordia.typing import component_v2
from concordia.utils import measurements as measurements_lib

import reactivex as rx


class EntityAgentWithLogging(entity_agent.EntityAgent, agent.GenerativeAgent):
"""An agent that exposes the latest information of each component."""

def __init__(
self,
agent_name: str,
act_component: component_v2.ActingComponent,
context_processor: component_v2.ContextProcessorComponent | None = None,
context_components: Mapping[str, component_v2.ContextComponent] = (
types.MappingProxyType({})
),
component_logging: measurements_lib.Measurements | None = None,
):
"""Initializes the agent.
The passed components will be owned by this entity agent (i.e. their
`set_entity` method will be called with this entity as the argument).
Whenever `get_last_log` is called, the latest values published in all the
channels in the given measurements object will be returned as a mapping of
channel name to value.
Args:
agent_name: The name of the agent.
act_component: The component that will be used to act.
context_processor: The component that will be used to process contexts. If
None, a NoOpContextProcessor will be used.
context_components: The ContextComponents that will be used by the agent.
component_logging: The channels where components publish events.
"""
super().__init__(agent_name=agent_name,
act_component=act_component,
context_processor=context_processor,
context_components=context_components)
self._log: Mapping[str, Any] = {}
self._tick = rx.subject.Subject()
self._component_logging = component_logging
if self._component_logging is not None:
self._channel_names = list(self._component_logging.available_channels())
channels = [
self._component_logging.get_channel(channel) # pylint: disable=attribute-error pytype mistakenly forgets that `_component_logging` is not None.
for channel in self._channel_names
]
rx.with_latest_from(self._tick, *channels).subscribe(self._set_log)
else:
self._channel_names = []

def _set_log(self, log: tuple[Any, ...]) -> None:
"""Set the logging object to return from get_last_log.
Args:
log: A tuple with the tick first, and the latest log from each component.
"""
tick_value, *channel_values = log
assert tick_value is None
self._log = dict(zip(self._channel_names, channel_values, strict=True))

def get_last_log(self):
self._tick.on_next(None) # Trigger the logging.
return self._log
26 changes: 8 additions & 18 deletions concordia/components/agent/v2/all_similar_memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from concordia.document import interactive_document
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
import termcolor
from concordia.typing import logging


_ASSOCIATIVE_RETRIEVAL = legacy_associative_memory.RetrieveAssociative()
Expand All @@ -42,7 +42,7 @@ def __init__(
),
num_memories_to_retrieve: int = 25,
pre_act_key: str = 'Relevant memories',
verbose: bool = False,
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
):
"""Initialize a component to report relevant memories (similar to a prompt).
Expand All @@ -54,16 +54,15 @@ def __init__(
num_memories_to_retrieve: The number of memories to retrieve.
pre_act_key: Prefix to add to the output of the component when called
in `pre_act`.
verbose: Whether to print the state of the component.
logging_channel: The channel to log debug information to.
"""
super().__init__(pre_act_key)
self._verbose = verbose
self._model = model
self._memory_component_name = memory_component_name
self._state = ''
self._components = dict(components)
self._num_memories_to_retrieve = num_memories_to_retrieve
self._last_log = None
self._logging_channel = logging_channel

def _make_pre_act_value(self) -> str:
agent_name = self.get_entity().name
Expand Down Expand Up @@ -108,21 +107,12 @@ def _make_pre_act_value(self) -> str:
terminators=(),
)

if self._verbose:
print(termcolor.colored(prompt.view().text(), 'green'), end='')
print(termcolor.colored(f'Query: {query}\n', 'green'), end='')
print(termcolor.colored(new_prompt.view().text(), 'green'), end='')
print(termcolor.colored(result, 'green'), end='')

self._last_log = {
'State': result,
self._logging_channel({
'Key': self.get_pre_act_key(),
'Value': result,
'Initial chain of thought': prompt.view().text().splitlines(),
'Query': f'{query}',
'Final chain of thought': new_prompt.view().text().splitlines(),
}
})

return result

def get_last_log(self):
if self._last_log:
return self._last_log.copy()
33 changes: 20 additions & 13 deletions concordia/components/agent/v2/concat_act_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
from concordia.typing import clock as game_clock
from concordia.typing import component_v2
from concordia.typing import entity as entity_lib
from concordia.typing import logging
from concordia.utils import helper_functions
from typing_extensions import override

DEFAULT_PRE_ACT_KEY = 'Act'


class ConcatActComponent(component_v2.ActingComponent):
"""A component which concatenates contexts from context components.
Expand All @@ -42,6 +45,8 @@ def __init__(
model: language_model.LanguageModel,
clock: game_clock.GameClock,
component_order: Sequence[str] | None = None,
pre_act_key: str = DEFAULT_PRE_ACT_KEY,
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
):
"""Initializes the agent.
Expand All @@ -58,6 +63,8 @@ def __init__(
component cannot appear twice in the component order. All components in
the component order must be in the `ComponentContextMapping` passed to
`get_action_attempt`.
pre_act_key: Prefix to add to the context of the component.
logging_channel: The channel to use for debug logging.
Raises:
ValueError: If the component order is not None and contains duplicate
Expand All @@ -76,7 +83,8 @@ def __init__(
+ ', '.join(self._component_order)
)

self._last_log = None
self._pre_act_key = pre_act_key
self._logging_channel = logging_channel

def _context_for_action(
self,
Expand Down Expand Up @@ -116,14 +124,14 @@ def get_action_attempt(
max_tokens=2200,
answer_prefix=output,
)
self._make_update_log(output, prompt)
self._log(output, prompt)
return output
elif action_spec.output_type == entity_lib.OutputType.CHOICE:
idx = prompt.multiple_choice_question(
question=call_to_action, answers=action_spec.options
)
output = action_spec.options[idx]
self._make_update_log(output, prompt)
self._log(output, prompt)
return output
elif action_spec.output_type == entity_lib.OutputType.FLOAT:
prefix = self.get_entity().name + ' '
Expand All @@ -132,7 +140,7 @@ def get_action_attempt(
max_tokens=2200,
answer_prefix=prefix,
)
self._make_update_log(sampled_text, prompt)
self._log(sampled_text, prompt)
try:
return str(float(sampled_text))
except ValueError:
Expand All @@ -143,12 +151,11 @@ def get_action_attempt(
'Supported output types are: FREE, CHOICE, and FLOAT.'
)

def _make_update_log(self,
result: str,
prompt: interactive_document.InteractiveDocument):
self._last_log = {'Output': result,
'Prompt': prompt.view().text().splitlines()}

def get_last_log(self):
if self._last_log:
return self._last_log.copy()
def _log(self,
result: str,
prompt: interactive_document.InteractiveDocument):
self._logging_channel({
'Key': self._pre_act_key,
'Value': result,
'Prompt': prompt.view().text().splitlines(),
})
6 changes: 3 additions & 3 deletions concordia/components/agent/v2/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""A simple acting component that aggregates contexts from components."""

from concordia.components.agent.v2 import action_spec_ignored
from concordia.typing import logging

DEFAULT_PRE_ACT_KEY = 'Constant'

Expand All @@ -27,13 +28,15 @@ def __init__(
self,
state: str,
pre_act_key: str = DEFAULT_PRE_ACT_KEY,
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
):
"""Initializes the agent.
Args:
state: the state of the component.
pre_act_key: Prefix to add to the output of the component when called
in `pre_act`.
logging_channel: The channel to use for debug logging.
Raises:
ValueError: If the component order is not None and contains duplicate
Expand All @@ -44,6 +47,3 @@ def __init__(

def _make_pre_act_value(self) -> str:
return self._state

def get_last_log(self):
return {'State': self._state}
28 changes: 5 additions & 23 deletions concordia/components/agent/v2/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from concordia.memory_bank import legacy_associative_memory
from concordia.typing import component_v2
from concordia.typing import entity as entity_lib
from concordia.typing import logging
from concordia.utils import concurrency
import termcolor


DEFAULT_PRE_ACT_KEY = 'Identity characteristics'
Expand Down Expand Up @@ -54,8 +54,7 @@ def __init__(
),
num_memories_to_retrieve: int = 25,
pre_act_key: str = DEFAULT_PRE_ACT_KEY,
verbose: bool = False,
log_color: str = 'green',
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
):
"""Initialize an identity component.
Expand All @@ -67,19 +66,16 @@ def __init__(
num_memories_to_retrieve: how many related memories to retrieve per query
pre_act_key: Prefix to add to the output of the component when called
in `pre_act`.
verbose: whether or not to print the result for debugging
log_color: color to print the debug log
logging_channel: The channel to use for debug logging.
"""
super().__init__(pre_act_key)
self._model = model
self._memory_component_name = memory_component_name
self._last_log = None

self._queries = queries
self._num_memories_to_retrieve = num_memories_to_retrieve

self._verbose = verbose
self._log_color = log_color
self._logging_channel = logging_channel

def _query_memory(self, query: str) -> str:
agent_name = self.get_entity().name
Expand Down Expand Up @@ -111,21 +107,10 @@ def _make_pre_act_value(self) -> str:
[f'{query}: {result}' for query, result in zip(self._queries, results)]
)

self._last_log = {
'State': output,
}
if self._verbose:
self._log(output)
self._logging_channel({'Key': self.get_pre_act_key(), 'Value': output})

return output

def _log(self, entry: str):
print(termcolor.colored(entry, self._log_color), end='')

def get_last_log(self):
if self._last_log:
return self._last_log.copy()


class IdentityWithoutPreAct(action_spec_ignored.ActionSpecIgnored):
"""An identity component that does not output its state to pre_act."""
Expand All @@ -151,6 +136,3 @@ def pre_act(

def update(self) -> None:
self._component.update()

def get_last_log(self):
return self._component.get_last_log()
13 changes: 9 additions & 4 deletions concordia/components/agent/v2/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
"""Component that provides the default role playing instructions to an agent."""

from concordia.components.agent.v2 import constant
from concordia.typing import logging

DEFAULT_INSTRUCTIONS_PRE_ACT_KEY = 'Role playing instructions'


class Instructions(constant.Constant):
"""A component that provides the role playing instructions for the agent."""

def __init__(self,
agent_name: str,
pre_act_key: str = DEFAULT_INSTRUCTIONS_PRE_ACT_KEY):
def __init__(
self,
agent_name: str,
pre_act_key: str = DEFAULT_INSTRUCTIONS_PRE_ACT_KEY,
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
):
state = (
f'The instructions for how to play the role of {agent_name} are as '
'follows. This is a social science experiment studying how well you '
Expand All @@ -38,4 +42,5 @@ def __init__(self,
f'into account all information about {agent_name} that you have. '
'Always use third-person limited perspective.'
)
super().__init__(state=state, pre_act_key=pre_act_key)
super().__init__(
state=state, pre_act_key=pre_act_key, logging_channel=logging_channel)
Loading

0 comments on commit 8fe5710

Please sign in to comment.