From af00bde2615a33eb45d6e2db10d8f94438af434d Mon Sep 17 00:00:00 2001 From: romnnn Date: Thu, 11 Apr 2024 12:59:24 +0200 Subject: [PATCH] ptx: parse all CUDA SDK --- ptx/Cargo.toml | 3 + ptx/README.md | 3 + ptx/bison/Cargo.toml | 1 + ptx/bison/build.rs | 6 +- ptx/bison/src/main.rs | 10 +++ ptx/src/ast.rs | 6 +- ptx/src/lib.rs | 13 +--- ptx/src/main.rs | 46 ++++++++++++ ptx/src/parser.rs | 171 +++++++++++++++++++++++++++++++++++++----- ptx/src/ptx.pest | 169 +++++++++++++++++++++++++++-------------- 10 files changed, 335 insertions(+), 93 deletions(-) create mode 100644 ptx/src/main.rs diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index f51ecab7..56eb68b7 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -26,7 +26,10 @@ itertools = "0" object = "0" bytes = "1" +clap = { version = "4", features = [ "derive" ] } + [dev-dependencies] once_cell = "1" snailquote = "0" diff = { path = "../diff" } +regex = "1" diff --git a/ptx/README.md b/ptx/README.md index 5c2c8764..ad05f980 100644 --- a/ptx/README.md +++ b/ptx/README.md @@ -13,3 +13,6 @@ The provided libraries may in the future be used for ```bash docker run -v "$PWD/kernels/:/out" ptxsamples ``` + +i = 194 +[ptx/src/parser.rs:1909:13] &kernel.path() = "/Users/roman/dev/box/ptx/kernels/cuda_12_3_r123compiler33567101_0_sm50_newdelete.1.sm_50.ptx" diff --git a/ptx/bison/Cargo.toml b/ptx/bison/Cargo.toml index 1641a543..396141d6 100644 --- a/ptx/bison/Cargo.toml +++ b/ptx/bison/Cargo.toml @@ -13,3 +13,4 @@ color-eyre = "0" duct = "0" bindgen = "0" cc = { version = "1", features = [] } +utils = { path = "../../utils" } diff --git a/ptx/bison/build.rs b/ptx/bison/build.rs index 3cede5fc..1e8e2ccd 100644 --- a/ptx/bison/build.rs +++ b/ptx/bison/build.rs @@ -171,7 +171,7 @@ fn build_ptx_parser() -> eyre::Result<()> { // generated_ptx_lexer, // generated_ptx_parser, // ]; - let sources = [generated_files.clone(), vec![ + let sources = vec![ source_dir.join("util.cc"), source_dir.join("gpgpu.cc"), source_dir.join("gpgpu_sim.cc"), @@ -189,7 +189,9 @@ fn build_ptx_parser() -> eyre::Result<()> { source_dir.join("operand_info.cc"), source_dir.join("symbol.cc"), source_dir.join("lib.cc"), - ]].concat(); + ]; + // let sources = utils::fs::multi_glob([source_dir.join("*.cc").to_string_lossy().to_string()]).collect::>()?; + let sources = [generated_files.clone(), sources].concat(); // let sources = vec![ // source_dir.join("memory_space.cc"), // ]; diff --git a/ptx/bison/src/main.rs b/ptx/bison/src/main.rs index 64390448..976c2111 100644 --- a/ptx/bison/src/main.rs +++ b/ptx/bison/src/main.rs @@ -2,6 +2,7 @@ use color_eyre::eyre; use clap::Parser; use std::path::PathBuf; use std::ffi::CString; +use std::time::Instant; #[derive(Parser, Debug, Clone)] pub struct ParsePTXOptions { @@ -25,8 +26,17 @@ fn main() -> eyre::Result<()> { match options.command { Command::ParsePTX(ParsePTXOptions {ptx_path}) => { + let code_size_bytes = std::fs::metadata(&ptx_path)?.len(); let path = CString::new(ptx_path.to_string_lossy().as_bytes())?; + let start = Instant::now(); unsafe { ptxbison::bindings::load_ptx_from_filename(path.as_c_str().as_ptr()) }; + let dur = start.elapsed(); + let dur_millis = dur.as_millis(); + let dur_secs = dur.as_secs_f64(); + let code_size_mib = code_size_bytes as f64 / (1024.0*1024.0); + let mib_per_sec = code_size_mib / dur_secs; + println!("parsing {} took {} ms ({:3.3} MiB/s)", ptx_path.display(), dur_millis, mib_per_sec); + } } diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index cc3feb06..064f44cb 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -103,6 +103,6 @@ pub enum ParseError<'a> { // pub value: f64, // } -#[derive(Debug, FromPest)] -#[pest_ast(rule(Rule::EOI))] -struct EOI; +// #[derive(Debug, FromPest)] +// #[pest_ast(rule(Rule::EOI))] +// struct EOI; diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index afb03be0..2b6dcce6 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -1,16 +1,9 @@ #![allow(dead_code)] -#[macro_use] -extern crate pest_derive; -#[macro_use] -extern crate pest_ast; -#[macro_use] -extern crate pest; +pub mod ast; +pub mod parser; -mod ast; -mod parser; - -use crate::parser::Rule; +use parser::Rule; use ast::{ASTNode, FunctionDeclHeader, ParseError}; use color_eyre::eyre; use pest::iterators::Pair; diff --git a/ptx/src/main.rs b/ptx/src/main.rs new file mode 100644 index 00000000..90efa15c --- /dev/null +++ b/ptx/src/main.rs @@ -0,0 +1,46 @@ +use color_eyre::eyre; +use std::path::PathBuf; +use std::time::Instant; + +use ptx::parser::{Rule, Parser as PTXParser}; +use clap::Parser as ClapParser; +use pest::Parser as PestParser; + +#[derive(ClapParser, Debug, Clone)] +pub struct ParsePTXOptions { + pub ptx_path: PathBuf, +} + +#[derive(ClapParser, Debug, Clone)] +pub enum Command { + ParsePTX(ParsePTXOptions), +} + +#[derive(ClapParser, Debug, Clone)] +pub struct Options { + #[clap(subcommand)] + pub command: Command, +} + + +fn main() -> eyre::Result<()> { + color_eyre::install()?; + let options = Options::parse(); + + match options.command { + Command::ParsePTX(ParsePTXOptions {ptx_path}) => { + let ptx_code = std::fs::read_to_string(&ptx_path)?; + let code_size_bytes = ptx_code.bytes().len(); + let start = Instant::now(); + let parsed = PTXParser::parse(Rule::program, &ptx_code)?; + let dur = start.elapsed(); + let dur_millis = dur.as_millis(); + let dur_secs = dur.as_secs_f64(); + let code_size_mib = code_size_bytes as f64 / (1024.0*1024.0); + let mib_per_sec = code_size_mib / dur_secs; + println!("parsing {} took {} ms ({:3.3} MiB/s)", ptx_path.display(), dur_millis, mib_per_sec); + } + } + + Ok(()) +} diff --git a/ptx/src/parser.rs b/ptx/src/parser.rs index 81f9ad21..307d554a 100644 --- a/ptx/src/parser.rs +++ b/ptx/src/parser.rs @@ -1,4 +1,4 @@ -#[derive(Parser)] +#[derive(pest_derive::Parser)] #[grammar = "./ptx.pest"] pub struct Parser; @@ -1086,10 +1086,7 @@ mod tests { Ok(()) } - #[test] - fn opcode_precendence() -> eyre::Result<()> { - crate::tests::init_test(); - let opcodes = [ + const ALL_OPCODES: [&str; 151] = [ "abs", "addp", "addc", @@ -1243,7 +1240,11 @@ mod tests { "xor", ]; - for opcode in opcodes { + + #[test] + fn opcode_precendence() -> eyre::Result<()> { + crate::tests::init_test(); + for opcode in ALL_OPCODES { dbg!(&opcode); assert_parses_to_typed( Rule::opcode, @@ -1349,6 +1350,38 @@ param1 Ok(()) } + + #[allow(non_snake_case)] + #[test] + fn parse_variable_decl_global_align_8_u64_underscore_ztv9containeriie6_initializer( + ) -> eyre::Result<()> { + crate::tests::init_test(); + let want = r#" +(variable_decl + (variable_spec (space_spec (addressable_spec: ".global"))) + (variable_spec (align_spec (integer (decimal: "8")))) + (variable_spec (type_spec (scalar_type: ".u64"))) + (identifier_spec + (identifier: "_ZTV9ContainerIiE") + (integer (decimal: "6"))) + (variable_decl_initializer + (operand (literal_operand (integer (decimal: "0")))) + (operand (literal_operand (integer (decimal: "0")))) + (operand (identifier: "_ZN9ContainerIiED1Ev")) + (operand (identifier: "_ZN9ContainerIiED0Ev")) + (operand (literal_operand (integer (decimal: "0")))) + (operand (literal_operand (integer (decimal: "0")))) + ) +) + "#; + assert_parses_to( + Rule::variable_decl, + r#".global .align 8 .u64 _ZTV9ContainerIiE[6] = {0, 0, _ZN9ContainerIiED1Ev, _ZN9ContainerIiED0Ev, 0, 0};"#, + want, + )?; + Ok(()) + } + #[test] fn parse_prototype_decl_prototype_0_callprototype() -> eyre::Result<()> { crate::tests::init_test(); @@ -1498,6 +1531,47 @@ ld.param.b32 %r115, [retval0+0]; Ok(()) } + + #[test] + fn parse_prototype_decl_prototype_15_callprototype() -> eyre::Result<()> { + crate::tests::init_test(); + let want = r#" +(prototype_decl + (identifier: "prototype_15") + (identifier: "_") + (prototype_param + (scalar_type: ".b64") + (identifier_spec (identifier: "_")) + ) + (prototype_param + (align_spec (integer (decimal: "4"))) + (scalar_type: ".b8") + (identifier_spec + (identifier: "_") + (integer (decimal: "16"))) + ) +) + "#; + let code = r#"prototype_15 : .callprototype +()_ (.param .b64 _, .param .align 4 .b8 _[16]); + "#; + assert_parses_to( + Rule::prototype_param, + ".param .align 4 .b8 _[16]", + r#"(prototype_param + (align_spec (integer (decimal: "4"))) + (scalar_type: ".b8") + (identifier_spec + (identifier: "_") + (integer (decimal: "16"))) + ) + "#, + )?; + assert_parses_to(Rule::prototype_decl, code, want)?; + Ok(()) + } + + #[test] fn parse_extern_func_param_b32_func_retval0_vprintf() -> eyre::Result<()> { crate::tests::init_test(); @@ -1582,30 +1656,37 @@ ld.param.b32 %r115, [retval0+0]; Ok(()) } + #[test] fn parse_vshr_u32_u32_u32_clamp_add() -> eyre::Result<()> { crate::tests::init_test(); let want = r#" (instruction_statement (instruction (opcode_spec - (opcode: "ld") - (option (addressable_spec: ".global")) - (option (type_spec (scalar_type: ".b32"))) + (opcode: "vshr") + (option (type_spec (scalar_type: ".u32"))) + (option (type_spec (scalar_type: ".u32"))) + (option (type_spec (scalar_type: ".u32"))) + (option: ".clamp") + (option: ".add") ) - (operand (identifier: "r2")) - (operand (memory_operand - (identifier: "array") - (address_expression (identifier: "r1")) - )) + (operand (identifier: "%r952")) + (operand (identifier: "%r1865")) + (operand (identifier: "%r1079")) + (operand (identifier: "%r1865")) ) ) "#; assert_parses_to( Rule::opcode_spec, "vshr.u32.u32.u32.clamp.add", - r#"(memory_operand - (identifier: "array") - (address_expression (identifier: "r1")) + r#"(opcode_spec + (opcode: "vshr") + (option (type_spec (scalar_type: ".u32"))) + (option (type_spec (scalar_type: ".u32"))) + (option (type_spec (scalar_type: ".u32"))) + (option: ".clamp") + (option: ".add") )"#, )?; assert_parses_to( @@ -1631,6 +1712,7 @@ ld.param.b32 %r115, [retval0+0]; Ok(()) } + #[test] fn parse_loc_1_120_13() -> eyre::Result<()> { crate::tests::init_test(); @@ -1815,21 +1897,70 @@ ld.param.b32 %r115, [retval0+0]; } #[test] - fn parse_all_kernels() -> eyre::Result<()> { + fn extract_opcodes() -> eyre::Result<()> { use std::fs::{read_dir, read_to_string, DirEntry}; use std::path::PathBuf; + use std::collections::HashSet; + crate::tests::init_test(); - // pest::set_call_limit(std::num::NonZeroUsize::new(10000)); let kernels_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("kernels"); dbg!(&kernels_dir); let mut kernels = read_dir(&kernels_dir)? .into_iter() .collect::, _>>()?; kernels.sort_by_key(|k| k.path()); + + let all_opcodes = ALL_OPCODES.join("|"); + let opcode_regex = regex::Regex::new(&format!(r"({})(\.[\w.:]*)", all_opcodes)).unwrap(); + + // atom.add.release.gpu.u32 %r57,[%rd10],%r58; + let mut all_options = HashSet::new(); for kernel in kernels { dbg!(&kernel.path()); let ptx_code = read_to_string(kernel.path())?; - let parsed = PTXParser::parse(Rule::program, &ptx_code)?; + let captures = opcode_regex.captures_iter(&ptx_code); + for m in captures { + let options = m[2].split(".").filter(|o| !o.is_empty()).map(ToString::to_string); + all_options.extend(options); + } + } + + let mut all_options: Vec<_> = all_options.into_iter().collect(); + all_options.sort(); + dbg!(&all_options); + Ok(()) + } + + #[test] + fn all_kernels() -> eyre::Result<()> { + use std::fs::{read_dir, read_to_string, DirEntry}; + use std::path::PathBuf; + use std::time::Instant; + crate::tests::init_test(); + // pest::set_call_limit(std::num::NonZeroUsize::new(10000)); + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let kernels_dir = manifest_dir.join("kernels"); + dbg!(&kernels_dir); + let mut kernels = read_dir(&kernels_dir)? + .into_iter() + .collect::, _>>()?; + kernels.sort_by_key(|k| k.path()); + + let skip = std::env::var("SKIP").ok().map(|s| s.parse::()).transpose()?.unwrap_or(0); + + let kernels_iter = kernels.iter().enumerate().skip(skip); + + for (i, kernel) in kernels_iter { + let ptx_code = read_to_string(kernel.path())?; + let code_size_bytes = ptx_code.bytes().len(); + let start = Instant::now(); + let _parsed = PTXParser::parse(Rule::program, &ptx_code)?; + let dur = start.elapsed(); + let dur_millis = dur.as_millis(); + let dur_secs = dur.as_secs_f64(); + let code_size_mib = code_size_bytes as f64 / (1024.0*1024.0); + let mib_per_sec = code_size_mib / dur_secs; + println!("[{:>4}] parsing {} took {} ms ({:3.3} MiB/s)", i, &kernel.path().display(), dur_millis, mib_per_sec); } Ok(()) } diff --git a/ptx/src/ptx.pest b/ptx/src/ptx.pest index b8521c7e..6bba40e4 100644 --- a/ptx/src/ptx.pest +++ b/ptx/src/ptx.pest @@ -22,9 +22,13 @@ variable_decl = { | variable_spec_list ~ identifier_spec ~ "=" ~ variable_decl_initializer | variable_spec_list ~ identifier_list } -variable_decl_initializer = { initializer_list | literal_operand | identifier_spec } -initializer_list = _{ - "{" ~ initializer_list ~ "}" | "{" ~ literal_list ~ "}" +variable_decl_initializer = { + initializer_list | operand_list + // initializer_list | literal_operand | identifier_spec +} +initializer_list = _{ + "{" ~ operand_list ~ "}" + // "{" ~ initializer_list ~ "}" | "{" ~ literal_list ~ "}" } literal_list = _{ literal_operand ~ ("," ~ literal_operand)* @@ -376,11 +380,11 @@ loc_function_name_attr = { "function_name" ~ loc_function_name_label ~ ("+" ~ in loc_inlined_at_attr = { "inlined_at" ~ integer ~ integer ~ integer } ptr_spec = { - ".ptr" ~ ptr_space_spec ~ ptr_align_spec - | ".ptr" ~ ptr_align_spec + ".ptr" ~ ptr_space_spec ~ align_spec + | ".ptr" ~ align_spec } ptr_space_spec = { ".global" | ".local" | ".shared" | ".const" } -ptr_align_spec = { ".align" ~ integer } +// ptr_align_spec = { ".align" ~ integer } align_spec = { ".align" ~ integer } variable_spec = { @@ -400,74 +404,119 @@ variable_spec_list = _{ // specific case) precedes ".lo" option = { // level must come before compare_spec (".le") - ".level" - | type_spec - | addressable_spec - | compare_spec - | wmma_spec - | rounding_mode - | prmt_spec - // | atomic_operation_spec - | ".sync" - | ".sys" - | ".shiftamt" - | ".sat" + // | type_spec + // | addressable_spec + // | compare_spec + wmma_spec + // | rounding_mode + // | prmt_spec + | cache_level ~ "::" ~ cache_eviction_priority + ~ cache_level ~ "::" ~ "cache_hint" + ~ cache_level ~ "::" ~ cache_prefetch_size + | cache_level ~ "::" ~ "cache_hint" + ~ cache_level ~ "::" ~ cache_prefetch_size + | cache_level ~ "::" ~ cache_prefetch_size | ".arrive" + | ".approx" + | ".async" + | ".acquire" + | ".aligned" | ".and" | ".all" | ".any" | ".abs" | ".add" - | ".approx" + | ".a2d" | ".ballot" | ".bfly" + | ".bf16x2" | ".bf16" + | ".b8" | ".b16" | ".b32" | ".b64" | ".bb64" | ".b128" | ".bb128" + | ".b4e" | ".b" - | ".global" - | ".gl" - | ".red" - | ".release" - | ".popc" - | ".1d" - | ".2d" - | ".3d" - | ".ftz" - | ".full" - | ".exch" - | ".exit" - | ".extp" - | ".to" - | ".trap" - | ".half" - | ".clamp" + | ".commit_group" + | ".const" + | ".clamp" + | ".cube" | ".cas" | ".cta" + | ".col" | ".ca" + | ".cu" | ".cg" | ".cs" | ".cv" + | ".down" + | ".dec" + | ".e4m3" | ".e5m2" + | ".exch" + | ".exit" + | ".extp" + | ".ecl" | ".ecr" + | ".equ" | ".eq" + | ".full" + | ".ftz" + // floating point + | ".f16x2" | ".f16" | ".f32" | ".f64" | ".ff64" + | ".f4e" + | ".gtu" | ".geu" | ".gt" | ".ge" + | ".gpu" + | ".global" | ".gl" + | ".half" + | ".hi" | ".hs" + | ".inc" + | ".idx" + | ".local" + | ".level" + | ".ltu" | ".leu" | ".lt" | ".le" | ".lo" | ".ls" | ".lu" - | cache_level ~ "::" ~ cache_eviction_priority - ~ cache_level ~ "::" ~ "cache_hint" - ~ cache_level ~ "::" ~ cache_prefetch_size - | cache_level ~ "::" ~ "cache_hint" - ~ cache_level ~ "::" ~ cache_prefetch_size - | cache_level ~ "::" ~ cache_prefetch_size - | ".wide" - | ".wb" - | ".wt" - | ".nc" - | ".neg" + | ".max" | ".min" + | ".m16n8k16" | ".m16n8k32" | ".m16n8k8" | ".m8n8k4" | ".m8n8" | ".noftz" - | ".max" - | ".min" - | ".uni" - | ".up" + | ".neg" + | ".nc" + | ".neu" | ".num" | ".nan" | ".ne" | ".or" - | ".down" - | ".dec" - | ".inc" - | ".idx" + | ".param" + | ".popc" + | ".pred" + | ".release" + | ".relaxed" + | ".red" + | ".row" + | ".rc8" | ".rc16" + | ".rni" | ".rzi" | ".rmi" | ".rpi" + | ".rn" | ".rz" | ".rm" | ".rp" + | ".sampleref" + | ".shiftamt" + | ".surfref" | ".surf" + | ".satfinite" | ".sat" + | ".sstarr" + | ".shared" + | ".sync" + | ".sys" + // signed integer + | ".s8" | ".s16" | ".s32" | ".s64" + | ".texref" | ".tex" + | ".tf32" + | ".trans" + | ".trap" + | ".to" + | ".uni" + | ".up" + // unsigned integer + | ".u8" | ".u16" | ".u32" | ".u64" + | ".v2" | ".v3" | ".v4" + | ".wait_group" + | ".wait_all" + | ".wide" + | ".wb" + | ".wt" | ".xor" + | ".x4" + | ".1d" + | ".2d" + | ".3d" + } cache_level = { ".L1" | ".L2" } cache_eviction_priority = { @@ -506,8 +555,9 @@ compare_spec = { } prmt_spec = { ".f4e" | ".b4e" | ".rc8" | ".rc16" | ".ecl" | ".ecr" } wmma_spec = { - wmma_directive ~ layout ~ configuration - | wmma_directive ~ layout ~ layout ~ configuration + wmma_directive ~ layout{1,2} ~ configuration + // wmma_directive ~ layout ~ configuration + // | wmma_directive ~ layout ~ layout ~ configuration } wmma_directive = { ".a.sync" | ".b.sync" | ".c.sync" | ".d.sync" | ".mma.sync" } layout = { ".row" | ".col" } @@ -531,7 +581,10 @@ prototype_call_function = @{ operand } // } // prototype_param = { ".param" ~ (".b32" | ".b64") ~ identifier } -prototype_param = { ".param" ~ scalar_type ~ identifier } +prototype_param = { + ".param" ~ align_spec? ~ scalar_type ~ identifier_spec + // | ".param" ~ scalar_type ~ identifier_spec +} prototype_param_list = _{ prototype_param ~ ("," ~ prototype_param)* // prototype_param ~ "," ~ prototype_param_list | prototype_param