Skip to content

Commit

Permalink
Added pytorch position format
Browse files Browse the repository at this point in the history
  • Loading branch information
datawater committed Aug 17, 2024
1 parent 4ce1dd1 commit a3b525a
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 155 deletions.
2 changes: 1 addition & 1 deletion libcmbr/src/cmbr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ impl CmbrFile {
pub fn serialize(&self) -> Vec<u8> {
return bitcode::serialize(&self).unwrap();
}
}
}
13 changes: 6 additions & 7 deletions libcmbr/src/cmbr/pgntocmbr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl CmbrFile {
convertor: &mut SanToCmbrMvConvertor,
is_compressed: bool,
) -> Result<Self, Box<dyn Error>> {
debug_assert!(is_compressed == false);
debug_assert!(!is_compressed);

let mut file = CmbrFile::new(is_compressed);
let mut board = Chess::new();
Expand All @@ -97,7 +97,7 @@ impl CmbrFile {

let len = ast.len();

(0..len).into_iter().for_each(|game_i| {
(0..len).for_each(|game_i| {
if game_i % 1000 == 0 || game_i == len {
print!("{}\r", game_i as f64 / len as f64 * 100.0);
let _ = std::io::stdout().flush();
Expand Down Expand Up @@ -126,7 +126,7 @@ impl CmbrFile {

// SAFE: Safe
Token::TagString(v) => unsafe {
let _ = cmbr_game.headers.push((
cmbr_game.headers.push((
from_utf8_unchecked(current_key).to_owned(),
from_utf8_unchecked(v).to_owned(),
));
Expand All @@ -144,7 +144,7 @@ impl CmbrFile {
variation_pointers.insert(0, 0);

for (id, variation) in variations_iter {
if variation.0.len() == 0 {
if variation.0.is_empty() {
eprintln!("[WARN] Empty variation on game N{game_i}. Skipping game");
break;
}
Expand Down Expand Up @@ -198,8 +198,7 @@ impl CmbrFile {
Token::NAG(n) => {
let mut nag_numeral =
// SAFE: Safe
u32::from_str_radix(unsafe { from_utf8_unchecked(*n) }, 10)
.unwrap();
(unsafe { from_utf8_unchecked(n) }).parse::<u32>().unwrap();

nag_numeral <<= 8;
nag_numeral |= 0b00001000;
Expand Down Expand Up @@ -237,7 +236,7 @@ impl CmbrFile {
}

Token::MoveAnnotation(an) => cmbr_variation.moves.push(
(((MOVE_ANNOTATION_TO_NAG[an] as u32) << 8) as u32 | 0b10000000)
(((MOVE_ANNOTATION_TO_NAG[an] as u32) << 8) | 0b10000000)
.into(),
),

Expand Down
8 changes: 4 additions & 4 deletions libcmbr/src/cmbr/santocmbrmv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ impl SanToCmbrMvConvertor {
promotion,
} => Self::shakmaty_move_to_cmbr(
role,
&from,
&to,
from,
to,
&capture.is_some(),
promotion,
&san.suffix,
Expand Down Expand Up @@ -134,8 +134,8 @@ impl SanToCmbrMvConvertor {

shakmaty::Move::EnPassant { from, to } => Self::shakmaty_move_to_cmbr(
&Role::Pawn,
&from,
&to,
from,
to,
&true,
&None,
&san.suffix,
Expand Down
6 changes: 3 additions & 3 deletions libcmbr/src/cmbr/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def_enum! (
FlagCapture => 1 << 2,
FlagNag => 1 << 3, // If this flag is set, the first 8 bits of the CMBR are replaced with a NAG index (https://w.wiki/AWUT)

FlagPromotesBishop => (1 << 6) | 0b000000,
FlagPromotesBishop => (1 << 6),
FlagPromotesKnight => (1 << 6) | 0b010000,
FlagPromotesRook => (1 << 6) | 0b100000,
FlagPromotesQueen => (1 << 6) | 0b110000,
Expand Down Expand Up @@ -63,7 +63,7 @@ pub struct CmbrFile {

/// A Struct denoting the structure of a game represented in CMBR
#[cfg_attr(feature = "bitcode", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, PartialEq, Eq, Clone)]
#[derive(Debug, PartialEq, Eq, Clone, Default)]
pub struct CmbrGame {
pub headers: Vec<(String, String)>,
/// Possible values: 'w', 'b', 'd', 'u'.
Expand Down Expand Up @@ -122,4 +122,4 @@ impl CmbrVariation {
comments: Vec::new(),
};
}
}
}
120 changes: 58 additions & 62 deletions libcmbr/src/cmbr/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,11 @@ mod cmbr_tests {
let mut board = Chess::new();

for token in variation.0 {
match token {
PgnToken::Token(t) => match t {
Token::Move(san) => {
let cmbr = convertor.san_to_cmbr(&mut board, san).unwrap();
cmbrs.push(cmbr);
}
_ => {}
},

_ => {}
if let PgnToken::Token(t) = token {
if let Token::Move(san) = t {
let cmbr = convertor.san_to_cmbr(&mut board, san).unwrap();
cmbrs.push(cmbr);
}
}
}
}
Expand All @@ -77,57 +72,58 @@ mod cmbr_tests {

assert_eq!(expected_vec, cmbrs);
}

#[cfg(feature = "benchmark")]
#[bench]
fn bench_san_cmbr(b: &mut Bencher) {
let file_path = get_project_root().unwrap().join("data/twic1544.pgn");
let file = File::open(file_path.clone());

if file.is_err() {
panic!(
"[ERROR] {}. File path: {:?}",
file.err().unwrap(),
file_path
);
}

// SAFE: Safe
let file = unsafe { file.unwrap_unchecked() };
let mmap = unsafe { Mmap::map(&file) };

if mmap.is_err() {
panic!("[ERROR] {}", mmap.err().unwrap());
}

let mut mmap = mmap.unwrap();
let ast = pgn::parse_pgn(&mut mmap);
let mut convertor = SanToCmbrMvConvertor::new(/* 128MB */ 128 * 1024 * 1024);

b.iter(|| {
'game: for game in &ast {
// This clone is fucking it up
for (_, variation) in (&game).0 .1.clone() {
let mut board = Chess::new();

for token in variation.0 {
match token {
PgnToken::Token(t) => match t {
Token::Move(san) => {
let cmbr = convertor.san_to_cmbr(&mut board, san);

if cmbr.is_err() {
continue 'game;
}
}
_ => {}
},

_ => {}
}
}
}
}
});
}
// FIXME: bench_san_cmbr is broken
// #[cfg(feature = "benchmark")]
// #[bench]
// fn bench_san_cmbr(b: &mut Bencher) {
// let file_path = get_project_root().unwrap().join("data/twic1544.pgn");
// let file = File::open(file_path.clone());

// if file.is_err() {
// panic!(
// "[ERROR] {}. File path: {:?}",
// file.err().unwrap(),
// file_path
// );
// }

// // SAFE: Safe
// let file = unsafe { file.unwrap_unchecked() };
// let mmap = unsafe { Mmap::map(&file) };

// if mmap.is_err() {
// panic!("[ERROR] {}", mmap.err().unwrap());
// }

// let mut mmap = mmap.unwrap();
// let ast = pgn::parse_pgn(&mut mmap);
// let mut convertor = SanToCmbrMvConvertor::new(/* 128MB */ 128 * 1024 * 1024);

// b.iter(|| {
// 'game: for game in &ast {
// // This clone is fucking it up
// for (_, variation) in (&game).0 .1.clone() {
// let mut board = Chess::new();

// for token in variation.0 {
// match token {
// PgnToken::Token(t) => match t {
// Token::Move(san) => {
// let cmbr = convertor.san_to_cmbr(&mut board, san);

// if cmbr.is_err() {
// continue 'game;
// }
// }
// _ => {}
// },

// _ => {}
// }
// }
// }
// }
// });
// }
}
4 changes: 2 additions & 2 deletions libcmbr/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![allow(non_upper_case_globals)]
#![feature(test, map_try_insert)]
#![allow(non_upper_case_globals, clippy::needless_return)]
#![feature(test, map_try_insert, stmt_expr_attributes)]

use cfg_if::cfg_if;

Expand Down
20 changes: 12 additions & 8 deletions libcmbr/src/pgn/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct PgnVariation<'a>(pub Vec<PgnToken<'a>>);
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct PgnGame<'a> {
pub global_tokens: Vec<Token<'a>>,
pub variations: LiteMap<u16, PgnVariation<'a>>
pub variations: LiteMap<u16, PgnVariation<'a>>,
}

/// Builds an ast (represented as `a Vec<PgnGame>`) from the inputted Token list
Expand All @@ -44,7 +44,7 @@ pub fn build_pgn_ast<'a>(tokens: &mut VecDeque<Token<'a>>) -> Vec<PgnGame<'a>> {
value.variations.insert(0, PgnVariation::default());
}

while tokens.len() != 0 {
while !tokens.is_empty() {
next_token(
tokens,
&mut tree,
Expand All @@ -64,7 +64,7 @@ macro_rules! push_token {
$tree
.get_mut($game_number as usize)
.unwrap()
.variations
.variations
.get_mut($variation_depth)
.unwrap()
.0
Expand Down Expand Up @@ -95,14 +95,14 @@ fn next_token<'a>(
Token::TagSymbol(_) | Token::TagString(_) => tree
.get_mut(*game_number as usize)
.unwrap()
.global_tokens
.global_tokens
.push(token),
Token::NullMove(_) => {}
Token::EscapeComment(_) => { /* NOTE: IDK what to do with this */ }
Token::Result(_) => {
tree.get_mut(*game_number as usize)
.unwrap()
.global_tokens
.global_tokens
.push(token);

*game_number += 1;
Expand All @@ -114,7 +114,9 @@ fn next_token<'a>(
let value = &mut tree.get_unchecked_mut(*game_number as usize);

value.global_tokens = Vec::new();
value.variations.insert(variation_depth, PgnVariation::default());
value
.variations
.insert(variation_depth, PgnVariation::default());
}
}
Token::StartVariation(_) => {
Expand All @@ -132,7 +134,9 @@ fn next_token<'a>(
// SAFE: Safe
unsafe {
let value = &mut tree.get_unchecked_mut(*game_number as usize);
value.variations.insert(new_variation_depth, PgnVariation::default());
value
.variations
.insert(new_variation_depth, PgnVariation::default());
}

next_token(
Expand All @@ -148,7 +152,7 @@ fn next_token<'a>(
}
}

if tokens.len() != 0 {
if !tokens.is_empty() {
next_token(
tokens,
tree,
Expand Down
Loading

0 comments on commit a3b525a

Please sign in to comment.