-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
77 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
Three csv files. | ||
pgn: has header uid,rating,FEN,solution | ||
Proofgame_pgn: uid,rating,FEN,solution,proofgame | ||
Original_pgn: has header uid,rating,pgn,solution | ||
Create a new one with header uid,rating,pgn,proofgame,solution. | ||
Do not include columns where proofgame is None. | ||
""" | ||
|
||
import pandas as pd | ||
import argparse | ||
import os | ||
|
||
def merge_files(data_dir, pgn_file, proofgame_file, original_file, output_file): | ||
# Load the csv files | ||
pgn = pd.read_csv(os.path.join(data_dir, pgn_file)) | ||
proofgame_pgn = pd.read_csv(os.path.join(data_dir, proofgame_file)) | ||
original_pgn = pd.read_csv(os.path.join(data_dir, original_file)) | ||
# original pgn may not have a header | ||
if 'uid' not in original_pgn.columns: | ||
original_pgn.columns = ['uid', 'rating', 'pgn', 'solution'] | ||
|
||
# Merge the dataframes | ||
merged_df = pd.merge(original_pgn, proofgame_pgn[['uid', 'proofgame']], on='uid', how='inner') | ||
|
||
# Drop rows where proofgame is None | ||
merged_df = merged_df[merged_df['proofgame'].notna()] | ||
|
||
# Write the merged dataframe to a new csv file | ||
merged_df.to_csv(output_file, index=False) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--data_dir", "-d", default="/data/chess-data/lichess_puzzles/", help="Directory containing the data files. Can be / if you want to use full paths") | ||
parser.add_argument("--pgn_file", "-p", default="fen_puzzles.csv", help="Name of the pgn file") | ||
parser.add_argument("--proofgame_file", "-pg", default="proofgame_pgns.csv", help="Name of the proofgame file") | ||
parser.add_argument("--original_file", "-o", default="pgn_puzzles.csv", help="Name of the original pgn file") | ||
parser.add_argument("--output", "-out", default="/data/chess-data/lichess_puzzles/pairs.csv", help="Name of the output file") | ||
args = parser.parse_args() | ||
|
||
merge_files(args.data_dir, args.pgn_file, args.proofgame_file, args.original_file, output_file=args.output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import chess | ||
import numpy as np | ||
import io | ||
import json | ||
import csv | ||
from pathlib import Path | ||
from tqdm import tqdm | ||
from puzzle_solver import convert_pgn_to_game, solve_puzzle | ||
import chessllm | ||
|
||
DATA_DIR = Path("/data/chess-data/lichess_puzzles") | ||
FILE_NAME = DATA_DIR / "pairs.csv" | ||
|
||
""" | ||
Solve puzzle pairs given in FILE_NAME, and report whether the model can solve them. | ||
Separate by rating buckets; take 40 samples from each bucket. | ||
It has the following columns: uid,rating,pgn,proofgame,solution. | ||
Helper functions: | ||
def solve_puzzle(board, solution) -> bool: whether model can solve the puzzle | ||
convert_pgn_to_game(pgn_moves) -> game | ||
""" | ||
|
||
|
||
def main(): | ||
raise NotImplementedError("This script is not finished") | ||
|
||
|
||
if __name__ == "__main__": | ||
api_key = open("OPENAI_API_KEY").read().strip() | ||
config = json.loads(open("config.json").read()) | ||
engine = chessllm.ChessLLM(api_key, config, num_lookahead_tokens=30) | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters