Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to store all world states in information state tree #1074

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 142 additions & 76 deletions open_spiel/algorithms/infostate_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ InfostateNode* InfostateNode::AddChild(std::unique_ptr<InfostateNode> child) {
return children_.back().get();
}

InfostateNode* InfostateNode::GetChild(
const std::string& infostate_string) const {
InfostateNode* InfostateNode::GetChild(std::string_view infostate_string) const {
for (const std::unique_ptr<InfostateNode>& child : children_) {
if (child->infostate_string() == infostate_string) return child.get();
}
Expand Down Expand Up @@ -161,10 +160,13 @@ void InfostateNode::SwapParent(std::unique_ptr<InfostateNode> self,
InfostateTree::InfostateTree(const std::vector<const State*>& start_states,
const std::vector<double>& chance_reach_probs,
std::shared_ptr<Observer> infostate_observer,
Player acting_player, int max_move_ahead_limit)
Player acting_player,
bool store_world_states,
int max_move_ahead_limit)
: acting_player_(acting_player),
infostate_observer_(std::move(infostate_observer)),
root_(MakeRootNode()) {
root_(MakeRootNode()),
store_all_world_states_(store_world_states) {
SPIEL_CHECK_FALSE(start_states.empty());
SPIEL_CHECK_EQ(start_states.size(), chance_reach_probs.size());
SPIEL_CHECK_GE(acting_player_, 0);
Expand All @@ -178,7 +180,8 @@ InfostateTree::InfostateTree(const std::vector<const State*>& start_states,
}

for (int i = 0; i < start_states.size(); ++i) {
RecursivelyBuildTree(root_.get(), /*depth=*/1, *start_states[i],
RecursivelyBuildTree(root_.get(), /*depth=*/1,
std::shared_ptr<const State>(start_states[i]->Clone()),
start_max_move_number + max_move_ahead_limit,
chance_reach_probs[i]);
}
Expand Down Expand Up @@ -225,11 +228,11 @@ std::unique_ptr<InfostateNode> InfostateTree::MakeNode(
: std::vector<Action>();
// Instantiate node using new to make sure that we can call
// the private constructor.
auto node = std::unique_ptr<InfostateNode>(new InfostateNode(
auto node = new InfostateNode(
*this, parent, parent->num_children(), type, infostate_string,
terminal_utility, terminal_ch_reach_prob, depth, std::move(legal_actions),
std::move(terminal_history)));
return node;
std::move(terminal_history));
return std::unique_ptr<InfostateNode>{node};
}

std::unique_ptr<InfostateNode> InfostateTree::MakeRootNode() const {
Expand All @@ -241,46 +244,56 @@ std::unique_ptr<InfostateNode> InfostateTree::MakeRootNode() const {
/*depth=*/0, /*legal_actions=*/{}, /*terminal_history=*/{}));
}

void InfostateTree::UpdateLeafNode(InfostateNode* node, const State& state,
size_t leaf_depth,
double chance_reach_probs) {
tree_height_ = std::max(tree_height_, leaf_depth);
node->corresponding_states_.push_back(state.Clone());
void InfostateTree::UpdateNode(InfostateNode* node,
std::shared_ptr<const State> state,
size_t node_depth,
double chance_reach_probs) {
tree_height_ = std::max(tree_height_, node_depth);
node->corresponding_states_.push_back(std::move(state));
node->corresponding_ch_reaches_.push_back(chance_reach_probs);
}

void InfostateTree::RecursivelyBuildTree(InfostateNode* parent, size_t depth,
const State& state, int move_limit,
std::shared_ptr<const State> state,
int move_limit,
double chance_reach_prob) {
if (state.IsTerminal())
return BuildTerminalNode(parent, depth, state, chance_reach_prob);
else if (state.IsPlayerActing(acting_player_))
return BuildDecisionNode(parent, depth, state, move_limit,
chance_reach_prob);
else
return BuildObservationNode(parent, depth, state, move_limit,
chance_reach_prob);
auto [child_node, leaf_update] = std::invoke([&] {
if (state->IsTerminal())
return BuildTerminalNode(parent, depth, state, chance_reach_prob);
else if (state->IsPlayerActing(acting_player_))
return BuildDecisionNode(parent, depth, state, move_limit,
chance_reach_prob);
else
return BuildObservationNode(parent, depth, state, move_limit,
chance_reach_prob);
});
if(store_all_world_states_ or leaf_update) {
UpdateNode(child_node, std::move(state), depth, chance_reach_prob);
}
}

void InfostateTree::BuildTerminalNode(InfostateNode* parent, size_t depth,
const State& state,
std::pair<InfostateNode*, bool>
InfostateTree::BuildTerminalNode(InfostateNode* parent, size_t depth,
const std::shared_ptr<const State>& state,
double chance_reach_prob) {
const double terminal_utility = state.Returns()[acting_player_];
const double terminal_utility = state->Returns()[acting_player_];
InfostateNode* terminal_node = parent->AddChild(
MakeNode(parent, kTerminalInfostateNode,
infostate_observer_->StringFrom(state, acting_player_),
terminal_utility, chance_reach_prob, depth, &state));
UpdateLeafNode(terminal_node, state, depth, chance_reach_prob);
infostate_observer_->StringFrom(*state, acting_player_),
terminal_utility, chance_reach_prob, depth, state.get()));
return {terminal_node, true};
}

void InfostateTree::BuildDecisionNode(InfostateNode* parent, size_t depth,
const State& state, int move_limit,
double chance_reach_prob) {
std::pair<InfostateNode*, bool>
InfostateTree::BuildDecisionNode(InfostateNode* parent,
size_t depth,
const std::shared_ptr<const State>& state,
int move_limit, double chance_reach_prob) {
SPIEL_DCHECK_EQ(parent->type(), kObservationInfostateNode);
std::string info_state =
infostate_observer_->StringFrom(state, acting_player_);
infostate_observer_->StringFrom(*state, acting_player_);
InfostateNode* decision_node = parent->GetChild(info_state);
const bool is_leaf_node = state.MoveNumber() >= move_limit;
const bool is_leaf_node = state->MoveNumber() >= move_limit;

if (decision_node) {
// The decision node has been already constructed along with children
Expand All @@ -289,50 +302,75 @@ void InfostateTree::BuildDecisionNode(InfostateNode* parent, size_t depth,
SPIEL_DCHECK_EQ(decision_node->type(), kDecisionInfostateNode);

if (is_leaf_node) { // Do not build deeper.
return UpdateLeafNode(decision_node, state, depth, chance_reach_prob);
return {decision_node, true};
}

if (state.IsSimultaneousNode()) {
const ActionView action_view(state);
for (int i = 0; i < action_view.legal_actions[acting_player_].size();
if (state->IsSimultaneousNode()) {
const ActionView action_view(*state);
for (int i = 0;
i < action_view.legal_actions[acting_player_].size();
++i) {
InfostateNode* observation_node = decision_node->child_at(i);
SPIEL_DCHECK_EQ(observation_node->type(), kObservationInfostateNode);

for (Action flat_actions :
action_view.fixed_action(acting_player_, i)) {
std::unique_ptr<State> child = state.Child(flat_actions);
RecursivelyBuildTree(observation_node, depth + 2, *child, move_limit,
auto child_state = std::shared_ptr{state->Child(flat_actions)};
// Only now we can advance the state, when we have all actions.
RecursivelyBuildTree(observation_node,
depth + 2,
child_state,
move_limit,
chance_reach_prob);
if(store_all_world_states_ and not observation_node->is_filler_node()) {
UpdateNode(observation_node,
std::move(child_state),
depth + 2,
chance_reach_prob);
}
}
}
} else {
std::vector<Action> legal_actions = state.LegalActions(acting_player_);
std::vector<Action> legal_actions = state->LegalActions(acting_player_);
for (int i = 0; i < legal_actions.size(); ++i) {
InfostateNode* observation_node = decision_node->child_at(i);
SPIEL_DCHECK_EQ(observation_node->type(), kObservationInfostateNode);
std::unique_ptr<State> child = state.Child(legal_actions.at(i));
RecursivelyBuildTree(observation_node, depth + 2, *child, move_limit,
auto child_state = std::shared_ptr{state->Child(legal_actions.at(i))};
// Only now we can advance the state, when we have all actions.
RecursivelyBuildTree(observation_node,
depth + 2,
child_state,
move_limit,
chance_reach_prob);
if(store_all_world_states_ and not observation_node->is_filler_node()) {
UpdateNode(observation_node,
std::move(child_state),
depth,
chance_reach_prob);
}
}
}
} else { // The decision node was not found yet.
decision_node = parent->AddChild(MakeNode(
parent, kDecisionInfostateNode, info_state,
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth, &state));
/*terminal_utility=*/NAN,
/*chance_reach_prob=*/NAN,
depth,
state.get()));

if (is_leaf_node) { // Do not build deeper.
return UpdateLeafNode(decision_node, state, depth, chance_reach_prob);
return {decision_node, true};
}

// Build observation nodes right away after the decision node.
// This is because the player might be acting multiple times in a row:
// each time it might get some observations that branch the infostate
// tree.

if (state.IsSimultaneousNode()) {
ActionView action_view(state);
for (int i = 0; i < action_view.legal_actions[acting_player_].size();
if (state->IsSimultaneousNode()) {
ActionView action_view(*state);
for (int i = 0;
i < action_view.legal_actions[acting_player_].size();
++i) {
// We build a dummy observation node.
// We can't ask for a proper infostate string or an originating state,
Expand All @@ -344,88 +382,115 @@ void InfostateTree::BuildDecisionNode(InfostateNode* parent, size_t depth,
/*infostate_string=*/kFillerInfostate,
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth,
/*originating_state=*/nullptr));

for (Action flat_actions :
action_view.fixed_action(acting_player_, i)) {
auto child_state = std::shared_ptr{state->Child(flat_actions)};
// Only now we can advance the state, when we have all actions.
std::unique_ptr<State> child = state.Child(flat_actions);
RecursivelyBuildTree(observation_node, depth + 2, *child, move_limit,
RecursivelyBuildTree(observation_node,
depth + 2,
child_state,
move_limit,
chance_reach_prob);
if(store_all_world_states_ and not observation_node->is_filler_node()) {
UpdateNode(observation_node,
std::move(child_state),
depth,
chance_reach_prob);
}
}
}
} else { // Not a sim move node.
for (Action a : state.LegalActions()) {
std::unique_ptr<State> child = state.Child(a);
for (Action a : state->LegalActions()) {
std::shared_ptr child = state->Child(a);
InfostateNode* observation_node = decision_node->AddChild(
MakeNode(decision_node, kObservationInfostateNode,
infostate_observer_->StringFrom(*child, acting_player_),
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth,
child.get()));
RecursivelyBuildTree(observation_node, depth + 2, *child, move_limit,
RecursivelyBuildTree(observation_node,
depth + 2,
child,
move_limit,
chance_reach_prob);
if(store_all_world_states_ and not observation_node->is_filler_node()) {
UpdateNode(observation_node,
std::move(child),
depth,
chance_reach_prob);
}
}
}
}
return {decision_node, false};
}

void InfostateTree::BuildObservationNode(InfostateNode* parent, size_t depth,
const State& state, int move_limit,
double chance_reach_prob) {
SPIEL_DCHECK_TRUE(state.IsChanceNode() ||
!state.IsPlayerActing(acting_player_));
const bool is_leaf_node = state.MoveNumber() >= move_limit;
std::pair<InfostateNode*, bool>
InfostateTree::BuildObservationNode(InfostateNode* parent, size_t depth,
const std::shared_ptr<const State>& state,
int move_limit,
double chance_reach_prob) {
SPIEL_DCHECK_TRUE(state->IsChanceNode() ||
!state->IsPlayerActing(acting_player_));
const bool is_leaf_node = state->MoveNumber() >= move_limit;
const std::string info_state =
infostate_observer_->StringFrom(state, acting_player_);
infostate_observer_->StringFrom(*state, acting_player_);

InfostateNode* observation_node = parent->GetChild(info_state);
if (!observation_node) {
observation_node = parent->AddChild(MakeNode(
parent, kObservationInfostateNode, info_state,
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth, &state));
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth, state.get()));
}
SPIEL_DCHECK_EQ(observation_node->type(), kObservationInfostateNode);

if (is_leaf_node) { // Do not build deeper.
return UpdateLeafNode(observation_node, state, depth, chance_reach_prob);
return {observation_node, true};
}

if (state.IsChanceNode()) {
for (std::pair<Action, double> action_prob : state.ChanceOutcomes()) {
std::unique_ptr<State> child = state.Child(action_prob.first);
RecursivelyBuildTree(observation_node, depth + 1, *child, move_limit,
if (state->IsChanceNode()) {
for (std::pair<Action, double> action_prob : state->ChanceOutcomes()) {
RecursivelyBuildTree(observation_node, depth + 1,
state->Child(action_prob.first), move_limit,
chance_reach_prob * action_prob.second);
}
} else {
for (Action a : state.LegalActions()) {
std::unique_ptr<State> child = state.Child(a);
RecursivelyBuildTree(observation_node, depth + 1, *child, move_limit,
for (Action a : state->LegalActions()) {
RecursivelyBuildTree(observation_node,
depth + 1,
state->Child(a),
move_limit,
chance_reach_prob);
}
}
return {observation_node, false};
}
int InfostateTree::root_branching_factor() const {
return root_->num_children();
}

std::shared_ptr<InfostateTree> MakeInfostateTree(const Game& game,
Player acting_player,
bool store_world_states,
int max_move_limit) {
// Uses new instead of make_shared, because shared_ptr is not a friend and
// can't call private constructors.
return std::shared_ptr<InfostateTree>(new InfostateTree(
{game.NewInitialState().get()}, /*chance_reach_probs=*/{1.},
game.MakeObserver(kInfoStateObsType, {}), acting_player, max_move_limit));
game.MakeObserver(kInfoStateObsType, {}), acting_player,
store_world_states, max_move_limit));
}

std::shared_ptr<InfostateTree> MakeInfostateTree(
const std::vector<InfostateNode*>& start_nodes, int max_move_ahead_limit) {
const std::vector<InfostateNode*>& start_nodes,
bool store_world_states, int max_move_ahead_limit) {
std::vector<const InfostateNode*> const_nodes(start_nodes.begin(),
start_nodes.end());
return MakeInfostateTree(const_nodes, max_move_ahead_limit);
return MakeInfostateTree(const_nodes, store_world_states, max_move_ahead_limit);
}

std::shared_ptr<InfostateTree> MakeInfostateTree(
const std::vector<const InfostateNode*>& start_nodes,
bool store_world_states,
int max_move_ahead_limit) {
SPIEL_CHECK_FALSE(start_nodes.empty());
const InfostateNode* some_node = start_nodes[0];
Expand Down Expand Up @@ -458,17 +523,18 @@ std::shared_ptr<InfostateTree> MakeInfostateTree(
// can't call private constructors.
return std::shared_ptr<InfostateTree>(new InfostateTree(
start_states, chance_reach_probs, originating_tree.infostate_observer_,
originating_tree.acting_player_, max_move_ahead_limit));
originating_tree.acting_player_, store_world_states, max_move_ahead_limit));
}

std::shared_ptr<InfostateTree> MakeInfostateTree(
const std::vector<const State*>& start_states,
const std::vector<double>& chance_reach_probs,
std::shared_ptr<Observer> infostate_observer, Player acting_player,
int max_move_ahead_limit) {
bool store_world_states, int max_move_ahead_limit) {
return std::shared_ptr<InfostateTree>(
new InfostateTree(start_states, chance_reach_probs, infostate_observer,
acting_player, max_move_ahead_limit));
new InfostateTree(start_states, chance_reach_probs,
std::move(infostate_observer), acting_player,
store_world_states, max_move_ahead_limit));
}
SequenceId InfostateTree::empty_sequence() const {
return root().sequence_id();
Expand Down
Loading