Skip to content

Commit

Permalink
Expose CorrDevBuilder to Python and add example use of the CorrDist f…
Browse files Browse the repository at this point in the history
…unctions from python.

PiperOrigin-RevId: 572869884
Change-Id: Id5075edd8a234e4a838926152d22d4d0248f3558
  • Loading branch information
lanctot committed Oct 12, 2023
1 parent a44607a commit 8859b79
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
5 changes: 5 additions & 0 deletions open_spiel/algorithms/cfr.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ class CFRSolverBase {
return std::make_shared<CFRCurrentPolicy>(info_states_, nullptr);
}

TabularPolicy TabularCurrentPolicy() const {
CFRCurrentPolicy policy(info_states_, nullptr);
return policy.AsTabular();
}

CFRInfoStateValuesTable& InfoStateValuesTable() { return info_states_; }

// See comments above CFRInfoStateValues::Serialize(double_precision) for
Expand Down
53 changes: 53 additions & 0 deletions open_spiel/python/algorithms/cfr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,5 +272,58 @@ def test_cpp_algorithms_identical_to_python_algorithm(self, game, cpp_class,
self.assertEqual(cpp_expl, python_expl)


class CorrDistTest(absltest.TestCase):
"""Test some of the correlation device distances functions in C++.
These functions are analogues to NashConv for various forms of correlated
equilibria.
"""

def test_cce_dist_kuhn_3p_cpp(self):
game = pyspiel.load_game("kuhn_poker(players=3)")
solver = pyspiel.CFRSolver(game) # C++ solver
strategies = []
corr_dist_values = []
for _ in range(10):
solver.evaluate_and_update_policy()
strategies.append(solver.tabular_current_policy())
corr_dev = pyspiel.uniform_correlation_device(strategies)
cce_dist_info = pyspiel.cce_dist(game, corr_dev)
corr_dist_values.append(cce_dist_info.dist_value)
self.assertLess(corr_dist_values[-1], corr_dist_values[0])

def test_cce_dist_kuhn_3p(self):
game = pyspiel.load_game("kuhn_poker(players=3)")
solver = cfr._CFRSolver(game,
regret_matching_plus=False,
linear_averaging=False,
alternating_updates=True)
strategies = []
corr_dist_values = []
for _ in range(10):
solver.evaluate_and_update_policy()
# Convert the policy to a pyspiel.TabularPolicy, needed by the CorrDist
# functions on the C++ side.
strategies.append(policy.python_policy_to_pyspiel_policy(
solver.current_policy()))
corr_dev = pyspiel.uniform_correlation_device(strategies)
cce_dist_info = pyspiel.cce_dist(game, corr_dev)
corr_dist_values.append(cce_dist_info.dist_value)
self.assertLess(corr_dist_values[-1], corr_dist_values[0])

def test_cce_dist_sheriff_cpp(self):
game = pyspiel.load_game("sheriff")
solver = pyspiel.CFRSolver(game) # C++ solver
strategies = []
corr_dist_values = []
for _ in range(3):
solver.evaluate_and_update_policy()
strategies.append(solver.tabular_current_policy())
corr_dev = pyspiel.uniform_correlation_device(strategies)
cce_dist_info = pyspiel.cce_dist(game, corr_dev)
corr_dist_values.append(cce_dist_info.dist_value)
self.assertLess(corr_dist_values[-1], corr_dist_values[0])


if __name__ == "__main__":
absltest.main()
21 changes: 19 additions & 2 deletions open_spiel/python/pybind11/algorithms_corr_dist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@

#include "open_spiel/python/pybind11/algorithms_corr_dist.h"

// Python bindings for trajectories.h
#include <memory>

#include "open_spiel/algorithms/corr_dev_builder.h"
#include "open_spiel/algorithms/corr_dist.h"
#include "open_spiel/python/pybind11/pybind11.h"
#include "open_spiel/spiel.h"
#include "pybind11/include/pybind11/cast.h"
#include "pybind11/include/pybind11/pybind11.h"

namespace open_spiel {
namespace py = ::pybind11;

using open_spiel::algorithms::CorrDevBuilder;
using open_spiel::algorithms::CorrDistInfo;
using open_spiel::algorithms::CorrelationDevice;

Expand Down Expand Up @@ -50,6 +53,20 @@ void init_pyspiel_algorithms_corr_dist(py::module& m) {
.def_readonly("conditional_best_response_policies",
&CorrDistInfo::conditional_best_response_policies);

py::class_<CorrDevBuilder> corr_dev_builder(m, "CorrDevBuilder");
corr_dev_builder.def(py::init<int>(), py::arg("seed") = 0)
.def("add_deterministic_joint_policy",
&CorrDevBuilder::AddDeterminsticJointPolicy,
py::arg("policy"), py::arg("weight") = 1.0)
.def("add_sampled_joint_policy",
&CorrDevBuilder::AddSampledJointPolicy,
py::arg("policy"), py::arg("num_samples"), py::arg("weight") = 1.0)
.def("add_mixed_joint_policy",
&CorrDevBuilder::AddMixedJointPolicy,
py::arg("policy"),
py::arg("weight") = 1.0)
.def("get_correlation_device", &CorrDevBuilder::GetCorrelationDevice);

m.def(
"cce_dist",
[](std::shared_ptr<const Game> game,
Expand Down
2 changes: 2 additions & 0 deletions open_spiel/python/pybind11/policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ void init_pyspiel_policy(py::module& m) {
.def("average_policy", &open_spiel::algorithms::CFRSolver::AveragePolicy)
.def("tabular_average_policy",
&open_spiel::algorithms::CFRSolver::TabularAveragePolicy)
.def("tabular_current_policy",
&open_spiel::algorithms::CFRSolver::TabularCurrentPolicy)
.def(py::pickle(
[](const open_spiel::algorithms::CFRSolver& solver) { // __getstate__
return solver.Serialize();
Expand Down

0 comments on commit 8859b79

Please sign in to comment.