diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 4251d97d..944341c1 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -258,6 +258,30 @@ gen::generate_instruction_type!( src2: T, } }, + Rcp { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + Sqrt { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + Rsqrt { + type: { Type::from(data.type_) }, + data: RsqrtData, + arguments: { + dst: T, + src: T, + } + }, Trap { } } ); @@ -1117,3 +1141,29 @@ pub struct MinMaxFloat { pub nan: bool, pub type_: ScalarType, } + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum DivFloatKind { + Approx, + Full, + Rounding(RoundingMode), +} + +#[derive(Copy, Clone)] +pub struct RcpData { + pub kind: RcpKind, + pub flush_to_zero: Option, + pub type_: ScalarType, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum RcpKind { + Approx, + Full(RoundingMode), +} + +#[derive(Copy, Clone)] +pub struct RsqrtData { + pub flush_to_zero: Option, + pub type_: ScalarType, +} diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index 71d8dced..159b918e 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -2244,6 +2244,107 @@ derive_parser!( } ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp-approx-ftz-f64 + rcp.approx{.ftz}.type d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_ + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd{.ftz}.f32 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Full(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd.f64 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Full(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + .type: ScalarType = { .f32, .f64 }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt + sqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Full(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd.f64 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Full(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64 + rsqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Rsqrt { + data: ast::RsqrtData { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::RsqrtData { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.ftz.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::RsqrtData { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + ScalarType = { .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } }