Skip to content

Commit

Permalink
Increase use of workspace. (LeelaChessZero#1498)
Browse files Browse the repository at this point in the history
3% increase in nps on benchmark.
100% increase in nps in #1 positions.
  • Loading branch information
Tilps authored Jan 30, 2021
1 parent e93a2a8 commit d2e03fd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
4 changes: 4 additions & 0 deletions src/chess/position.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class PositionHistory {
public:
PositionHistory() = default;
PositionHistory(const PositionHistory& other) = default;
PositionHistory(PositionHistory&& other) = default;

PositionHistory& operator=(const PositionHistory& other) = default;
PositionHistory& operator=(PositionHistory&& other) = default;

// Returns first position of the game (or fen from which it was initialized).
const Position& Starting() const { return positions_.front(); }
Expand Down
22 changes: 11 additions & 11 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1345,11 +1345,11 @@ void SearchWorker::GatherMinibatch2() {
}

void SearchWorker::ProcessPickedTask(int start_idx, int end_idx,
TaskWorkspace*) {
TaskWorkspace* workspace) {
auto& history = workspace->history;
// This code runs multiple passes of work across the same input in order to
// reduce taking/dropping mutexes in quick succession.
PositionHistory history = search_->played_history_;
history.Reserve(search_->played_history_.GetLength() + 30);
history = search_->played_history_;

// First pass - Extend nodes.
for (int i = start_idx; i < end_idx; i++) {
Expand Down Expand Up @@ -1491,15 +1491,15 @@ void SearchWorker::PickNodesToExtendTask(Node* node, int base_depth,
// with tasks.
// TODO: pre-reserve visits_to_perform for expected depth and likely maximum
// width. Maybe even do so outside of lock scope.
std::vector<std::unique_ptr<std::array<int, 256>>> visits_to_perform;
auto& vtp_buffer = workspace->vtp_buffer;
visits_to_perform.reserve(30);
std::vector<int> vtp_last_filled;
vtp_last_filled.reserve(30);
std::vector<int> current_path;
current_path.reserve(30);
std::vector<Move> moves_to_path = moves_to_base;
moves_to_path.reserve(30);
auto& visits_to_perform = workspace->visits_to_perform;
visits_to_perform.clear();
auto& vtp_last_filled = workspace->vtp_last_filled;
vtp_last_filled.clear();
auto& current_path = workspace->current_path;
current_path.clear();
auto& moves_to_path = workspace->moves_to_path;
moves_to_path = moves_to_base;
// Sometimes receiver is reused, othertimes not, so only jump start if small.
if (receiver->capacity() < 30) {
receiver->reserve(receiver->size() + 30);
Expand Down
13 changes: 13 additions & 0 deletions src/mcts/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,19 @@ class SearchWorker {
struct TaskWorkspace {
std::array<Node::Iterator, 256> cur_iters;
std::vector<std::unique_ptr<std::array<int, 256>>> vtp_buffer;
std::vector<std::unique_ptr<std::array<int, 256>>> visits_to_perform;
std::vector<int> vtp_last_filled;
std::vector<int> current_path;
std::vector<Move> moves_to_path;
PositionHistory history;
TaskWorkspace() {
vtp_buffer.reserve(30);
visits_to_perform.reserve(30);
vtp_last_filled.reserve(30);
current_path.reserve(30);
moves_to_path.reserve(30);
history.Reserve(30);
}
};

struct PickTask {
Expand Down

0 comments on commit d2e03fd

Please sign in to comment.