Skip to content

Commit

Permalink
wip testing pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
dpaleka committed Oct 24, 2023
1 parent 21df1fd commit 1991a8a
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
42 changes: 42 additions & 0 deletions make_pairs_puzzles_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Three csv files.
pgn: has header uid,rating,FEN,solution
Proofgame_pgn: uid,rating,FEN,solution,proofgame
Original_pgn: has header uid,rating,pgn,solution
Create a new one with header uid,rating,pgn,proofgame,solution.
Do not include columns where proofgame is None.
"""

import pandas as pd
import argparse
import os

def merge_files(data_dir, pgn_file, proofgame_file, original_file, output_file):
# Load the csv files
pgn = pd.read_csv(os.path.join(data_dir, pgn_file))
proofgame_pgn = pd.read_csv(os.path.join(data_dir, proofgame_file))
original_pgn = pd.read_csv(os.path.join(data_dir, original_file))
# original pgn may not have a header
if 'uid' not in original_pgn.columns:
original_pgn.columns = ['uid', 'rating', 'pgn', 'solution']

# Merge the dataframes
merged_df = pd.merge(original_pgn, proofgame_pgn[['uid', 'proofgame']], on='uid', how='inner')

# Drop rows where proofgame is None
merged_df = merged_df[merged_df['proofgame'].notna()]

# Write the merged dataframe to a new csv file
merged_df.to_csv(output_file, index=False)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", "-d", default="/data/chess-data/lichess_puzzles/", help="Directory containing the data files. Can be / if you want to use full paths")
parser.add_argument("--pgn_file", "-p", default="fen_puzzles.csv", help="Name of the pgn file")
parser.add_argument("--proofgame_file", "-pg", default="proofgame_pgns.csv", help="Name of the proofgame file")
parser.add_argument("--original_file", "-o", default="pgn_puzzles.csv", help="Name of the original pgn file")
parser.add_argument("--output", "-out", default="/data/chess-data/lichess_puzzles/pairs.csv", help="Name of the output file")
args = parser.parse_args()

merge_files(args.data_dir, args.pgn_file, args.proofgame_file, args.original_file, output_file=args.output)
33 changes: 33 additions & 0 deletions puzzle_pair_solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import chess
import numpy as np
import io
import json
import csv
from pathlib import Path
from tqdm import tqdm
from puzzle_solver import convert_pgn_to_game, solve_puzzle
import chessllm

DATA_DIR = Path("/data/chess-data/lichess_puzzles")
FILE_NAME = DATA_DIR / "pairs.csv"

"""
Solve puzzle pairs given in FILE_NAME, and report whether the model can solve them.
Separate by rating buckets; take 40 samples from each bucket.
It has the following columns: uid,rating,pgn,proofgame,solution.
Helper functions:
def solve_puzzle(board, solution) -> bool: whether model can solve the puzzle
convert_pgn_to_game(pgn_moves) -> game
"""


def main():
raise NotImplementedError("This script is not finished")


if __name__ == "__main__":
api_key = open("OPENAI_API_KEY").read().strip()
config = json.loads(open("config.json").read())
engine = chessllm.ChessLLM(api_key, config, num_lookahead_tokens=30)
main()
3 changes: 2 additions & 1 deletion puzzle_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def solve_puzzle(board, solution):


DATA_DIR = Path("/data/chess-data/lichess_puzzles")
FILE_NAME = DATA_DIR / "pgn_puzzles.csv"

def main():

Expand All @@ -65,7 +66,7 @@ def main():
enough_samples = 40

# Read the data and sort into buckets
with open(DATA_DIR / "pgn_puzzles.csv") as f:
with open(FILE_NAME) as f:
reader = csv.reader(f)
print(reader.__next__())
for puzzleid, rating, pgn, solution in tqdm(list(reader)):
Expand Down

0 comments on commit 1991a8a

Please sign in to comment.