Skip to content

Commit

Permalink
continuation Correction History
Browse files Browse the repository at this point in the history
  • Loading branch information
OmerFarukTutkun committed Oct 4, 2024
1 parent 81c1d31 commit 01d6d00
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
6 changes: 6 additions & 0 deletions src/movepick.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ using CapturePieceToHistory = Stats<int16_t, 10692, PIECE_NB, SQUARE_NB, PIECE_T
// PieceToHistory is like ButterflyHistory but is addressed by a move's [piece][to]
using PieceToHistory = Stats<int16_t, 29952, PIECE_NB, SQUARE_NB>;

// PieceToCorrectionHistory is addressed by a move's [piece][to]
using PieceToCorrectionHistory = Stats<int16_t, CORRECTION_HISTORY_LIMIT, PIECE_NB, SQUARE_NB>;

// ContinuationHistory is the combined history of a given pair of moves, usually
// the current one given a previous one. The nested history table is based on
// PieceToHistory instead of ButterflyBoards.
Expand Down Expand Up @@ -179,6 +182,9 @@ using MinorPieceCorrectionHistory =
using NonPawnCorrectionHistory =
Stats<int16_t, CORRECTION_HISTORY_LIMIT, COLOR_NB, CORRECTION_HISTORY_SIZE>;

// ContinuationCorrectionHistory is the combined correction history of a given pair of moves
using ContinuationCorrectionHistory = Stats<PieceToCorrectionHistory, NOT_USED, PIECE_NB, SQUARE_NB>;

// The MovePicker class is used to pick one pseudo-legal move at a time from the
// current position. The most important method is next_move(), which emits one
// new pseudo-legal move on every call, until there are no moves left, when
Expand Down
33 changes: 26 additions & 7 deletions src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,19 @@ constexpr int futility_move_count(bool improving, Depth depth) {

// Add correctionHistory value to raw staticEval and guarantee evaluation
// does not hit the tablebase range.
Value to_corrected_static_eval(Value v, const Worker& w, const Position& pos) {
Value to_corrected_static_eval(Value v, const Worker& w, const Position& pos,Stack * ss) {
const Color us = pos.side_to_move();
const auto m = (ss-1)->currentMove;
const auto pcv = w.pawnCorrectionHistory[us][pawn_structure_index<Correction>(pos)];
const auto mcv = w.materialCorrectionHistory[us][material_index(pos)];
const auto macv = w.majorPieceCorrectionHistory[us][major_piece_index(pos)];
const auto micv = w.minorPieceCorrectionHistory[us][minor_piece_index(pos)];
const auto wnpcv = w.nonPawnCorrectionHistory[WHITE][us][non_pawn_index<WHITE>(pos)];
const auto bnpcv = w.nonPawnCorrectionHistory[BLACK][us][non_pawn_index<BLACK>(pos)];
const auto cntcv = (*(ss - 2)->continuationCorrectionHistory)[pos.piece_on(m.to_sq())][m.to_sq()];

const auto cv =
(6245 * pcv + 3442 * mcv + 3471 * macv + 5958 * micv + 6566 * (wnpcv + bnpcv)) / 131072;
(5932 * pcv + 2994 * mcv + 3269 * macv + 5660 * micv + 6237 * (wnpcv + bnpcv) + cntcv*5555) / 131072;
v += cv;
return std::clamp(v, VALUE_TB_LOSS_IN_MAX_PLY + 1, VALUE_TB_WIN_IN_MAX_PLY - 1);
}
Expand Down Expand Up @@ -240,6 +243,8 @@ void Search::Worker::iterative_deepening() {
{
(ss - i)->continuationHistory =
&this->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel
(ss - i)->continuationCorrectionHistory =
&this->continuationCorrectionHistory[NO_PIECE][0];
(ss - i)->staticEval = VALUE_NONE;
}

Expand Down Expand Up @@ -504,6 +509,10 @@ void Search::Worker::clear() {
nonPawnCorrectionHistory[WHITE].fill(0);
nonPawnCorrectionHistory[BLACK].fill(0);

for (auto& to : continuationCorrectionHistory)
for (auto& h : to)
h->fill(0);

for (bool inCheck : {false, true})
for (StatsType c : {NoCaptures, Captures})
for (auto& to : continuationHistory[inCheck][c])
Expand Down Expand Up @@ -727,7 +736,7 @@ Value Search::Worker::search(
else if (PvNode)
Eval::NNUE::hint_common_parent_position(pos, networks[numaAccessToken], refreshTable);

ss->staticEval = eval = to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos);
ss->staticEval = eval = to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos, ss);

// ttValue can be used as a better position evaluation (~7 Elo)
if (ttData.value != VALUE_NONE
Expand All @@ -738,7 +747,7 @@ Value Search::Worker::search(
{
unadjustedStaticEval =
evaluate(networks[numaAccessToken], pos, refreshTable, thisThread->optimism[us]);
ss->staticEval = eval = to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos);
ss->staticEval = eval = to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos, ss);

// Static evaluation is saved as it was before adjustment by correction history
ttWriter.write(posKey, VALUE_NONE, ss->ttPv, BOUND_NONE, DEPTH_UNSEARCHED, Move::none(),
Expand Down Expand Up @@ -795,6 +804,7 @@ Value Search::Worker::search(

ss->currentMove = Move::null();
ss->continuationHistory = &thisThread->continuationHistory[0][0][NO_PIECE][0];
ss->continuationCorrectionHistory = &thisThread->continuationCorrectionHistory[NO_PIECE][0];

pos.do_null_move(st, tt);

Expand Down Expand Up @@ -876,6 +886,8 @@ Value Search::Worker::search(
ss->currentMove = move;
ss->continuationHistory =
&this->continuationHistory[ss->inCheck][true][pos.moved_piece(move)][move.to_sq()];
ss->continuationCorrectionHistory =
&this->continuationCorrectionHistory[pos.moved_piece(move)][move.to_sq()];

thisThread->nodes.fetch_add(1, std::memory_order_relaxed);
pos.do_move(move, st);
Expand Down Expand Up @@ -1124,7 +1136,8 @@ Value Search::Worker::search(
ss->currentMove = move;
ss->continuationHistory =
&thisThread->continuationHistory[ss->inCheck][capture][movedPiece][move.to_sq()];

ss->continuationCorrectionHistory =
&thisThread->continuationCorrectionHistory[movedPiece][move.to_sq()];
uint64_t nodeCount = rootNode ? uint64_t(nodes) : 0;

// Step 16. Make the move
Expand Down Expand Up @@ -1401,6 +1414,8 @@ Value Search::Worker::search(
&& !(bestValue >= beta && bestValue <= ss->staticEval)
&& !(!bestMove && bestValue >= ss->staticEval))
{
const auto m = (ss-1)->currentMove;

auto bonus = std::clamp(int(bestValue - ss->staticEval) * depth / 8,
-CORRECTION_HISTORY_LIMIT / 4, CORRECTION_HISTORY_LIMIT / 4);
thisThread->pawnCorrectionHistory[us][pawn_structure_index<Correction>(pos)]
Expand All @@ -1412,6 +1427,7 @@ Value Search::Worker::search(
<< bonus * 123 / 128;
thisThread->nonPawnCorrectionHistory[BLACK][us][non_pawn_index<BLACK>(pos)]
<< bonus * 165 / 128;
(*(ss - 2)->continuationCorrectionHistory)[pos.piece_on(m.to_sq())][m.to_sq()] << bonus;
}

assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE);
Expand Down Expand Up @@ -1507,7 +1523,7 @@ Value Search::Worker::qsearch(Position& pos, Stack* ss, Value alpha, Value beta)
unadjustedStaticEval =
evaluate(networks[numaAccessToken], pos, refreshTable, thisThread->optimism[us]);
ss->staticEval = bestValue =
to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos);
to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos, ss);

// ttValue can be used as a better position evaluation (~13 Elo)
if (std::abs(ttData.value) < VALUE_TB_WIN_IN_MAX_PLY
Expand All @@ -1522,7 +1538,7 @@ Value Search::Worker::qsearch(Position& pos, Stack* ss, Value alpha, Value beta)
? evaluate(networks[numaAccessToken], pos, refreshTable, thisThread->optimism[us])
: -(ss - 1)->staticEval;
ss->staticEval = bestValue =
to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos);
to_corrected_static_eval(unadjustedStaticEval, *thisThread, pos, ss);
}

// Stand pat. Return immediately if static value is at least beta
Expand Down Expand Up @@ -1619,6 +1635,9 @@ Value Search::Worker::qsearch(Position& pos, Stack* ss, Value alpha, Value beta)
ss->continuationHistory =
&thisThread
->continuationHistory[ss->inCheck][capture][pos.moved_piece(move)][move.to_sq()];
ss->continuationCorrectionHistory =
&thisThread
->continuationCorrectionHistory[pos.moved_piece(move)][move.to_sq()];

// Step 7. Make and search the move
thisThread->nodes.fetch_add(1, std::memory_order_relaxed);
Expand Down
2 changes: 2 additions & 0 deletions src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ namespace Search {
struct Stack {
Move* pv;
PieceToHistory* continuationHistory;
PieceToCorrectionHistory* continuationCorrectionHistory;
int ply;
Move currentMove;
Move excludedMove;
Expand Down Expand Up @@ -289,6 +290,7 @@ class Worker {
MajorPieceCorrectionHistory majorPieceCorrectionHistory;
MinorPieceCorrectionHistory minorPieceCorrectionHistory;
NonPawnCorrectionHistory nonPawnCorrectionHistory[COLOR_NB];
ContinuationCorrectionHistory continuationCorrectionHistory;

private:
void iterative_deepening();
Expand Down

0 comments on commit 01d6d00

Please sign in to comment.