diff --git a/gen/Cargo.toml b/gen/Cargo.toml index e24be0f4..e26383da 100644 --- a/gen/Cargo.toml +++ b/gen/Cargo.toml @@ -13,3 +13,4 @@ rustc-hash = "2.0.0" syn = "2.0.67" quote = "1.0" proc-macro2 = "1.0.86" +either = "1.13.0" diff --git a/gen/src/lib.rs b/gen/src/lib.rs index ebddf03f..472d1fc6 100644 --- a/gen/src/lib.rs +++ b/gen/src/lib.rs @@ -1,3 +1,4 @@ +use either::Either; use gen_impl::parser; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; @@ -28,7 +29,7 @@ static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; struct OpcodeDefinitions { definitions: Vec, - block_selection: Vec, usize)>>, + block_selection: Vec>, usize)>>, } impl OpcodeDefinitions { @@ -51,33 +52,51 @@ impl OpcodeDefinitions { _ => {} } 'check_definitions: for i in unselected.iter().copied() { - // Attempt every modifier - 'check_candidates: for candidate in definitions[i] + let mut candidates = definitions[i] .unordered_modifiers .iter() .chain(definitions[i].ordered_modifiers.iter()) - { - let candidate = if let DotModifierRef::Direct { - optional: false, - value, - .. - } = candidate - { - value - } else { - continue; - }; + .filter(|modifier| match modifier { + DotModifierRef::Direct { + optional: false, .. + } + | DotModifierRef::Indirect { + optional: false, .. + } => true, + _ => false, + }) + .collect::>(); + candidates.sort_by_key(|modifier| match modifier { + DotModifierRef::Direct { .. } => 1, + DotModifierRef::Indirect { value, .. } => value.alternatives.len(), + }); + // Attempt every modifier + 'check_candidates: for candidate_modifier in candidates { // check all other unselected patterns for j in unselected.iter().copied() { if i == j { continue; } - if definitions[j].possible_modifiers.contains(candidate) { - continue 'check_candidates; + let candidate_set = match candidate_modifier { + DotModifierRef::Direct { value, .. } => Either::Left(iter::once(value)), + DotModifierRef::Indirect { value, .. } => { + Either::Right(value.alternatives.iter()) + } + }; + for candidate_value in candidate_set { + if definitions[j].possible_modifiers.contains(candidate_value) { + continue 'check_candidates; + } } } // it's unique - selections[i] = Some((Some(candidate), generation)); + let candidate_vec = match candidate_modifier { + DotModifierRef::Direct { value, .. } => vec![value.clone()], + DotModifierRef::Indirect { value, .. } => { + value.alternatives.iter().cloned().collect::>() + } + }; + selections[i] = Some((Some(candidate_vec), generation)); selected_something = true; continue 'check_definitions; } @@ -96,9 +115,9 @@ impl OpcodeDefinitions { let mut current_generation_definitions = Vec::new(); for (idx, selection) in selections.iter_mut().enumerate() { match selection { - Some((modifier, generation)) => { + Some((modifier_set, generation)) => { if *generation == current_generation { - current_generation_definitions.push((modifier.cloned(), idx)); + current_generation_definitions.push((modifier_set.clone(), idx)); *selection = None; } } @@ -181,6 +200,8 @@ impl SingleOpcodeDefinition { let name = &arg.ident; let arg_type = if arg.unified { quote! { (ParsedOperandStr<'input>, bool) } + } else if arg.can_be_negated { + quote! { (bool, ParsedOperandStr<'input>) } } else { quote! { ParsedOperandStr<'input> } }; @@ -222,9 +243,6 @@ impl SingleOpcodeDefinition { unnamed_rules = FxHashMap::default(); } let mut possible_modifiers = FxHashSet::default(); - for (_, options) in named_rules.iter() { - possible_modifiers.extend(options.alternatives.iter().cloned()); - } let parser::OpcodeDecl(instruction, arguments) = opcode_decl; let mut unordered_modifiers = instruction .modifiers @@ -232,6 +250,7 @@ impl SingleOpcodeDefinition { .map(|parser::MaybeDotModifier { optional, modifier }| { match named_rules.get(&modifier) { Some(alts) => { + possible_modifiers.extend(alts.alternatives.iter().cloned()); if alts.alternatives.len() == 1 && alts.type_.is_none() { DotModifierRef::Direct { optional, @@ -437,11 +456,10 @@ fn emit_parse_function( for (selection_key, selected_definition) in selection_layer { let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]); match selection_key { - Some(selection_key) => { - let selection_key = - selection_key.dot_capitalized(); + Some(selection_keys) => { + let selection_keys = selection_keys.iter().map(|k| k.dot_capitalized()); quote! { - else if modifiers.contains(& #type_name :: #selection_key) { + else if false #(|| modifiers.contains(& #type_name :: #selection_keys))* { #def_parser } } @@ -715,7 +733,7 @@ fn emit_definition_parser( | DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), }); let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| { - let comma = if idx == 0 { + let comma = if idx == 0 || arg.pre_pipe { quote! { empty } } else { quote! { any.verify(|t| *t == #token_type::Comma).void() } @@ -774,10 +792,17 @@ fn emit_definition_parser( (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified) }; let arg_name = &arg.ident; + if arg.unified && arg.can_be_negated { + panic!("TODO: argument can't be both prefixed by `!` and suffixed by `.unified`") + } let inner_parser = if arg.unified { quote! { #pattern.map(|(_, _, _, _, name, _, unified)| (name, unified)) } + } else if arg.can_be_negated { + quote! { + #pattern.map(|(_, _, _, negated, name, _, _)| (negated, name)) + } } else { quote! { #pattern.map(|(_, _, _, _, name, _, _)| name) diff --git a/gen_impl/src/lib.rs b/gen_impl/src/lib.rs index 7160603d..57660fb0 100644 --- a/gen_impl/src/lib.rs +++ b/gen_impl/src/lib.rs @@ -70,7 +70,7 @@ impl GenerateInstructionType { let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix()); let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { ( - quote! { <#type_parameters, To> }, + quote! { <#type_parameters, To: Operand> }, quote! { <#short_parameters, To> }, quote! { #type_name }, ) @@ -514,19 +514,29 @@ impl ArgumentField { .unwrap_or_else(|| quote! { StateSpace::Reg }); let is_dst = self.is_dst; let name = &self.name; - let arguments_name = if is_mut { - quote! { - &mut arguments.#name - } + let (operand_fn, arguments_name) = if is_mut { + ( + quote! { + VisitOperand::visit_mut + }, + quote! { + &mut arguments.#name + }, + ) } else { - quote! { - & arguments.#name - } + ( + quote! { + VisitOperand::visit + }, + quote! { + & arguments.#name + }, + ) }; quote! {{ let type_ = #type_; let space = #space; - visitor.visit(#arguments_name, &type_, space, #is_dst); + #operand_fn(#arguments_name, |x| visitor.visit(x, &type_, space, #is_dst)); }} } @@ -548,7 +558,7 @@ impl ArgumentField { let #name = { let type_ = #type_; let space = #space; - visitor.visit(arguments.#name, &type_, space, #is_dst) + MapOperand::map(arguments.#name, |x| visitor.visit(x, &type_, space, #is_dst)) }; } } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 714c9b38..e456e03a 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1,6 +1,5 @@ -use std::intrinsics::unreachable; - -use super::{MemScope, ScalarType, StateSpace, VectorPrefix}; +use super::{MemScope, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix}; +use crate::{PtxError, PtxParserState}; use bitflags::bitflags; pub enum Statement { @@ -11,7 +10,7 @@ pub enum Statement { } gen::generate_instruction_type!( - pub enum Instruction { + pub enum Instruction { Mov { type: { &data.typ }, data: MovDetails, @@ -63,6 +62,52 @@ gen::generate_instruction_type!( src2: T, } }, + Setp { + data: SetpData, + arguments: { + dst1: { + repr: T, + type: ScalarType::Pred.into() + }, + dst2: { + repr: Option, + type: ScalarType::Pred.into() + }, + src1: { + repr: T, + type: data.type_.into(), + }, + src2: { + repr: T, + type: data.type_.into(), + } + } + }, + SetpBool { + data: SetpBoolData, + arguments: { + dst1: { + repr: T, + type: ScalarType::Pred.into() + }, + dst2: { + repr: Option, + type: ScalarType::Pred.into() + }, + src1: { + repr: T, + type: data.base.type_.into(), + }, + src2: { + repr: T, + type: data.base.type_.into(), + }, + src3: { + repr: T, + type: ScalarType::Pred.into() + } + } + }, Ret { data: RetData }, @@ -70,6 +115,66 @@ gen::generate_instruction_type!( } ); +pub trait Visitor { + fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMut { + fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); +} + +pub trait VisitorMap { + fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; +} + +trait VisitOperand { + type Operand; + fn visit(&self, fn_: impl FnOnce(&Self::Operand)); + fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)); +} + +impl VisitOperand for T { + type Operand = Self; + fn visit(&self, fn_: impl FnOnce(&Self::Operand)) { + fn_(self) + } + fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) { + fn_(self) + } +} + +impl VisitOperand for Option { + type Operand = T; + fn visit(&self, fn_: impl FnOnce(&Self::Operand)) { + self.as_ref().map(fn_); + } + fn visit_mut(&mut self, fn_: impl FnOnce(&mut Self::Operand)) { + self.as_mut().map(fn_); + } +} + +trait MapOperand: Sized { + type Input; + type Output; + fn map(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output; +} + +impl MapOperand for T { + type Input = Self; + type Output = U; + fn map(self, fn_: impl FnOnce(T) -> U) -> U { + fn_(self) + } +} + +impl MapOperand for Option { + type Input = T; + type Output = Option; + fn map(self, fn_: impl FnOnce(T) -> U) -> Option { + self.map(|x| fn_(x)) + } +} + pub struct MultiVariable { pub var: Variable, pub count: Option, @@ -89,18 +194,6 @@ pub struct PredAt { pub label: ID, } -pub trait Visitor { - fn visit(&mut self, args: &T, type_: &Type, space: StateSpace, is_dst: bool); -} - -pub trait VisitorMut { - fn visit(&mut self, args: &mut T, type_: &Type, space: StateSpace, is_dst: bool); -} - -pub trait VisitorMap { - fn visit(&mut self, args: From, type_: &Type, space: StateSpace, is_dst: bool) -> To; -} - #[derive(PartialEq, Eq, Clone, Hash)] pub enum Type { // .param.b32 foo; @@ -121,6 +214,43 @@ impl Type { } } +impl ScalarType { + pub fn kind(self) -> ScalarKind { + match self { + ScalarType::U8 => ScalarKind::Unsigned, + ScalarType::U16 => ScalarKind::Unsigned, + ScalarType::U16x2 => ScalarKind::Unsigned, + ScalarType::U32 => ScalarKind::Unsigned, + ScalarType::U64 => ScalarKind::Unsigned, + ScalarType::S8 => ScalarKind::Signed, + ScalarType::S16 => ScalarKind::Signed, + ScalarType::S16x2 => ScalarKind::Signed, + ScalarType::S32 => ScalarKind::Signed, + ScalarType::S64 => ScalarKind::Signed, + ScalarType::B8 => ScalarKind::Bit, + ScalarType::B16 => ScalarKind::Bit, + ScalarType::B32 => ScalarKind::Bit, + ScalarType::B64 => ScalarKind::Bit, + ScalarType::B128 => ScalarKind::Bit, + ScalarType::F16 => ScalarKind::Float, + ScalarType::F16x2 => ScalarKind::Float, + ScalarType::F32 => ScalarKind::Float, + ScalarType::F64 => ScalarKind::Float, + ScalarType::BF16 => ScalarKind::Float, + ScalarType::BF16x2 => ScalarKind::Float, + ScalarType::Pred => ScalarKind::Pred, + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ScalarKind { + Bit, + Unsigned, + Signed, + Float, + Pred, +} impl From for Type { fn from(value: ScalarType) -> Self { Type::Scalar(value) @@ -347,3 +477,135 @@ pub enum MulIntControl { High, Wide, } + +pub struct SetpData { + pub type_: ScalarType, + pub flush_to_zero: Option, + pub cmp_op: SetpCompareOp, +} + +impl SetpData { + pub(crate) fn try_parse( + errors: &mut PtxParserState, + cmp_op: super::RawSetpCompareOp, + ftz: bool, + type_: ScalarType, + ) -> Self { + let flush_to_zero = match (ftz, type_) { + (_, ScalarType::F32) => Some(ftz), + _ => { + errors.push(PtxError::NonF32Ftz); + None + } + }; + let type_kind = type_.kind(); + let cmp_op = if type_kind == ScalarKind::Float { + SetpCompareOp::Float(SetpCompareFloat::from(cmp_op)) + } else { + match SetpCompareInt::try_from(cmp_op) { + Ok(op) => SetpCompareOp::Integer(op), + Err(err) => { + errors.push(err); + SetpCompareOp::Integer(SetpCompareInt::Eq) + } + } + }; + Self { + type_, + flush_to_zero, + cmp_op, + } + } +} + +pub struct SetpBoolData { + pub base: SetpData, + pub bool_op: SetpBoolPostOp, + pub negate_src3: bool +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareOp { + Integer(SetpCompareInt), + Float(SetpCompareFloat), +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareInt { + Eq, + NotEq, + Less, + LessOrEq, + Greater, + GreaterOrEq, +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareFloat { + Eq, + NotEq, + Less, + LessOrEq, + Greater, + GreaterOrEq, + NanEq, + NanNotEq, + NanLess, + NanLessOrEq, + NanGreater, + NanGreaterOrEq, + IsNotNan, + IsAnyNan, +} + +impl TryFrom for SetpCompareInt { + type Error = PtxError; + + fn try_from(value: RawSetpCompareOp) -> Result { + match value { + RawSetpCompareOp::Eq => Ok(SetpCompareInt::Eq), + RawSetpCompareOp::Ne => Ok(SetpCompareInt::NotEq), + RawSetpCompareOp::Lt => Ok(SetpCompareInt::Less), + RawSetpCompareOp::Le => Ok(SetpCompareInt::LessOrEq), + RawSetpCompareOp::Gt => Ok(SetpCompareInt::Greater), + RawSetpCompareOp::Ge => Ok(SetpCompareInt::GreaterOrEq), + RawSetpCompareOp::Lo => Ok(SetpCompareInt::Less), + RawSetpCompareOp::Ls => Ok(SetpCompareInt::LessOrEq), + RawSetpCompareOp::Hi => Ok(SetpCompareInt::Greater), + RawSetpCompareOp::Hs => Ok(SetpCompareInt::GreaterOrEq), + RawSetpCompareOp::Equ => Err(PtxError::WrongType), + RawSetpCompareOp::Neu => Err(PtxError::WrongType), + RawSetpCompareOp::Ltu => Err(PtxError::WrongType), + RawSetpCompareOp::Leu => Err(PtxError::WrongType), + RawSetpCompareOp::Gtu => Err(PtxError::WrongType), + RawSetpCompareOp::Geu => Err(PtxError::WrongType), + RawSetpCompareOp::Num => Err(PtxError::WrongType), + RawSetpCompareOp::Nan => Err(PtxError::WrongType), + } + } +} + +impl From for SetpCompareFloat { + fn from(value: RawSetpCompareOp) -> Self { + match value { + RawSetpCompareOp::Eq => SetpCompareFloat::Eq, + RawSetpCompareOp::Ne => SetpCompareFloat::NotEq, + RawSetpCompareOp::Lt => SetpCompareFloat::Less, + RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq, + RawSetpCompareOp::Gt => SetpCompareFloat::Greater, + RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq, + RawSetpCompareOp::Lo => SetpCompareFloat::Less, + RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq, + RawSetpCompareOp::Hi => SetpCompareFloat::Greater, + RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq, + RawSetpCompareOp::Equ => SetpCompareFloat::NanEq, + RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq, + RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess, + RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq, + RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater, + RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq, + RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan, + RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan, + } + } +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index b087fb93..785496d1 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -769,6 +769,8 @@ pub enum PtxError { #[error("")] NonF32Ftz, #[error("")] + WrongType, + #[error("")] WrongArrayType, #[error("")] WrongVectorElement, @@ -996,6 +998,9 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum ScalarType { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum SetpBoolPostOp { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -1424,6 +1429,38 @@ derive_parser!( .rnd: RawFloatRounding = { .rn }; ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp + setp.CmpOp{.ftz}.type p[|q], a, b => { + let data = ast::SetpData::try_parse(state, cmpop, ftz, type_); + ast::Instruction::Setp { + data, + arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b } + } + } + setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => { + let (negate_src3, c) = c; + let base = ast::SetpData::try_parse(state, cmpop, ftz, type_); + let data = ast::SetpBoolData { + base, + bool_op: boolop, + negate_src3 + }; + ast::Instruction::SetpBool { + data, + arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c } + } + } + .CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge, + .lo, .ls, .hi, .hs, // signed + .equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only + .BoolOp: SetpBoolPostOp = { .and, .or, .xor }; + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64, + .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } @@ -1432,8 +1469,6 @@ derive_parser!( ); fn main() { - use winnow::combinator::*; - use winnow::token::*; use winnow::Parser; let lexer = Token::lexer(