Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Attempt to optimize the position skipping to fit a distribution. #199

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions lib/nnue_training_data_formats.h
Original file line number Diff line number Diff line change
Expand Up @@ -7560,12 +7560,14 @@ namespace binpack
int concurrency,
std::string path,
std::ios_base::openmode om = std::ios_base::app,
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr,
std::function<void(std::vector<TrainingDataEntry>&)> bufferFilter = nullptr
) :
m_concurrency(concurrency),
m_inputFile(path, om),
m_bufferOffset(0),
m_skipPredicate(std::move(skipPredicate))
m_skipPredicate(std::move(skipPredicate)),
m_bufferFilter(std::move(bufferFilter))
{
m_numRunningWorkers.store(0);
if (!m_inputFile.hasNextChunk())
Expand Down Expand Up @@ -7639,6 +7641,11 @@ namespace binpack
// now shuffle the local buffer
auto& prng = rng::get_thread_local_rng();
std::shuffle(m_localBuffer.begin(), m_localBuffer.end(), prng);
if (m_bufferFilter)
{
m_bufferFilter(m_localBuffer);
std::shuffle(m_localBuffer.begin(), m_localBuffer.end(), prng);
}

std::unique_lock lock(m_waitingBufferMutex);
m_waitingBufferEmpty.wait(lock, [this]() { return m_waitingBuffer.empty() || m_stopFlag.load(); });
Expand Down Expand Up @@ -7756,6 +7763,7 @@ namespace binpack
std::condition_variable m_waitingBufferEmpty;
std::condition_variable m_waitingBufferFull;
std::function<bool(const TrainingDataEntry&)> m_skipPredicate;
std::function<void(std::vector<TrainingDataEntry>&)> m_bufferFilter;

std::vector<std::thread> m_workers;

Expand Down
20 changes: 16 additions & 4 deletions lib/nnue_training_data_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,14 @@ namespace training_data {
static constexpr auto openmode = std::ios::in | std::ios::binary;
static inline const std::string extension = "binpack";

BinpackSfenInputParallelStream(int concurrency, std::string filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filename, openmode, skipPredicate)),
BinpackSfenInputParallelStream(
int concurrency,
std::string filename,
bool cyclic,
std::function<bool(const TrainingDataEntry&)> skipPredicate,
std::function<void(std::vector<TrainingDataEntry>&)> bufferFilter
) :
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filename, openmode, std::move(skipPredicate), std::move(bufferFilter))),
m_filename(filename),
m_concurrency(concurrency),
m_eof(false),
Expand Down Expand Up @@ -272,13 +278,19 @@ namespace training_data {
return nullptr;
}

inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::string& filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr)
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(
int concurrency,
const std::string& filename,
bool cyclic,
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr,
std::function<void(std::vector<TrainingDataEntry>&)> bufferFilter = nullptr
)
{
// TODO (low priority): optimize and parallelize .bin reading.
if (has_extension(filename, BinSfenInputStream::extension))
return std::make_unique<BinSfenInputStream>(filename, cyclic, std::move(skipPredicate));
else if (has_extension(filename, BinpackSfenInputParallelStream::extension))
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filename, cyclic, std::move(skipPredicate));
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filename, cyclic, std::move(skipPredicate), std::move(bufferFilter));

return nullptr;
}
Expand Down
202 changes: 121 additions & 81 deletions training_data_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,14 @@ struct Stream : AnyStream
{
using StorageType = StorageT;

Stream(int concurrency, const char* filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
m_stream(training_data::open_sfen_input_file_parallel(concurrency, filename, cyclic, skipPredicate))
Stream(
int concurrency,
const char* filename,
bool cyclic,
std::function<bool(const TrainingDataEntry&)> skipPredicate,
std::function<void(std::vector<TrainingDataEntry>&)> bufferFilter
) :
m_stream(training_data::open_sfen_input_file_parallel(concurrency, filename, cyclic, std::move(skipPredicate), std::move(bufferFilter)))
{
}

Expand All @@ -478,8 +484,14 @@ struct AsyncStream : Stream<StorageT>
{
using BaseType = Stream<StorageT>;

AsyncStream(int concurrency, const char* filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
BaseType(1, filename, cyclic, skipPredicate)
AsyncStream(
int concurrency,
const char* filename,
bool cyclic,
std::function<bool(const TrainingDataEntry&)> skipPredicate,
std::function<void(std::vector<TrainingDataEntry>&)> bufferFilter
) :
BaseType(1, filename, cyclic, std::move(skipPredicate), std::move(bufferFilter))
{
}

Expand All @@ -505,15 +517,23 @@ struct FeaturedBatchStream : Stream<StorageT>

static constexpr int num_feature_threads_per_reading_thread = 2;

FeaturedBatchStream(int concurrency, const char* filename, int batch_size, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
FeaturedBatchStream(
int concurrency,
const char* filename,
int batch_size,
bool cyclic,
std::function<bool(const TrainingDataEntry&)> skipPredicate,
std::function<void(std::vector<TrainingDataEntry>&)> bufferFilter
) :
BaseType(
std::max(
1,
concurrency / num_feature_threads_per_reading_thread
),
filename,
cyclic,
skipPredicate
std::move(skipPredicate),
std::move(bufferFilter)
),
m_concurrency(concurrency),
m_batch_size(batch_size)
Expand Down Expand Up @@ -698,7 +718,8 @@ struct FenBatchStream : Stream<FenBatch>
),
filename,
cyclic,
skipPredicate
skipPredicate,
nullptr // TODO: if needed
),
m_concurrency(concurrency),
m_batch_size(batch_size)
Expand Down Expand Up @@ -808,6 +829,90 @@ struct FenBatchStream : Stream<FenBatch>
std::vector<std::thread> m_workers;
};

std::function<void(std::vector<TrainingDataEntry>&)> make_buffer_filter()
{
return [
](std::vector<TrainingDataEntry>& buffer) {
static constexpr double desired_piece_count_weights[33] = {
1.000000,
1.121094, 1.234375, 1.339844, 1.437500, 1.527344, 1.609375, 1.683594, 1.750000,
1.808594, 1.859375, 1.902344, 1.937500, 1.964844, 1.984375, 1.996094, 2.000000,
1.996094, 1.984375, 1.964844, 1.937500, 1.902344, 1.859375, 1.808594, 1.750000,
1.683594, 1.609375, 1.527344, 1.437500, 1.339844, 1.234375, 1.121094, 1.000000
};

static constexpr double desired_piece_count_weights_total = [](){
double tot = 0;
for (auto w : desired_piece_count_weights)
tot += w;
return tot;
}();

// Basically ignore n lower piece counts.
static constexpr int nth_base = 3;
static constexpr int max_skipping_factor = 5;

int pc_hist[33] = {0};

for (const auto& entry : buffer)
{
const int pc = entry.pos.piecesBB().count();
pc_hist[pc] += 1;
}

std::vector<std::pair<int, double>> adjusted_hist_by_pc;
for (int i = 2; i <= 32; ++i)
{
adjusted_hist_by_pc.emplace_back(i, pc_hist[i] / desired_piece_count_weights[i]);
}
std::sort(adjusted_hist_by_pc.begin(), adjusted_hist_by_pc.end(), [](const std::pair<int, double>& lhs, const std::pair<int, double>& rhs){
return lhs.second < rhs.second;
});
const int base = adjusted_hist_by_pc[nth_base].second;
int pc_hist_desired[33] = {0};
for (int i = 2; i <= 32; ++i)
{
pc_hist_desired[i] = std::max(
static_cast<int>(base * desired_piece_count_weights[i]),
static_cast<int>(pc_hist[i] / max_skipping_factor)
);
}

auto begin = buffer.begin();
auto end = buffer.end();
while (begin != end)
{
const int pc = begin->pos.piecesBB().count();
if (pc_hist_desired[pc] > 0)
{
pc_hist_desired[pc] -= 1;
++begin;
}
else
{
--end;
std::swap(*begin, *end);
}
}
begin = buffer.begin(); // reassign begin because we were using it as a current pointer

constexpr bool do_debug_print = false;
if (do_debug_print) {
int pc_hist2[33] = {0};
for (auto b = begin; b != end; ++b)
{
const int pc = b->pos.piecesBB().count();
pc_hist2[pc] += 1;
}
std::cout << "Total : " << buffer.size() << '\n';
std::cout << "Passed: " << (end - begin) << '\n';
for (int i = 0; i < 33; ++i)
std::cout << i << ' ' << pc_hist2[i] << '\n';
}
buffer.resize(end - begin);
};
}

std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered, int random_fen_skipping, bool wld_filtered)
{
if (filtered || random_fen_skipping || wld_filtered)
Expand All @@ -819,33 +924,8 @@ std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered,
wld_filtered
](const TrainingDataEntry& e){

static constexpr double desired_piece_count_weights[33] = {
1.000000,
1.121094, 1.234375, 1.339844, 1.437500, 1.527344, 1.609375, 1.683594, 1.750000,
1.808594, 1.859375, 1.902344, 1.937500, 1.964844, 1.984375, 1.996094, 2.000000,
1.996094, 1.984375, 1.964844, 1.937500, 1.902344, 1.859375, 1.808594, 1.750000,
1.683594, 1.609375, 1.527344, 1.437500, 1.339844, 1.234375, 1.121094, 1.000000
};

static constexpr double desired_piece_count_weights_total = [](){
double tot = 0;
for (auto w : desired_piece_count_weights)
tot += w;
return tot;
}();

static thread_local std::mt19937 gen(std::random_device{}());

// keep stats on passing pieces
static thread_local double alpha = 1;
static thread_local double piece_count_history_all[33] = {0};
static thread_local double piece_count_history_passed[33] = {0};
static thread_local double piece_count_history_all_total = 0;
static thread_local double piece_count_history_passed_total = 0;

// max skipping rate
static constexpr double max_skipping_rate = 10.0;

auto do_wld_skip = [&]() {
std::bernoulli_distribution distrib(1.0 - e.score_result_prob());
auto& prng = rng::get_thread_local_rng();
Expand All @@ -871,47 +951,6 @@ std::function<bool(const TrainingDataEntry&)> make_skip_predicate(bool filtered,
if (wld_filtered && do_wld_skip())
return true;

constexpr bool do_debug_print = false;
if (do_debug_print) {
if (uint64_t(piece_count_history_all_total) % 10000 == 0) {
std::cout << "Total : " << piece_count_history_all_total << '\n';
std::cout << "Passed: " << piece_count_history_passed_total << '\n';
for (int i = 0; i < 33; ++i)
std::cout << i << ' ' << piece_count_history_passed[i] << '\n';
}
}

const int pc = e.pos.piecesBB().count();
piece_count_history_all[pc] += 1;
piece_count_history_all_total += 1;

// update alpha, which scales the filtering probability, to a maximum rate.
if (uint64_t(piece_count_history_all_total) % 10000 == 0) {
double pass = piece_count_history_all_total * desired_piece_count_weights_total;
for (int i = 0; i < 33; ++i)
{
if (desired_piece_count_weights[pc] > 0)
{
double tmp = piece_count_history_all_total * desired_piece_count_weights[pc] /
(desired_piece_count_weights_total * piece_count_history_all[pc]);
if (tmp < pass)
pass = tmp;
}
}
alpha = 1.0 / (pass * max_skipping_rate);
}

double tmp = alpha * piece_count_history_all_total * desired_piece_count_weights[pc] /
(desired_piece_count_weights_total * piece_count_history_all[pc]);
tmp = std::min(1.0, tmp);
std::bernoulli_distribution distrib(1.0 - tmp);
auto& prng = rng::get_thread_local_rng();
if (distrib(prng))
return true;

piece_count_history_passed[pc] += 1;
piece_count_history_passed_total += 1;

return false;
};
}
Expand Down Expand Up @@ -997,39 +1036,40 @@ extern "C" {
bool filtered, int random_fen_skipping, bool wld_filtered)
{
auto skipPredicate = make_skip_predicate(filtered, random_fen_skipping, wld_filtered);
auto bufferFilter = make_buffer_filter();

std::string_view feature_set(feature_set_c);
if (feature_set == "HalfKP")
{
return new FeaturedBatchStream<FeatureSet<HalfKP>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
return new FeaturedBatchStream<FeatureSet<HalfKP>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate, bufferFilter);
}
else if (feature_set == "HalfKP^")
{
return new FeaturedBatchStream<FeatureSet<HalfKPFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
return new FeaturedBatchStream<FeatureSet<HalfKPFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate, bufferFilter);
}
else if (feature_set == "HalfKA")
{
return new FeaturedBatchStream<FeatureSet<HalfKA>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
return new FeaturedBatchStream<FeatureSet<HalfKA>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate, bufferFilter);
}
else if (feature_set == "HalfKA^")
{
return new FeaturedBatchStream<FeatureSet<HalfKAFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
return new FeaturedBatchStream<FeatureSet<HalfKAFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate, bufferFilter);
}
else if (feature_set == "HalfKAv2")
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
return new FeaturedBatchStream<FeatureSet<HalfKAv2>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate, bufferFilter);
}
else if (feature_set == "HalfKAv2^")
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2Factorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
return new FeaturedBatchStream<FeatureSet<HalfKAv2Factorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate, bufferFilter);
}
else if (feature_set == "HalfKAv2_hm")
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hm>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hm>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate, bufferFilter);
}
else if (feature_set == "HalfKAv2_hm^")
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hmFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hmFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate, bufferFilter);
}
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
return nullptr;
Expand Down