Skip to content

Commit

Permalink
switching all the examples to new measurments system
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588364902
Change-Id: Ia79696dc079e94960cb463c6daa01868d2569260
  • Loading branch information
vezhnick authored and copybara-github committed Dec 6, 2023
1 parent 6c6b49a commit 3db52d7
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 213 deletions.
10 changes: 0 additions & 10 deletions examples/phone/calendar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -502,20 +502,10 @@
"## Summary and analysis of the episode"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j71OiuPot5UV"
},
"source": [
"## Save results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "O4jp0xGXvOAJ"
},
"outputs": [],
Expand Down
168 changes: 102 additions & 66 deletions examples/three_key_questions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"source": [
"# @title Imports\n",
"\n",
"import collections\n",
"import concurrent.futures\n",
"import datetime\n",
"\n",
Expand All @@ -84,9 +85,10 @@
"from concordia.clocks import game_clock\n",
"from concordia.components import game_master as gm_components\n",
"from concordia.environment import game_master\n",
"from concordia.environment.metrics import common_sense_morality\n",
"from concordia.environment.metrics import goal_achievement\n",
"from concordia.environment.metrics import reputation\n",
"from concordia.metrics import goal_achievement\n",
"from concordia.metrics import common_sense_morality\n",
"from concordia.metrics import opinion_of_others\n",
"from concordia.utils import measurements as measurements_lib\n",
"from concordia.google.language_model import sax_model\n",
"from concordia.utils import html as html_lib\n",
"from concordia.utils import plotting\n"
Expand Down Expand Up @@ -248,7 +250,7 @@
" gender='female',\n",
" goal='Alice wants Bob to accept his car is trashed and back off.',\n",
" context=shared_context,\n",
" traits='responsibility: high; aggression: low',\n",
" traits='responsibility: low; aggression: high',\n",
" ),\n",
" formative_memories.AgentConfig(\n",
" name='Bob',\n",
Expand Down Expand Up @@ -285,8 +287,11 @@
},
"outputs": [],
"source": [
"def build_agent(agent_config):\n",
"\n",
"def build_agent(\n",
" agent_config,\n",
" player_names: list[str],\n",
" measurements: measurements_lib.Measurements | None = None,\n",
"):\n",
" mem = formative_memory_factory.make_memories(agent_config)\n",
"\n",
" self_perception = components.self_perception.SelfPerception(\n",
Expand All @@ -298,50 +303,92 @@
" verbose=True,\n",
" )\n",
" situation_perception = components.situation_perception.SituationPerception(\n",
" name='situation perception',\n",
" model=model,\n",
" memory=mem,\n",
" agent_name=agent_config.name,\n",
" clock_now=clock.now,\n",
" verbose=True,\n",
" )\n",
" name='situation perception',\n",
" model=model,\n",
" memory=mem,\n",
" agent_name=agent_config.name,\n",
" clock_now=clock.now,\n",
" verbose=True,\n",
" )\n",
" person_by_situation = components.person_by_situation.PersonBySituation(\n",
" name='person by situation',\n",
" model=model,\n",
" memory=mem,\n",
" agent_name=agent_config.name,\n",
" clock_now=clock.now,\n",
" components=[self_perception, situation_perception],\n",
" verbose=True,\n",
" )\n",
" name='person by situation',\n",
" model=model,\n",
" memory=mem,\n",
" agent_name=agent_config.name,\n",
" clock_now=clock.now,\n",
" components=[self_perception, situation_perception],\n",
" verbose=True,\n",
" )\n",
" persona = components.sequential.Sequential(\n",
" name='persona',\n",
" components=[\n",
" self_perception,\n",
" situation_perception,\n",
" person_by_situation,\n",
" ],\n",
" )\n",
" name='persona',\n",
" components=[\n",
" self_perception,\n",
" situation_perception,\n",
" person_by_situation,\n",
" ],\n",
" )\n",
" current_time_component = components.report_function.ReportFunction(\n",
" name='current_time',\n",
" function=clock.current_time_interval_str)\n",
" name='current_time', function=clock.current_time_interval_str\n",
" )\n",
"\n",
" current_obs = components.observation.Observation(agent_config.name, mem)\n",
" summary_obs = components.observation.ObservationSummary(\n",
" model=model,\n",
" agent_name=agent_config.name,\n",
" components=[persona],\n",
" )\n",
" model=model,\n",
" agent_name=agent_config.name,\n",
" components=[persona],\n",
" )\n",
"\n",
" goal_metric = goal_achievement.GoalAchievementMetric(\n",
" model=model,\n",
" player_name=agent_config.name,\n",
" player_goal=agent_config.goal,\n",
" clock=clock,\n",
" name='Goal Achievement',\n",
" measurements=measurements,\n",
" channel='goal_achievement',\n",
" verbose=False,\n",
" )\n",
" morality_metric = common_sense_morality.CommonSenseMoralityMetric(\n",
" model=model,\n",
" player_name=agent_config.name,\n",
" clock=clock,\n",
" name='Morality',\n",
" verbose=False,\n",
" measurements=measurements,\n",
" channel='common_sense_morality',\n",
" )\n",
"\n",
" agent = basic_agent.BasicAgent(\n",
" model,\n",
" mem,\n",
" agent_name=agent_config.name,\n",
" clock=clock,\n",
" verbose=False,\n",
" components=[persona, current_time_component,summary_obs, current_obs],\n",
" components=[\n",
" persona,\n",
" current_time_component,\n",
" summary_obs,\n",
" current_obs,\n",
" goal_metric,\n",
" morality_metric,\n",
" ],\n",
" update_interval=time_step,\n",
" )\n",
" return agent\n"
" reputation_metric = opinion_of_others.OpinionOfOthersMetric(\n",
" model=model,\n",
" player_name=agent_config.name,\n",
" player_names=player_names,\n",
" context_fn=agent.state,\n",
" clock=clock,\n",
" name='Opinion',\n",
" verbose=False,\n",
" measurements=measurements,\n",
" channel='opinion_of_others',\n",
" question=\"What is {opining_player}'s opinion of {of_player}?\",\n",
" )\n",
" agent.add_component(reputation_metric)\n",
"\n",
" return agent"
]
},
{
Expand All @@ -353,11 +400,18 @@
"outputs": [],
"source": [
"player_configs = player_configs[:NUM_PLAYERS]\n",
"player_names = [player.name for player in player_configs][:NUM_PLAYERS]\n",
"measurements = measurements_lib.Measurements()\n",
"\n",
"players = []\n",
"\n",
"with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PLAYERS) as pool:\n",
" for agent in pool.map(build_agent, player_configs[:NUM_PLAYERS]):\n",
" for agent in pool.map(build_agent,\n",
" player_configs[:NUM_PLAYERS],\n",
" # All players get the same `player_names`.\n",
" [player_names] * NUM_PLAYERS,\n",
" # All players get the same `measurements` object.\n",
" [measurements] * NUM_PLAYERS):\n",
" players.append(agent)\n"
]
},
Expand Down Expand Up @@ -458,29 +512,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5SpNVmlh6_hp"
},
"outputs": [],
"source": [
"# @title Metrics\n",
"player_goals = {\n",
" player_config.name: player_config.goal for player_config in player_configs\n",
"}\n",
"\n",
"goal_metric = goal_achievement.GoalAchievementMetric(\n",
" model, player_goals, clock, 'Goal achievement', verbose=False)\n",
"morality_metric = common_sense_morality.CommonSenseMoralityMetric(\n",
" model, players, clock, 'Morality', verbose=False)\n",
"reputation_metric = reputation.ReputationMetric(\n",
" model, players, clock, 'Reputation', verbose=False)\n",
"\n",
"metrics = [goal_metric, morality_metric, reputation_metric]"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -502,7 +533,6 @@
" convo_externality,\n",
" direct_effect_externality,\n",
" ],\n",
" measurements=metrics,\n",
" randomise_initiative=True,\n",
" player_observes_event=False,\n",
" verbose=True,\n",
Expand Down Expand Up @@ -584,12 +614,18 @@
"outputs": [],
"source": [
"# @title Metrics plotting\n",
"tb = widgets.TabBar([metric.name() for metric in metrics])\n",
"\n",
"for metric in metrics:\n",
" with tb.output_to(metric.name()):\n",
" plotting.plot_metric_line(metric)\n",
" plotting.plot_metric_pie(metric)\n"
"colab_import.reload_module(plotting)\n",
"\n",
"group_by = collections.defaultdict(lambda: 'player')\n",
"group_by['opinion_of_others'] = 'of_player'\n",
"\n",
"tb = widgets.TabBar([channel for channel in measurements.available_channels()])\n",
"for channel in measurements.available_channels():\n",
" with tb.output_to(channel):\n",
" plotting.plot_line_measurement_channel(measurements, channel,\n",
" group_by=group_by[channel],\n",
" xaxis='time_str')"
]
},
{
Expand Down
Loading

0 comments on commit 3db52d7

Please sign in to comment.