diff --git a/SciLean/Modules/ML/XLA/Concatenate.lean b/SciLean/Modules/ML/XLA/Concatenate.lean new file mode 100644 index 00000000..15fbc088 --- /dev/null +++ b/SciLean/Modules/ML/XLA/Concatenate.lean @@ -0,0 +1,102 @@ +import SciLean.Modules.ML.XLA.TensorIndex + +/-! + +### concatenate + +#### Semantics + +Concatenates `inputs` along `dimension` dimension in the same order as the given +arguments and produces a `result` tensor. More formally, +`result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]`, where: + +1. `id = d0 + ... + dk-1 + kd`. +1. `d` is equal to `dimension`, and `d0`, ... are `d`th dimension sizes + of `inputs`. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|-------------|------------------------------------------------------------|------------------| +| (I1) | `inputs` | variadic number of tensors or per-tensor quantized tensors | (C1-C6) | +| (I2) | `dimension` | constant of type `si64` | (C2), (C4), (C6) | + +#### Outputs + +| Name | Type | Constraints | +|----------|---------------------------------------|-------------| +| `result` | tensor or per-tensor quantized tensor | (C5-C6) | + +#### Constraints + +* (C1) `same(element_type(inputs...))`. +* (C2) `same(shape(inputs...))` except for `dim(inputs..., dimension)`. +* (C3) `0 < size(inputs)`. +* (C4) `0 <= dimension < rank(inputs[0])`. +* (C5) `element_type(result) = element_type(inputs[0])`. +* (C6) `shape(result) = shape(inputs[0])` except for: + * `dim(result, dimension) = dim(inputs[0], dimension) + ...`. + +#### Examples + +```mlir +// %input0: [[1, 2], [3, 4], [5, 6]] +// %input1: [[7, 8]] +%result = "stablehlo.concatenate"(%input0, %input1) { + dimension = 0 : i64 +} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> +// %result: [[1, 2], [3, 4], [5, 6], [7, 8]] +``` + + [More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/concatenate.mlir) + +-/ + +namespace SciLean + + +namespace Concatenate + +structure Args {r k} (inDims : Fin k → Dims r) where + dimension : Fin r + +def Args.outShape {r k} {inDims : Fin k → Dims r} (args : Args inDims) : Dims r := + .ofFn fun d => + if d = args.dimension then + ∑ i, (inDims i)[d] + else if h : 0 < k then + (inDims ⟨0, by linarith⟩)[d] + else + 0 + +def Args.indexMap {r k} {inDims : Fin k → Dims r} (args : Args inDims) : + (i : Fin k) × TensorIndex (inDims i) + ≃ + TensorIndex args.outShape := sorry + +structure Constraints {r k} {inDims : Fin k → Dims r} (args : Args inDims) (outDims : Dims r) where + c1 : True + c2 : ∀ d, d ≠ args.dimension → ∀ i j, (inDims i)[d] = (inDims j)[d] + c3 : 0 < k + c4 : (0:ℕ) ≤ args.dimension ∧ dimension < r + c5 : True + c6 : ∀ d, + if d = args.dimension then + outDims[d] = ∑ i, (inDims i)[d] + else + ∀ i, outDims[d] = (inDims i)[d] + + +end Concatenate + + +open Concatenate in +def concatenate {r k} {inDims : Fin k → Dims r} {outDims : Dims r} + (inputs : (i : Fin k) → TensorIndex (inDims i) → R) + (args : Args inDims) + (h : Constraints args outDims) + (houtDims : outDims = args.outShape := by infer_var) : + TensorIndex outDims → R := + fun i => + let ⟨i,j⟩ := args.indexMap.symm (houtDims ▸ i) + inputs i j diff --git a/SciLean/Modules/ML/XLA/Convolution.lean b/SciLean/Modules/ML/XLA/Convolution.lean index 3dbca30a..3a8efd47 100644 --- a/SciLean/Modules/ML/XLA/Convolution.lean +++ b/SciLean/Modules/ML/XLA/Convolution.lean @@ -232,8 +232,9 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op( -/ +namespace Convolution -structure convolution.ArgData {r} (lhsDims rhsDims : Dims r) where +structure Args {r} (lhsDims rhsDims : Dims r) where window_strides : ArrayN ℤ (r-2) padding : ArrayN (ℤ×ℤ) (r-2) lhs_dilation : ArrayN ℕ+ (r-2) @@ -241,89 +242,70 @@ structure convolution.ArgData {r} (lhsDims rhsDims : Dims r) where window_reversal : ArrayN Bool r input_batch_dimension : Fin r input_feature_dimension : Fin r - input_spatial_dimensions : ArrayN ℤ (r-2) + input_spatial_dimensions : ArrayN (Fin r) (r-2) kernel_input_feature_dimension : Fin r kernel_output_feature_dimension : Fin r - kernel_spatial_dimensions : ArrayN ℤ (r-2) + kernel_spatial_dimensions : ArrayN (Fin r) (r-2) output_batch_dimension : Fin r output_feature_dimension : Fin r - output_spatial_dimensions : ArrayN ℤ (r-2) + output_spatial_dimensions : ArrayN (Fin r) (r-2) feature_group_count : ℕ+ batch_group_count : ℕ+ precision_config : True +namespace Args -def convolution.ArgData.lhsSpatialShape {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : Dims (r - 2) := sorry - -def convolution.ArgData.rhsSpatialShape {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : Dims (r - 2) := sorry - -def convolution.ArgData.outSpatialShape {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : Dims (r - 2) := sorry - -def convolution.ArgData.lowPadding {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : ArrayN ℤ (r-2) := args.padding.map (·.1) - -def convolution.ArgData.highPadding {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : ArrayN ℤ (r-2) := args.padding.map (·.2) - -def convolution.ArgData.lhsShapeMap {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : - α × ArrayN α (r - 2) × α - ≃ - ArrayN α r := sorry - -def convolution.ArgData.outShapeMap {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : - α × ArrayN α (r - 2) × α - ≃ - ArrayN α r := sorry - -def convolution.ArgData.rhsShapeMap {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : - α × ArrayN α (r - 2) × α - ≃ - ArrayN α r := sorry - -def convolution.ArgData.outShapeMap {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : - ArrayN α r - ≃ - α × ArrayN α (r - 2) × α := sorry - - -def convolution.ArgData.outDims {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) : Dims r := - - let output_batch_dim_size := lhsDims[args.input_batch_dimension] / args.batch_group_count - let output_feature_dim_size := rhsDims[args.kernel_output_feature_dimension] - - let dilated_input_shape := (args.lhsSpatialShape - 1) * args.lhs_dilation + 1 - let padded_input_shape := args.lowPadding + dilated_input_shape + args.highPadding - let dilated_window_shape := (args.rhsSpatialShape - 1) * args.rhs_dilation + 1 - let is_empty_window := padded_input_shape ≤ 0 || dilated_window_shape > padded_input_shape - let output_spatial_dims := if is_empty_window then 0 else (padded_input_shape - dilated_window_shape) / args.window_strides + 1 - - args.outShapeMap (output_batch_dim_size, output_spatial_dims, output_feature_dim_size) - - -- .ofFn fun result_dim => - -- if result_dim = args.output_batch_dimension then - -- lhsDims[args.input_batch_dimension] / args.batch_group_count - -- else if result_dim = args.output_feature_dimension then - -- rhsDims[args.kernel_output_feature_dimension] - -- else - -- let spatial_dim : Fin (r-2) := sorry -- result_dim - args.output_feature_dimension - 1 - -- let lhs_dim := args.input_spatial_dimensions[spatial_dim] - -- let rhs_dim := args.kernel_spatial_dimensions[spatial_dim] - -- let dilated_input_shape := (args.lhsSpatialShape[spatial_dim] - 1) * args.lhs_dilation[spatial_dim] + 1 - -- let padded_input_shape := args.lowPadding[spatial_dim] + dilated_input_shape + args.highPadding[spatial_dim] - -- let dilated_window_shape := (args.rhsspatialShape[spatial_dim] - 1) * args.rhs_dilation[spatial_dim] + 1 - -- let is_empty_window := padded_input_shape ≤ 0 || dilated_window_shape > padded_input_shape - -- if is_empty_window then 0 else (padded_input_shape - dilated_window_shape) / args.window_strides[spatial_dim] + 1 - -structure convolution.Conditions {r} {lhsDims rhsDims : Dims r} - (args : convolution.ArgData lhsDims rhsDims) (outDims : Dims r) where + def lhsSpatialShape {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : Dims (r - 2) := sorry + + def rhsSpatialShape {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : Dims (r - 2) := sorry + + def outSpatialShape {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : Dims (r - 2) := sorry + + def lowPadding {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : ArrayN ℤ (r-2) := args.padding.map (·.1) + + def highPadding {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : ArrayN ℤ (r-2) := args.padding.map (·.2) + + def lhsShapeMap {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : + α × ArrayN α (r - 2) × α + ≃ + ArrayN α r := sorry + + def rhsShapeMap {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : + α × ArrayN α (r - 2) × α + ≃ + ArrayN α r := sorry + + def outShapeMap {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : + α × ArrayN α (r - 2) × α + ≃ + ArrayN α r := sorry + + def outDims {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) : Dims r := + + let output_batch_dim_size := lhsDims[args.input_batch_dimension] / args.batch_group_count + let output_feature_dim_size := rhsDims[args.kernel_output_feature_dimension] + + let dilated_input_shape := (args.lhsSpatialShape - 1) * args.lhs_dilation + 1 + let padded_input_shape := args.lowPadding + dilated_input_shape + args.highPadding + let dilated_window_shape := (args.rhsSpatialShape - 1) * args.rhs_dilation + 1 + let is_empty_window := padded_input_shape ≤ 0 || dilated_window_shape > padded_input_shape + let output_spatial_dims := if is_empty_window then 0 else (padded_input_shape - dilated_window_shape) / args.window_strides + 1 + + args.outShapeMap (output_batch_dim_size, output_spatial_dims, output_feature_dim_size) + +end Args + +structure Conditions {r} {lhsDims rhsDims : Dims r} + (args : Args lhsDims rhsDims) (outDims : Dims r) where c1 : lhsDims = rhsDims c2 : args.window_strides.size = r - 2 c3 : 0 < args.window_strides @@ -336,18 +318,25 @@ structure convolution.Conditions {r} {lhsDims rhsDims : Dims r} c10 : lhsDims[args.input_batch_dimension] % args.batch_group_count = 0 c11 : lhsDims[args.input_feature_dimension] % args.feature_group_count = 0 c12 : args.input_spatial_dimensions.size = r - 2 - c13 : ∀ d, 0 ≤ args.input_spatial_dimensions[d] ∧ args.input_spatial_dimensions[d] < r + c13 : ∀ d, (0:ℕ) ≤ args.input_spatial_dimensions[d] ∧ args.input_spatial_dimensions[d] < r c14 : rhsDims[args.kernel_input_feature_dimension] = lhsDims[args.input_feature_dimension] / args.feature_group_count c15 : rhsDims[args.kernel_output_feature_dimension] % args.batch_group_count = 0 -def convolution {r} {lhsDims rhsDims outDims : Dims r} +end Convolution + +variable {R} [RealScalar R] + +-- case +open Convolution in +def convolutionCore {r} {lhsDims rhsDims outDims : Dims r} (lhs : TensorIndex lhsDims → R) (rhs : TensorIndex rhsDims → R) - (args : convolution.ArgData lhsDims rhsDims) : + (args : Args lhsDims rhsDims) + (h : Conditions args outDims) : TensorIndex outDims → R := fun i => - let output_spatial_index : ArrayN ℤ (r-2) := sorry -- get the correct parts of `i` + let (_,output_spatial_index,_) := args.outShapeMap.symm i.1 -- get the correct parts of `i` let lhsWindowShape := args.lhsShapeMap (lhsDims[args.input_batch_dimension], @@ -359,30 +348,65 @@ def convolution {r} {lhsDims rhsDims outDims : Dims r} let lhs_base_dilation := args.lhsShapeMap (1,args.lhs_dilation,1) let lhs_window_dilations := args.lhsShapeMap (1,args.rhs_dilation,1) + let padded_lhs := pad lhs 0 lhs_padding_low lhs_padding_high lhs_base_dilation.toNat + -- there is some issue with elaboration and we have to specify these arguments explicitly + (outDims:= pad.outDims lhsDims lhs_padding_low lhs_padding_high lhs_base_dilation.toNat) (by infer_var) + + + let lhs_window_start : ArrayN ℤ r := args.lhsShapeMap (0,output_spatial_index,0) + let lhs_window := slice padded_lhs + { start_indices := lhs_window_start + limit_indices := (lhs_window_start + lhsWindowShape) + strides := lhs_window_dilations + c1 := sorry, c2 := sorry, c3 := sorry, c4 := sorry, c5 := sorry } + -- there is some issue with elaboration and we have to specify these arguments explicitly + (outDims:=slice.outDims lhs_window_start (lhs_window_start + lhsWindowShape) lhs_window_dilations) (by infer_var) - let padded_lhs := pad lhs 0 lhs_padding_low lhs_padding_high lhs_base_dilation.toNat (outDims:= pad.outDims lhsDims lhs_padding_low lhs_padding_high lhs_base_dilation.toNat) (by infer_var) - let lhs_window_start : ArrayN ℤ r := args.lhsShapeMap 0 output_spatial_index 0 - let lhs_window := slice padded_lhs lhs_window_start (lhs_window_start + lhs_window_dimensions) lhs_window_dilations - (outDims:=slice.outDims lhs_window_start (lhs_window_start + lhs_window_dimensions) lhs_window_dilations) (by infer_var) let dot_product : R := dot_general lhs_window rhs {lhs_batching_dimensions := #[] rhs_batching_dimensions := #[] - lhs_contracting_dimensions := args.input_spatial_dimensions.1 ++ [args.input_feature_dimension] - rhs_contracting_dimensions := args.kernel_spatial_dimensions.1 ++ [args.kernel_input_feature_dimension] - - c1 := by simp - c2 := by simp - c3 := by have := args.c13.1; simp_all - c4 := by have := args.c18.1; simp_all - c5 := by simp - c6 := by simp - c7 := by simp - c8 := by simp - c9 := by intro d; simp at d; have := d.2; aesop + lhs_contracting_dimensions := args.input_spatial_dimensions.1 ++ #[args.input_feature_dimension] + rhs_contracting_dimensions := args.kernel_spatial_dimensions.1 ++ #[args.kernel_input_feature_dimension] + + c1 := sorry + c2 := sorry + c3 := sorry + c4 := sorry + c5 := sorry + c6 := sorry + c7 := sorry + c8 := sorry + c9 := sorry c10 := by sorry c11 := by simp c12 := by sorry} - (t:= 0) (outDims := sorry) sorry + (t:=0) (outDims := sorry) sorry dot_product + + +#check Array.modify + +open Convolution in +def convolution {r} {lhsDims rhsDims outDims : Dims r} + (lhs : TensorIndex lhsDims → R) (rhs : TensorIndex rhsDims → R) + (args : Args lhsDims rhsDims) + (h : Conditions args outDims) : TensorIndex outDims → R := + + if args.feature_group_count > 1 then + let lhsDims' : Dims r := ⟨lhsDims.1.modify args.input_feature_dimension.1 (fun d => d / args.feature_group_count), sorry⟩ + let lhses : Fin args.feature_group_count → TensorIndex lhsDims' → R := sorry + --split lhs args.feature_group_count args.input_feature_dimension + let rhsDims' : Dims r := ⟨rhsDims.1.modify args.kernel_output_feature_dimension.1 (fun d => d / args.feature_group_count), sorry⟩ + let rhses : Fin args.feature_group_count → TensorIndex rhsDims' → R := sorry + --split rhs args.feature_group_count args.kernel_output_feature_dimension + let results := fun i => convolution (lhses i) (rhses i) args h + concatenate results args.output_feature_dimension + else if args.batch_group_count > 1 then + let lhses := split lhs args.batch_group_count args.input_batch_dimension + let rhses := split rhs args.batch_group_count args.kernel_output_feature_dimension + let results := lhses.zipWith rhses (fun lhs rhs => convolutionCore lhs rhs args h) + concatenate results args.output_feature_dimension + else + convolutionCore lhs rhs args h diff --git a/SciLean/Modules/ML/XLA/Split.lean b/SciLean/Modules/ML/XLA/Split.lean new file mode 100644 index 00000000..194f38c1 --- /dev/null +++ b/SciLean/Modules/ML/XLA/Split.lean @@ -0,0 +1,49 @@ +import SciLean.Modules.ML.XLA.TensorIndex + + +namespace SciLean + +/-! Split function that is not in StableHLO spec but appears in the definition of convolution + +-/ + + +namespace Split + +structure Args {r} (inDims : Dims r) where + split_size : ℕ+ + split_dimension : Fin r + +def Args.outShape {r} {inDims : Dims r} (args : Args inDims) : Dims r := + .ofFn fun d => + if d = args.split_dimension then + inDims[d] / args.split_size + else + inDims[d] + +def Args.indexMap {r} {inDims : Dims r} (args : Args inDims) : + Fin args.split_size × TensorIndex args.outShape + ≃ + TensorIndex inDims := sorry + +structure Conditions {r} {inDims : Dims r} (args : Args inDims) (outDims : Dims r) where + c1 : inDims[args.split_dimension] % args.split_size = 0 + c2 : ∀ d, + if d = args.split_dimension then + outDims[d] = inDims[d] / args.split_size + else + outDims[d] = inDims[d] + + +structure Args' {r} (inDims outDims : Dims r) + +end Split + +open Split in + +def split {r} {inDims outDims : Dims r} (operand : TensorIndex inDims → R) + (args : Args inDims) (h : Conditions args outDims) + (houtDims : outDims = args.outShape := by infer_var) + (i : Fin args.split_size) : TensorIndex outDims → R := + fun j => + operand (args.indexMap (i, houtDims ▸ j))