From 5628f76544d941eda2034366b986157dc6f328ca Mon Sep 17 00:00:00 2001 From: Fabian Fichter Date: Sun, 7 Jul 2024 17:29:36 +0200 Subject: [PATCH] Support various filtering --- src/tools/training_data_generator.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/tools/training_data_generator.cpp b/src/tools/training_data_generator.cpp index e5534b69..cf4c2ac4 100644 --- a/src/tools/training_data_generator.cpp +++ b/src/tools/training_data_generator.cpp @@ -95,6 +95,9 @@ namespace Stockfish::Tools bool write_out_draw_game_in_training_data_generation = true; bool detect_draw_by_consecutive_low_score = true; bool detect_draw_by_insufficient_mating_material = true; + bool filter_captures = false; + bool filter_checks = false; + bool filter_promotions = false; uint64_t num_threads; @@ -356,7 +359,11 @@ namespace Stockfish::Tools // Filter for static positions using abs(qsearch_value - eval_value) // sync_cout << pos.fen() << " | " << search_value << " | " << qsearch_value << " | " << eval_value << sync_endl; if (ply >= params.write_minply && !was_seen_before(pos) - && !pos.checkers() && pos.nnue_applicable() && std::abs(qsearch_value - eval_value) <= params.eval_diff_limit) + && !pos.checkers() && pos.nnue_applicable() + && std::abs(qsearch_value - eval_value) <= params.eval_diff_limit + && !(params.filter_captures && pos.capture(search_pv[0])) + && !(params.filter_checks && pos.gives_check(search_pv[0])) + && !(params.filter_promotions && (type_of(search_pv[0]) == PROMOTION || type_of(search_pv[0]) == PIECE_PROMOTION))) { auto& psv = packed_sfens.emplace_back(); @@ -797,6 +804,12 @@ namespace Stockfish::Tools is >> params.detect_draw_by_consecutive_low_score; else if (token == "adjudicate_draws_by_insufficient_material") is >> params.detect_draw_by_insufficient_mating_material; + else if (token == "filter_captures") + is >> params.filter_captures; + else if (token == "filter_checks") + is >> params.filter_checks; + else if (token == "filter_promotions") + is >> params.filter_promotions; else if (token == "data_format") is >> sfen_format; else if (token == "seed") @@ -875,7 +888,10 @@ namespace Stockfish::Tools << " - random_file_name = " << random_file_name << endl << " - write_drawn_games = " << params.write_out_draw_game_in_training_data_generation << endl << " - draw by low score = " << params.detect_draw_by_consecutive_low_score << endl - << " - draw by insuff. mat. = " << params.detect_draw_by_insufficient_mating_material << endl; + << " - draw by insuff. mat. = " << params.detect_draw_by_insufficient_mating_material << endl + << " - filter_captures = " << params.filter_captures << endl + << " - filter_checks = " << params.filter_checks << endl + << " - filter_promotions = " << params.filter_promotions << endl; // Show if the training data generator uses NNUE. Eval::NNUE::verify();