Skip to content

Commit

Permalink
Add abs, mad
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Aug 21, 2024
1 parent 588d66b commit 6cd18bf
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 9 deletions.
4 changes: 2 additions & 2 deletions gen_impl/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub struct OpcodeDecl(pub Instruction, pub Arguments);

impl OpcodeDecl {
fn peek(input: syn::parse::ParseStream) -> bool {
Instruction::peek(input)
Instruction::peek(input) && !input.peek2(Token![=])
}
}

Expand Down Expand Up @@ -106,7 +106,7 @@ impl Parse for CodeBlock {
} else {
return Err(lookahead.error());
};
Ok(Self{special, code})
Ok(Self { special, code })
}
}

Expand Down
69 changes: 65 additions & 4 deletions ptx_parser/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,27 @@ gen::generate_instruction_type!(
src: T,
}
},
Abs {
data: AbsDetails,
type: { Type::Scalar(data.type_) },
arguments<T>: {
dst: T,
src: T,
}
},
Mad {
type: { Type::from(data.type_()) },
data: MadDetails,
arguments<T>: {
dst: {
repr: T,
type: { Type::from(data.dst_type()) },
},
src1: T,
src2: T,
src3: T,
}
},
Trap { }
}
);
Expand Down Expand Up @@ -588,16 +609,14 @@ pub enum MulDetails {
}

impl MulDetails {
#[allow(unused)] // Used by generated code
fn type_(&self) -> ScalarType {
pub fn type_(&self) -> ScalarType {
match self {
MulDetails::Integer { type_, .. } => *type_,
MulDetails::Float(arith) => arith.type_,
}
}

#[allow(unused)] // Used by generated code
fn dst_type(&self) -> ScalarType {
pub fn dst_type(&self) -> ScalarType {
match self {
MulDetails::Integer {
type_,
Expand Down Expand Up @@ -995,3 +1014,45 @@ pub enum CvtaDirection {
GenericToExplicit,
ExplicitToGeneric,
}

#[derive(Copy, Clone)]
pub struct AbsDetails {
pub flush_to_zero: Option<bool>,
pub type_: ScalarType,
}

#[derive(Copy, Clone)]
pub enum MadDetails {
Integer {
control: MulIntControl,
saturate: bool,
type_: ScalarType,
},
Float(ArithFloat),
}

impl MadDetails {
pub fn dst_type(&self) -> ScalarType {
match self {
MadDetails::Integer {
type_,
control: MulIntControl::Wide,
..
} => match type_ {
ScalarType::U16 => ScalarType::U32,
ScalarType::S16 => ScalarType::S32,
ScalarType::U32 => ScalarType::U64,
ScalarType::S32 => ScalarType::S64,
_ => unreachable!(),
},
_ => self.type_(),
}
}

fn type_(&self) -> ScalarType {
match self {
MadDetails::Integer { type_, .. } => *type_,
MadDetails::Float(arith) => arith.type_,
}
}
}
173 changes: 170 additions & 3 deletions ptx_parser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,8 @@ derive_parser!(
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
mul.mode.type d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Integer {
Expand All @@ -1476,8 +1478,6 @@ derive_parser!(
.s16, .s32 };
RawMulIntControl = { .wide };


// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
mul{.rnd}{.ftz}{.sat}.f32 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
Expand Down Expand Up @@ -1507,7 +1507,6 @@ derive_parser!(
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
mul{.rnd}{.ftz}{.sat}.f16 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
Expand Down Expand Up @@ -1706,6 +1705,174 @@ derive_parser!(
.space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} };
.size: ScalarType = { .u32, .u64 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-abs
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs
abs.type d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: None,
type_
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs{.ftz}.f32 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: Some(ftz),
type_: f32
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs.f64 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: None,
type_: f64
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs{.ftz}.f16 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: Some(ftz),
type_: f16
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs{.ftz}.f16x2 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: Some(ftz),
type_: f16x2
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs.bf16 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: None,
type_: bf16
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
abs.bf16x2 d, a => {
ast::Instruction::Abs {
data: ast::AbsDetails {
flush_to_zero: None,
type_: bf16x2
},
arguments: ast::AbsArgs {
dst: d, src: a
}
}
}
.type: ScalarType = { .s16, .s32, .s64 };
ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
mad.mode.type d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Integer {
type_,
control: mode.into(),
saturate: false
},
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.type: ScalarType = { .u16, .u32, .u64,
.s16, .s32, .s64 };
.mode: RawMulIntControl = { .hi, .lo };

// The .wide suffix is supported only for 16-bit and 32-bit integer types.
mad.wide.type d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Integer {
type_,
control: wide.into(),
saturate: false
},
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.type: ScalarType = { .u16, .u32,
.s16, .s32 };
RawMulIntControl = { .wide };

mad.hi.sat.s32 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Integer {
type_: s32,
control: hi.into(),
saturate: true
},
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
RawMulIntControl = { .hi };
ScalarType = { .s32 };

mad{.ftz}{.sat}.f32 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
type_: f32,
rounding: None,
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
mad.rnd{.ftz}{.sat}.f32 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
type_: f32,
rounding: Some(rnd.into()),
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}
}
mad.rnd.f64 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
type_: f64,
rounding: Some(rnd.into()),
flush_to_zero: None,
saturate: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}}
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }
Expand Down

0 comments on commit 6cd18bf

Please sign in to comment.