From 43a3138bda78ddbb4878b2d98668681960cec649 Mon Sep 17 00:00:00 2001 From: Daniel Paleka Date: Wed, 25 Oct 2023 16:30:42 +0200 Subject: [PATCH] experiment with pairs of pgns, with gpt-3.5-turbo --- make_pairs_puzzles_dataset.py | 2 ++ puzzle_pair_solve.py | 65 ++++++++++++++++++++++++++++++++--- puzzle_solver.py | 6 ++-- 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/make_pairs_puzzles_dataset.py b/make_pairs_puzzles_dataset.py index ecf366e..09c82b6 100644 --- a/make_pairs_puzzles_dataset.py +++ b/make_pairs_puzzles_dataset.py @@ -28,6 +28,8 @@ def merge_files(data_dir, pgn_file, proofgame_file, original_file, output_file): merged_df = merged_df[merged_df['proofgame'].notna()] # Write the merged dataframe to a new csv file + # reorder so it's uid, rating, pgn, proofgame, solution + merged_df = merged_df[['uid', 'rating', 'pgn', 'proofgame', 'solution']] merged_df.to_csv(output_file, index=False) if __name__ == "__main__": diff --git a/puzzle_pair_solve.py b/puzzle_pair_solve.py index f9aa514..295bdbb 100644 --- a/puzzle_pair_solve.py +++ b/puzzle_pair_solve.py @@ -14,20 +14,77 @@ """ Solve puzzle pairs given in FILE_NAME, and report whether the model can solve them. Separate by rating buckets; take 40 samples from each bucket. -It has the following columns: uid,rating,pgn,proofgame,solution. +It has the following columns: uid,rating,pgn,proofgame,solution Helper functions: def solve_puzzle(board, solution) -> bool: whether model can solve the puzzle convert_pgn_to_game(pgn_moves) -> game """ +import chess +import numpy as np +import io +import json +import csv +from pathlib import Path +from tqdm import tqdm +from puzzle_solver import convert_pgn_to_game, solve_puzzle +import chessllm + +DATA_DIR = Path("/data/chess-data/lichess_puzzles") +FILE_NAME = DATA_DIR / "pairs.csv" + +def main(engine): + # Create buckets + bucket_size = 200 + buckets = {i*bucket_size: [] for i in range(30)} + enough_samples = 10 + + # Read the data and sort into buckets + with open(FILE_NAME) as f: + reader = csv.reader(f) + print(reader.__next__()) + for uid, rating, pgn, proofgame, solution in tqdm(list(reader)): + rating_bucket = int(rating) // bucket_size * bucket_size + if len(buckets[rating_bucket]) < enough_samples: + buckets[rating_bucket].append((pgn, proofgame, solution)) + + # print how many elems in buckets + for k, v in buckets.items(): + print(f'rating [{k}, {k + bucket_size})', 'n', len(v)) + + # Test the puzzles + ok_pgn = {i*bucket_size: [] for i in range(30)} + ok_proofgame = {i*bucket_size: [] for i in range(30)} + for rating_bucket, puzzles in tqdm(buckets.items()): + for pgn, proofgame, solution in puzzles: + board_pgn = chess.Board() + board_proofgame = chess.Board() + + print("pgn origi", pgn) + print("proofgame", proofgame) + # Iterate over the moves and apply them to the board + for move in convert_pgn_to_game(pgn).mainline_moves(): + board_pgn.push(move) + for move in convert_pgn_to_game(proofgame).mainline_moves(): + board_proofgame.push(move) + + is_right_pgn = solve_puzzle(board_pgn, solution, engine) + is_right_proofgame = solve_puzzle(board_proofgame, solution, engine) -def main(): - raise NotImplementedError("This script is not finished") + ok_pgn[rating_bucket].append(is_right_pgn) + ok_proofgame[rating_bucket].append(is_right_proofgame) + # Compare the results + for i in range(30): + bucket_start = i * bucket_size + if len(ok_pgn[bucket_start]) > 0 and len(ok_proofgame[bucket_start]) > 0: + pgn_acc = np.mean(ok_pgn[bucket_start]) + proofgame_acc = np.mean(ok_proofgame[bucket_start]) + print(f'rating [{bucket_start}, {bucket_start + bucket_size})', f'pgn acc {pgn_acc:.3f}', f'proofgame acc {proofgame_acc:.3f}', 'n', len(ok_pgn[bucket_start])) if __name__ == "__main__": api_key = open("OPENAI_API_KEY").read().strip() config = json.loads(open("config.json").read()) engine = chessllm.ChessLLM(api_key, config, num_lookahead_tokens=30) - main() \ No newline at end of file + main(engine) diff --git a/puzzle_solver.py b/puzzle_solver.py index 37113e4..4d6c32f 100644 --- a/puzzle_solver.py +++ b/puzzle_solver.py @@ -24,13 +24,15 @@ from tqdm import tqdm def convert_pgn_to_game(pgn_moves): + print("pgn_moves", pgn_moves) pgn = io.StringIO(pgn_moves) + print("pgn", pgn) game = chess.pgn.read_game(pgn) if len(game.errors) > 0: return None return game -def solve_puzzle(board, solution): +def solve_puzzle(board, solution, engine): print("Solving puzzle", board.fen(), solution) solution = solution.split() while True: @@ -84,7 +86,7 @@ def main(): for move in convert_pgn_to_game(pgn).mainline_moves(): board.push(move) - is_right = solve_puzzle(board, solution) + is_right = solve_puzzle(board, solution, engine) ok[rating_bucket].append(is_right)