Skip to content

Commit

Permalink
allow backends to suggest minibatch size (#1877)
Browse files Browse the repository at this point in the history
* allow backends to suggest minibatch size
* simple cuda heuristic
  • Loading branch information
borg323 authored Nov 14, 2023
1 parent a2f98f7 commit 8539794
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 28 deletions.
4 changes: 2 additions & 2 deletions appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ cache:
- C:\ndk\android-ndk-r19c\toolchains\llvm\prebuilt\windows-x86_64
before_build:
- cmd: git submodule update --init --recursive
- cmd: IF %BLAS%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_MAX_PREFETCH 0 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h
- cmd: IF %ANDROID%==true (echo.#define DEFAULT_MINIBATCH_SIZE 7 & echo.#define DEFAULT_MAX_PREFETCH 0 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h
- cmd: IF %BLAS%==true (echo.#define DEFAULT_MAX_PREFETCH 0 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h
- cmd: IF %ANDROID%==true (echo.#define DEFAULT_MAX_PREFETCH 0 & echo.#define DEFAULT_TASK_WORKERS 0) > params_override.h
- cmd: SET BUILD_BLAS=%BLAS%
- cmd: IF %OPENCL%==true SET BUILD_BLAS=true
- cmd: IF %DX%==true SET BUILD_BLAS=true
Expand Down
18 changes: 7 additions & 11 deletions src/mcts/params.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018-2019 The LCZero Authors
Copyright (C) 2018-2023 The LCZero Authors
Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -38,9 +38,6 @@
#include "params_override.h"
#endif

#ifndef DEFAULT_MINIBATCH_SIZE
#define DEFAULT_MINIBATCH_SIZE 256
#endif
#ifndef DEFAULT_MAX_PREFETCH
#define DEFAULT_MAX_PREFETCH 32
#endif
Expand Down Expand Up @@ -156,7 +153,7 @@ const OptionId SearchParams::kMiniBatchSizeId{
"minibatch-size", "MinibatchSize",
"How many positions the engine tries to batch together for parallel NN "
"computation. Larger batches may reduce strength a bit, especially with a "
"small number of playouts."};
"small number of playouts. Set to 0 to use a backend suggested value."};
const OptionId SearchParams::kMaxPrefetchBatchId{
"max-prefetch", "MaxPrefetch",
"When the engine cannot gather a large enough batch for immediate use, try "
Expand Down Expand Up @@ -287,7 +284,7 @@ const OptionId SearchParams::kOutOfOrderEvalId{
"in the cache or is terminal, evaluate it right away without sending the "
"batch to the NN. When off, this may only happen with the very first node "
"of a batch; when on, this can happen with any node."};
const OptionId SearchParams::kMaxOutOfOrderEvalsId{
const OptionId SearchParams::kMaxOutOfOrderEvalsFactorId{
"max-out-of-order-evals-factor", "MaxOutOfOrderEvalsFactor",
"Maximum number of out of order evals during gathering of a batch is "
"calculated by multiplying the maximum batch size by this number."};
Expand Down Expand Up @@ -459,7 +456,7 @@ const OptionId SearchParams::kSearchSpinBackoffId{
void SearchParams::Populate(OptionsParser* options) {
// Here the uci optimized defaults" are set.
// Many of them are overridden with training specific values in tournament.cc.
options->Add<IntOption>(kMiniBatchSizeId, 1, 1024) = DEFAULT_MINIBATCH_SIZE;
options->Add<IntOption>(kMiniBatchSizeId, 0, 1024) = 0;
options->Add<IntOption>(kMaxPrefetchBatchId, 0, 1024) = DEFAULT_MAX_PREFETCH;
options->Add<FloatOption>(kCpuctId, 0.0f, 100.0f) = 1.745f;
options->Add<FloatOption>(kCpuctAtRootId, 0.0f, 100.0f) = 1.745f;
Expand Down Expand Up @@ -497,7 +494,7 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add<FloatOption>(kMaxCollisionVisitsScalingPowerId, 0.01, 100) =
1.25;
options->Add<BoolOption>(kOutOfOrderEvalId) = true;
options->Add<FloatOption>(kMaxOutOfOrderEvalsId, 0.0f, 100.0f) = 2.4f;
options->Add<FloatOption>(kMaxOutOfOrderEvalsFactorId, 0.0f, 100.0f) = 2.4f;
options->Add<BoolOption>(kStickyEndgamesId) = true;
options->Add<BoolOption>(kSyzygyFastPlayId) = false;
options->Add<IntOption>(kMultiPvId, 1, 500) = 1;
Expand Down Expand Up @@ -637,9 +634,8 @@ SearchParams::SearchParams(const OptionsDict& options)
options.Get<float>(kContemptMaxValueId),
options.Get<float>(kWDLContemptAttenuationId))),
kWDLEvalObjectivity(options.Get<float>(kWDLEvalObjectivityId)),
kMaxOutOfOrderEvals(std::max(
1, static_cast<int>(options.Get<float>(kMaxOutOfOrderEvalsId) *
options.Get<int>(kMiniBatchSizeId)))),
kMaxOutOfOrderEvalsFactor(
options.Get<float>(kMaxOutOfOrderEvalsFactorId)),
kNpsLimit(options.Get<float>(kNpsLimitId)),
kSolidTreeThreshold(options.Get<int>(kSolidTreeThresholdId)),
kTaskWorkersPerSearchWorker(
Expand Down
8 changes: 5 additions & 3 deletions src/mcts/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ class SearchParams {
float GetWDLRescaleRatio() const { return kWDLRescaleParams.ratio; }
float GetWDLRescaleDiff() const { return kWDLRescaleParams.diff; }
float GetWDLEvalObjectivity() const { return kWDLEvalObjectivity; }
int GetMaxOutOfOrderEvals() const { return kMaxOutOfOrderEvals; }
float GetMaxOutOfOrderEvalsFactor() const {
return kMaxOutOfOrderEvalsFactor;
}
float GetNpsLimit() const { return kNpsLimit; }
int GetSolidTreeThreshold() const { return kSolidTreeThreshold; }

Expand Down Expand Up @@ -215,7 +217,7 @@ class SearchParams {
static const OptionId kWDLDrawRateTargetId;
static const OptionId kWDLDrawRateReferenceId;
static const OptionId kWDLBookExitBiasId;
static const OptionId kMaxOutOfOrderEvalsId;
static const OptionId kMaxOutOfOrderEvalsFactorId;
static const OptionId kNpsLimitId;
static const OptionId kSolidTreeThresholdId;
static const OptionId kTaskWorkersPerSearchWorkerId;
Expand Down Expand Up @@ -274,7 +276,7 @@ class SearchParams {
const float kContempt;
const WDLRescaleParams kWDLRescaleParams;
const float kWDLEvalObjectivity;
const int kMaxOutOfOrderEvals;
const float kMaxOutOfOrderEvalsFactor;
const float kNpsLimit;
const int kSolidTreeThreshold;
const int kTaskWorkersPerSearchWorker;
Expand Down
14 changes: 7 additions & 7 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018-2019 The LCZero Authors
Copyright (C) 2018-2023 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -1246,9 +1246,9 @@ void SearchWorker::InitializeIteration(
std::unique_ptr<NetworkComputation> computation) {
computation_ = std::make_unique<CachingComputation>(std::move(computation),
search_->cache_);
computation_->Reserve(params_.GetMiniBatchSize());
computation_->Reserve(target_minibatch_size_);
minibatch_.clear();
minibatch_.reserve(2 * params_.GetMiniBatchSize());
minibatch_.reserve(2 * target_minibatch_size_);
}

// 2. Gather minibatch.
Expand Down Expand Up @@ -1299,8 +1299,8 @@ void SearchWorker::GatherMinibatch() {
// Gather nodes to process in the current batch.
// If we had too many nodes out of order, also interrupt the iteration so
// that search can exit.
while (minibatch_size < params_.GetMiniBatchSize() &&
number_out_of_order_ < params_.GetMaxOutOfOrderEvals()) {
while (minibatch_size < target_minibatch_size_ &&
number_out_of_order_ < max_out_of_order_) {
// If there's something to process without touching slow neural net, do it.
if (minibatch_size > 0 && computation_->GetCacheMisses() == 0) return;

Expand All @@ -1322,8 +1322,8 @@ void SearchWorker::GatherMinibatch() {
int new_start = static_cast<int>(minibatch_.size());

PickNodesToExtend(
std::min({collisions_left, params_.GetMiniBatchSize() - minibatch_size,
params_.GetMaxOutOfOrderEvals() - number_out_of_order_}));
std::min({collisions_left, target_minibatch_size_ - minibatch_size,
max_out_of_order_ - number_out_of_order_}));

// Count the non-collisions.
int non_collisions = 0;
Expand Down
11 changes: 10 additions & 1 deletion src/mcts/search.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018 The LCZero Authors
Copyright (C) 2018-2023 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -222,6 +222,13 @@ class SearchWorker {
this->RunTasks(i);
});
}
target_minibatch_size_ = params_.GetMiniBatchSize();
if (target_minibatch_size_ == 0) {
target_minibatch_size_ = search_->network_->GetMiniBatchSize();
}
max_out_of_order_ =
std::max(1, static_cast<int>(params_.GetMaxOutOfOrderEvalsFactor() *
target_minibatch_size_));
}

~SearchWorker() {
Expand Down Expand Up @@ -452,6 +459,8 @@ class SearchWorker {
// List of nodes to process.
std::vector<NodeToProcess> minibatch_;
std::unique_ptr<CachingComputation> computation_;
int target_minibatch_size_;
int max_out_of_order_;
// History is reset and extended by PickNodeToExtend().
PositionHistory history_;
int number_out_of_order_ = 0;
Expand Down
6 changes: 5 additions & 1 deletion src/neural/blas/network_blas.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018-2022 The LCZero Authors
Copyright (C) 2018-2023 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -167,6 +167,10 @@ class BlasNetwork : public Network {
return capabilities_;
}

int GetMiniBatchSize() const override {
return 7;
}

void InitThread(int id) override { Numa::BindThread(id); }

std::unique_ptr<Buffers> GetBuffers() {
Expand Down
7 changes: 7 additions & 0 deletions src/neural/cuda/network_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class CudaNetwork : public Network {
showDeviceInfo(deviceProp);

l2_cache_size_ = deviceProp.l2CacheSize;
sm_count_ = deviceProp.multiProcessorCount;

allow_cache_opt_ = options.GetOrDefault<bool>("cache_opt", false);

Expand Down Expand Up @@ -895,6 +896,11 @@ class CudaNetwork : public Network {
return capabilities_;
}

int GetMiniBatchSize() const override {
// Simple heuristic that seems to work for a wide range of GPUs.
return 2 * sm_count_;
}

std::unique_ptr<NetworkComputation> NewComputation() override {
// Set correct gpu id for this computation (as it might have been called
// from a different thread).
Expand Down Expand Up @@ -931,6 +937,7 @@ class CudaNetwork : public Network {
const NetworkCapabilities capabilities_;
int gpu_id_;
int l2_cache_size_;
int sm_count_;
int max_batch_size_;
bool wdl_;
bool moves_left_;
Expand Down
3 changes: 2 additions & 1 deletion src/neural/network.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018 The LCZero Authors
Copyright (C) 2018-2023 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -107,6 +107,7 @@ class Network {
virtual const NetworkCapabilities& GetCapabilities() const = 0;
virtual std::unique_ptr<NetworkComputation> NewComputation() = 0;
virtual void InitThread(int /*id*/) {}
virtual int GetMiniBatchSize() const { return 256; }
virtual ~Network() = default;
};

Expand Down
6 changes: 5 additions & 1 deletion src/neural/onednn/network_onednn.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2021-2022 The LCZero Authors
Copyright (C) 2021-2023 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -801,6 +801,10 @@ class OnednnNetwork : public Network {
return capabilities_;
}

int GetMiniBatchSize() const override {
return batch_size_ * steps_;
}

std::unique_ptr<NetworkComputation> NewComputation() override {
return std::make_unique<OnednnNetworkComputation>(this, wdl_, moves_left_);
}
Expand Down
6 changes: 5 additions & 1 deletion src/neural/onnx/network_onnx.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2021 The LCZero Authors
Copyright (C) 2021-2023 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -94,6 +94,10 @@ class OnnxNetwork : public Network {
const NetworkCapabilities& GetCapabilities() const override {
return capabilities_;
}
int GetMiniBatchSize() const override {
return batch_size_ == -1 ? Network::GetMiniBatchSize()
: batch_size_ * steps_;
}

Ort::Env onnx_env_;
// Prepare sessions for this many multiples of the batch size;
Expand Down

0 comments on commit 8539794

Please sign in to comment.