Skip to content

Commit

Permalink
Change parameter passing style for Game objects in order to support P…
Browse files Browse the repository at this point in the history
…ython games.

The pybind smart_holder logic will create a shared_ptr for Python-created objects only when required to do so. This means that if a Python-implemented game is passed from Python to C++ as Game& and then a C++ function calls shared_from_this() on it, this will fail unless there's already a C++ shared_ptr for some other reason.

The fix is either:
a - Amend the C++ interface to take shared_ptr instead of refs
b - Introduce a lambda function in the pybind interface, taking a shared_ptr and dereferencing it to call the ref-based C++ implementation

Either option will result in pybind creating a shared_ptr for us before calling our C++ code.

To minimize disruption to existing code, and forestall future failures, I've applied change (b) everywhere I could see, even though not every case was failing (because not every case called shared_from_this in the C++ implementation).

For further details of the relevant pybind internals, see pybind/pybind11#3023

fixes: #905
PiperOrigin-RevId: 469016236
Change-Id: I9467eeb992f3463a432cc7060c46404d2bbd4638
  • Loading branch information
elkhrt authored and lanctot committed Aug 21, 2022
1 parent f4121ea commit c9f2e37
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 97 deletions.
5 changes: 5 additions & 0 deletions open_spiel/python/games/kuhn_poker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def test_exploitability_uniform_random_cc(self):
self.assertAlmostEqual(
pyspiel.exploitability(game, test_policy), expected_nash_conv / 2)

def test_cfr_cc(self):
"""Runs a C++ CFR algorithm on the game."""
game = pyspiel.load_game("python_kuhn_poker")
unused_results = pyspiel.CFRSolver(game)


if __name__ == "__main__":
absltest.main()
48 changes: 29 additions & 19 deletions open_spiel/python/pybind11/algorithms_corr_dist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,37 @@ void init_pyspiel_algorithms_corr_dist(py::module& m) {
.def_readonly("conditional_best_response_policies",
&CorrDistInfo::conditional_best_response_policies);

m.def("cce_dist",
py::overload_cast<const Game&, const CorrelationDevice&, int, float>(
&open_spiel::algorithms::CCEDist),
"Returns a player's distance to a coarse-correlated equilibrium.",
py::arg("game"),
py::arg("correlation_device"),
py::arg("player"),
py::arg("prob_cut_threshold") = -1.0);
m.def(
"cce_dist",
[](std::shared_ptr<const Game> game,
const CorrelationDevice& correlation_device, int player,
float prob_cut_threshold) {
return algorithms::CCEDist(*game, correlation_device, player,
prob_cut_threshold);
},
"Returns a player's distance to a coarse-correlated equilibrium.",
py::arg("game"), py::arg("correlation_device"), py::arg("player"),
py::arg("prob_cut_threshold") = -1.0);

m.def("cce_dist",
py::overload_cast<const Game&, const CorrelationDevice&, float>(
&open_spiel::algorithms::CCEDist),
"Returns the distance to a coarse-correlated equilibrium.",
py::arg("game"),
py::arg("correlation_device"),
py::arg("prob_cut_threshold") = -1.0);
m.def(
"cce_dist",
[](std::shared_ptr<const Game> game,
const CorrelationDevice& correlation_device,
float prob_cut_threshold) {
return algorithms::CCEDist(*game, correlation_device,
prob_cut_threshold);
},
"Returns the distance to a coarse-correlated equilibrium.",
py::arg("game"), py::arg("correlation_device"),
py::arg("prob_cut_threshold") = -1.0);

m.def("ce_dist",
py::overload_cast<const Game&, const CorrelationDevice&>(
&open_spiel::algorithms::CEDist),
"Returns the distance to a correlated equilibrium.");
m.def(
"ce_dist",
[](std::shared_ptr<const Game> game,
const CorrelationDevice& correlation_device) {
return algorithms::CEDist(*game, correlation_device);
},
"Returns the distance to a correlated equilibrium.");

// TODO(author5): expose the rest of the functions.
}
Expand Down
27 changes: 19 additions & 8 deletions open_spiel/python/pybind11/algorithms_trajectories.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,28 @@ void init_pyspiel_algorithms_trajectories(py::module& m) {
.def("resize_fields",
&open_spiel::algorithms::BatchedTrajectory::ResizeFields);

m.def("record_batched_trajectories",
py::overload_cast<
const Game&, const std::vector<open_spiel::TabularPolicy>&,
const std::unordered_map<std::string, int>&, int, bool, int, int>(
&open_spiel::algorithms::RecordBatchedTrajectory),
"Records a batch of trajectories.");
m.def(
"record_batched_trajectories",
[](std::shared_ptr<const Game> game,
const std::vector<TabularPolicy>& policies,
const std::unordered_map<std::string, int>& state_to_index,
int batch_size, bool include_full_observations, int seed,
int max_unroll_length) {
return open_spiel::algorithms::RecordBatchedTrajectory(
*game, policies, state_to_index, batch_size,
include_full_observations, seed, max_unroll_length);
},
"Records a batch of trajectories.");

py::class_<open_spiel::algorithms::TrajectoryRecorder>(m,
"TrajectoryRecorder")
.def(py::init<const Game&, const std::unordered_map<std::string, int>&,
int>())
.def(py::init(
[](std::shared_ptr<const Game> game,
const std::unordered_map<std::string, int>& state_to_index,
int seed) {
return new algorithms::TrajectoryRecorder(*game, state_to_index,
seed);
}))
.def("record_batch",
&open_spiel::algorithms::TrajectoryRecorder::RecordBatch);
}
Expand Down
51 changes: 33 additions & 18 deletions open_spiel/python/pybind11/bots.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,20 @@ void init_pyspiel_bots(py::module& m) {
"Returns a list of registered bot names.");
m.def(
"bots_that_can_play_game",
py::overload_cast<const Game&, Player>(&open_spiel::BotsThatCanPlayGame),
[](std::shared_ptr<const Game> game, int player) {
return BotsThatCanPlayGame(*game, player);
},
py::arg("game"), py::arg("player"),
"Returns a list of bot names that can play specified game for the "
"given player.");
m.def("bots_that_can_play_game",
py::overload_cast<const Game&>(&open_spiel::BotsThatCanPlayGame),
py::arg("game"),
"Returns a list of bot names that can play specified game for any "
"player.");
m.def(
"bots_that_can_play_game",
[](std::shared_ptr<const Game> game) {
return BotsThatCanPlayGame(*game);
},
py::arg("game"),
"Returns a list of bot names that can play specified game for any "
"player.");

py::class_<algorithms::Evaluator,
std::shared_ptr<algorithms::Evaluator>> mcts_evaluator(
Expand Down Expand Up @@ -223,14 +228,21 @@ void init_pyspiel_bots(py::module& m) {
.def("children_str", &SearchNode::ChildrenStr);

py::class_<algorithms::MCTSBot, Bot>(m, "MCTSBot")
.def(py::init<const Game&, std::shared_ptr<Evaluator>, double, int,
int64_t, bool, int, bool,
::open_spiel::algorithms::ChildSelectionPolicy>(),
py::arg("game"), py::arg("evaluator"), py::arg("uct_c"),
py::arg("max_simulations"), py::arg("max_memory_mb"),
py::arg("solve"), py::arg("seed"), py::arg("verbose"),
py::arg("child_selection_policy") =
algorithms::ChildSelectionPolicy::UCT)
.def(
py::init([](std::shared_ptr<const Game> game,
std::shared_ptr<Evaluator> evaluator, double uct_c,
int max_simulations, int64_t max_memory_mb, bool solve,
int seed, bool verbose,
algorithms::ChildSelectionPolicy child_selection_policy) {
return new algorithms::MCTSBot(
*game, evaluator, uct_c, max_simulations, max_memory_mb, solve,
seed, verbose, child_selection_policy);
}),
py::arg("game"), py::arg("evaluator"), py::arg("uct_c"),
py::arg("max_simulations"), py::arg("max_memory_mb"),
py::arg("solve"), py::arg("seed"), py::arg("verbose"),
py::arg("child_selection_policy") =
algorithms::ChildSelectionPolicy::UCT)
.def("step", &algorithms::MCTSBot::Step)
.def("mcts_search", &algorithms::MCTSBot::MCTSearch);

Expand Down Expand Up @@ -270,10 +282,13 @@ void init_pyspiel_bots(py::module& m) {

m.def("make_stateful_random_bot", open_spiel::MakeStatefulRandomBot,
"A stateful random bot, for test purposes.");
m.def("make_policy_bot",
py::overload_cast<const Game&, Player, int, std::shared_ptr<Policy>>(
open_spiel::MakePolicyBot),
"A bot that samples from a policy.");
m.def(
"make_policy_bot",
[](std::shared_ptr<const Game> game, Player player_id, int seed,
std::shared_ptr<Policy> policy) {
return MakePolicyBot(*game, player_id, seed, policy);
},
"A bot that samples from a policy.");

#if OPEN_SPIEL_BUILD_WITH_ROSHAMBO
m.attr("ROSHAMBO_NUM_THROWS") = py::int_(open_spiel::roshambo::kNumThrows);
Expand Down
29 changes: 16 additions & 13 deletions open_spiel/python/pybind11/game_transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,35 @@ namespace py = ::pybind11;

void init_pyspiel_game_transforms(py::module& m) {
m.def("load_game_as_turn_based",
py::overload_cast<const std::string&>(&open_spiel::LoadGameAsTurnBased),
py::overload_cast<const std::string&>(&LoadGameAsTurnBased),
"Converts a simultaneous game into an turn-based game with infosets.");

m.def("load_game_as_turn_based",
py::overload_cast<const std::string&, const GameParameters&>(
&open_spiel::LoadGameAsTurnBased),
&LoadGameAsTurnBased),
"Converts a simultaneous game into an turn-based game with infosets.");

m.def("extensive_to_tensor_game", open_spiel::ExtensiveToTensorGame,
m.def("extensive_to_tensor_game", ExtensiveToTensorGame,
"Converts an extensive-game to its equivalent tensor game, "
"which is exponentially larger. Use only with small games.");

m.def("convert_to_turn_based",
[](const std::shared_ptr<open_spiel::Game>& game) {
return open_spiel::ConvertToTurnBased(*game);
},
"Returns a turn-based version of the given game.");
m.def(
"convert_to_turn_based",
[](std::shared_ptr<const Game> game) {
return ConvertToTurnBased(*game);
},
"Returns a turn-based version of the given game.");

m.def("create_repeated_game",
py::overload_cast<const Game&, const GameParameters&>(
&open_spiel::CreateRepeatedGame),
"Creates a repeated game from a stage game.");
m.def(
"create_repeated_game",
[](std::shared_ptr<const Game> game, const GameParameters& params) {
return CreateRepeatedGame(*game, params);
},
"Creates a repeated game from a stage game.");

m.def("create_repeated_game",
py::overload_cast<const std::string&, const GameParameters&>(
&open_spiel::CreateRepeatedGame),
&CreateRepeatedGame),
"Creates a repeated game from a stage game.");
}
} // namespace open_spiel
7 changes: 5 additions & 2 deletions open_spiel/python/pybind11/observer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ void init_pyspiel_observer(py::module& m) {
// C++ Observation, intended only for the Python Observation class, not
// for general Python code.
py::class_<Observation>(m, "_Observation", py::buffer_protocol())
.def(py::init<const Game&, std::shared_ptr<Observer>>(), py::arg("game"),
py::arg("observer"))
.def(py::init([](std::shared_ptr<const Game> game,
std::shared_ptr<Observer> observer) {
return new Observation(*game, observer);
}),
py::arg("game"), py::arg("observer"))
.def("tensors", &Observation::tensors)
.def("tensors_info", &Observation::tensors_info)
.def("string_from", &Observation::StringFrom)
Expand Down
86 changes: 55 additions & 31 deletions open_spiel/python/pybind11/policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ void init_pyspiel_policy(py::module& m) {
&open_spiel::PreferredActionPolicy::GetStatePolicy);

py::class_<open_spiel::algorithms::CFRSolver>(m, "CFRSolver")
.def(py::init<const Game&>())
.def(py::init([](std::shared_ptr<const Game> game) {
return new algorithms::CFRSolver(*game);
}))
.def("evaluate_and_update_policy",
&open_spiel::algorithms::CFRSolver::EvaluateAndUpdatePolicy)
.def("current_policy", &open_spiel::algorithms::CFRSolver::CurrentPolicy)
Expand All @@ -147,7 +149,9 @@ void init_pyspiel_policy(py::module& m) {
}));

py::class_<open_spiel::algorithms::CFRPlusSolver>(m, "CFRPlusSolver")
.def(py::init<const Game&>())
.def(py::init([](std::shared_ptr<const Game> game) {
return new algorithms::CFRPlusSolver(*game);
}))
.def("evaluate_and_update_policy",
&open_spiel::algorithms::CFRPlusSolver::EvaluateAndUpdatePolicy)
.def("current_policy", &open_spiel::algorithms::CFRSolver::CurrentPolicy)
Expand All @@ -163,7 +167,9 @@ void init_pyspiel_policy(py::module& m) {
}));

py::class_<open_spiel::algorithms::CFRBRSolver>(m, "CFRBRSolver")
.def(py::init<const Game&>())
.def(py::init([](std::shared_ptr<const Game> game) {
return new algorithms::CFRBRSolver(*game);
}))
.def("evaluate_and_update_policy",
&open_spiel::algorithms::CFRPlusSolver::EvaluateAndUpdatePolicy)
.def("current_policy", &open_spiel::algorithms::CFRSolver::CurrentPolicy)
Expand All @@ -184,7 +190,11 @@ void init_pyspiel_policy(py::module& m) {

py::class_<open_spiel::algorithms::ExternalSamplingMCCFRSolver>(
m, "ExternalSamplingMCCFRSolver")
.def(py::init<const Game&, int, open_spiel::algorithms::AverageType>(),
.def(py::init([](std::shared_ptr<const Game> game, int seed,
algorithms::AverageType average_type) {
return new algorithms::ExternalSamplingMCCFRSolver(*game, seed,
average_type);
}),
py::arg("game"), py::arg("seed") = 0,
py::arg("avg_type") = open_spiel::algorithms::AverageType::kSimple)
.def("run_iteration",
Expand All @@ -204,7 +214,12 @@ void init_pyspiel_policy(py::module& m) {

py::class_<open_spiel::algorithms::OutcomeSamplingMCCFRSolver>(
m, "OutcomeSamplingMCCFRSolver")
.def(py::init<const Game&, double, int>(), py::arg("game"),
.def(py::init(
[](std::shared_ptr<const Game> game, double epsilon, int seed) {
return new algorithms::OutcomeSamplingMCCFRSolver(
*game, epsilon, seed);
}),
py::arg("game"),
py::arg("epsilon") = open_spiel::algorithms::
OutcomeSamplingMCCFRSolver::kDefaultEpsilon,
py::arg("seed") = -1)
Expand Down Expand Up @@ -267,45 +282,54 @@ void init_pyspiel_policy(py::module& m) {
py::arg("use_infostate_get_policy"),
py::arg("prob_cut_threshold") = 0.0);

m.def("exploitability",
py::overload_cast<const Game&, const Policy&>(&Exploitability),
"Returns the sum of the utility that a best responder wins when when "
"playing against 1) the player 0 policy contained in `policy` and 2) "
"the player 1 policy contained in `policy`."
"This only works for two player, zero- or constant-sum sequential "
"games, and raises a SpielFatalError if an incompatible game is passed "
"to it.");
m.def(
"exploitability",
[](std::shared_ptr<const Game> game, const Policy& policy) {
return Exploitability(*game, policy);
},
"Returns the sum of the utility that a best responder wins when when "
"playing against 1) the player 0 policy contained in `policy` and 2) "
"the player 1 policy contained in `policy`."
"This only works for two player, zero- or constant-sum sequential "
"games, and raises a SpielFatalError if an incompatible game is passed "
"to it.");

m.def(
"exploitability",
py::overload_cast<
const Game&, const std::unordered_map<std::string, ActionsAndProbs>&>(
&Exploitability),
[](std::shared_ptr<const Game> game,
const std::unordered_map<std::string, ActionsAndProbs>& policy) {
return Exploitability(*game, policy);
},
"Returns the sum of the utility that a best responder wins when when "
"playing against 1) the player 0 policy contained in `policy` and 2) "
"the player 1 policy contained in `policy`."
"This only works for two player, zero- or constant-sum sequential "
"games, and raises a SpielFatalError if an incompatible game is passed "
"to it.");

m.def("nash_conv",
py::overload_cast<const Game&, const Policy&, bool>(&NashConv),
"Calculates a measure of how far the given policy is from a Nash "
"equilibrium by returning the sum of the improvements in the value "
"that each player could obtain by unilaterally changing their strategy "
"while the opposing player maintains their current strategy (which "
"for a Nash equilibrium, this value is 0). The third parameter is to "
"indicate whether to use the Policy::GetStatePolicy(const State&) "
"instead of Policy::GetStatePolicy(const std::string& info_state) for "
"computation of the on-policy expected values.",
py::arg("game"), py::arg("policy"),
py::arg("use_state_get_policy") = false);
m.def(
"nash_conv",
[](std::shared_ptr<const Game> game, const Policy& policy,
bool use_state_get_policy) {
return NashConv(*game, policy, use_state_get_policy);
},
"Calculates a measure of how far the given policy is from a Nash "
"equilibrium by returning the sum of the improvements in the value "
"that each player could obtain by unilaterally changing their strategy "
"while the opposing player maintains their current strategy (which "
"for a Nash equilibrium, this value is 0). The third parameter is to "
"indicate whether to use the Policy::GetStatePolicy(const State&) "
"instead of Policy::GetStatePolicy(const std::string& info_state) for "
"computation of the on-policy expected values.",
py::arg("game"), py::arg("policy"),
py::arg("use_state_get_policy") = false);

m.def(
"nash_conv",
py::overload_cast<
const Game&, const std::unordered_map<std::string, ActionsAndProbs>&>(
&NashConv),
[](std::shared_ptr<const Game> game,
const std::unordered_map<std::string, ActionsAndProbs>& policy) {
return NashConv(*game, policy);
},
"Calculates a measure of how far the given policy is from a Nash "
"equilibrium by returning the sum of the improvements in the value "
"that each player could obtain by unilaterally changing their strategy "
Expand Down
Loading

0 comments on commit c9f2e37

Please sign in to comment.