Skip to content

Commit

Permalink
concatenate and split functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Oct 23, 2024
1 parent b545980 commit 0ebfd32
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 96 deletions.
102 changes: 102 additions & 0 deletions SciLean/Modules/ML/XLA/Concatenate.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import SciLean.Modules.ML.XLA.TensorIndex

/-!
### concatenate
#### Semantics
Concatenates `inputs` along `dimension` dimension in the same order as the given
arguments and produces a `result` tensor. More formally,
`result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]`, where:
1. `id = d0 + ... + dk-1 + kd`.
1. `d` is equal to `dimension`, and `d0`, ... are `d`th dimension sizes
of `inputs`.
#### Inputs
| Label | Name | Type | Constraints |
|-------|-------------|------------------------------------------------------------|------------------|
| (I1) | `inputs` | variadic number of tensors or per-tensor quantized tensors | (C1-C6) |
| (I2) | `dimension` | constant of type `si64` | (C2), (C4), (C6) |
#### Outputs
| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C5-C6) |
#### Constraints
* (C1) `same(element_type(inputs...))`.
* (C2) `same(shape(inputs...))` except for `dim(inputs..., dimension)`.
* (C3) `0 < size(inputs)`.
* (C4) `0 <= dimension < rank(inputs[0])`.
* (C5) `element_type(result) = element_type(inputs[0])`.
* (C6) `shape(result) = shape(inputs[0])` except for:
* `dim(result, dimension) = dim(inputs[0], dimension) + ...`.
#### Examples
```mlir
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
```
&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/concatenate.mlir)
-/

namespace SciLean


namespace Concatenate

structure Args {r k} (inDims : Fin k → Dims r) where
dimension : Fin r

def Args.outShape {r k} {inDims : Fin k → Dims r} (args : Args inDims) : Dims r :=
.ofFn fun d =>
if d = args.dimension then
∑ i, (inDims i)[d]
else if h : 0 < k then
(inDims ⟨0, by linarith⟩)[d]
else
0

def Args.indexMap {r k} {inDims : Fin k → Dims r} (args : Args inDims) :
(i : Fin k) × TensorIndex (inDims i)
TensorIndex args.outShape := sorry

structure Constraints {r k} {inDims : Fin k → Dims r} (args : Args inDims) (outDims : Dims r) where
c1 : True
c2 : ∀ d, d ≠ args.dimension → ∀ i j, (inDims i)[d] = (inDims j)[d]
c3 : 0 < k
c4 : (0:ℕ) ≤ args.dimension ∧ dimension < r
c5 : True
c6 : ∀ d,
if d = args.dimension then
outDims[d] = ∑ i, (inDims i)[d]
else
∀ i, outDims[d] = (inDims i)[d]


end Concatenate


open Concatenate in
def concatenate {r k} {inDims : Fin k → Dims r} {outDims : Dims r}
(inputs : (i : Fin k) → TensorIndex (inDims i) → R)
(args : Args inDims)
(h : Constraints args outDims)
(houtDims : outDims = args.outShape := by infer_var) :
TensorIndex outDims → R :=
fun i =>
let ⟨i,j⟩ := args.indexMap.symm (houtDims ▸ i)
inputs i j
216 changes: 120 additions & 96 deletions SciLean/Modules/ML/XLA/Convolution.lean
Original file line number Diff line number Diff line change
Expand Up @@ -232,98 +232,80 @@ For hybrid quantized types, performs `hybrid_dequantize_then_op(
-/

namespace Convolution

structure convolution.ArgData {r} (lhsDims rhsDims : Dims r) where
structure Args {r} (lhsDims rhsDims : Dims r) where
window_strides : ArrayN ℤ (r-2)
padding : ArrayN (ℤ×ℤ) (r-2)
lhs_dilation : ArrayN ℕ+ (r-2)
rhs_dilation : ArrayN ℕ+ (r-2)
window_reversal : ArrayN Bool r
input_batch_dimension : Fin r
input_feature_dimension : Fin r
input_spatial_dimensions : ArrayN (r-2)
input_spatial_dimensions : ArrayN (Fin r) (r-2)
kernel_input_feature_dimension : Fin r
kernel_output_feature_dimension : Fin r
kernel_spatial_dimensions : ArrayN (r-2)
kernel_spatial_dimensions : ArrayN (Fin r) (r-2)
output_batch_dimension : Fin r
output_feature_dimension : Fin r
output_spatial_dimensions : ArrayN (r-2)
output_spatial_dimensions : ArrayN (Fin r) (r-2)
feature_group_count : ℕ+
batch_group_count : ℕ+
precision_config : True

namespace Args

def convolution.ArgData.lhsSpatialShape {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) : Dims (r - 2) := sorry

def convolution.ArgData.rhsSpatialShape {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) : Dims (r - 2) := sorry

def convolution.ArgData.outSpatialShape {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) : Dims (r - 2) := sorry

def convolution.ArgData.lowPadding {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) : ArrayN ℤ (r-2) := args.padding.map (·.1)

def convolution.ArgData.highPadding {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) : ArrayN ℤ (r-2) := args.padding.map (·.2)

def convolution.ArgData.lhsShapeMap {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) :
α × ArrayN α (r - 2) × α
ArrayN α r := sorry

def convolution.ArgData.outShapeMap {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) :
α × ArrayN α (r - 2) × α
ArrayN α r := sorry

def convolution.ArgData.rhsShapeMap {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) :
α × ArrayN α (r - 2) × α
ArrayN α r := sorry

def convolution.ArgData.outShapeMap {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) :
ArrayN α r
α × ArrayN α (r - 2) × α := sorry


def convolution.ArgData.outDims {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) : Dims r :=

let output_batch_dim_size := lhsDims[args.input_batch_dimension] / args.batch_group_count
let output_feature_dim_size := rhsDims[args.kernel_output_feature_dimension]

let dilated_input_shape := (args.lhsSpatialShape - 1) * args.lhs_dilation + 1
let padded_input_shape := args.lowPadding + dilated_input_shape + args.highPadding
let dilated_window_shape := (args.rhsSpatialShape - 1) * args.rhs_dilation + 1
let is_empty_window := padded_input_shape ≤ 0 || dilated_window_shape > padded_input_shape
let output_spatial_dims := if is_empty_window then 0 else (padded_input_shape - dilated_window_shape) / args.window_strides + 1

args.outShapeMap (output_batch_dim_size, output_spatial_dims, output_feature_dim_size)

-- .ofFn fun result_dim =>
-- if result_dim = args.output_batch_dimension then
-- lhsDims[args.input_batch_dimension] / args.batch_group_count
-- else if result_dim = args.output_feature_dimension then
-- rhsDims[args.kernel_output_feature_dimension]
-- else
-- let spatial_dim : Fin (r-2) := sorry -- result_dim - args.output_feature_dimension - 1
-- let lhs_dim := args.input_spatial_dimensions[spatial_dim]
-- let rhs_dim := args.kernel_spatial_dimensions[spatial_dim]
-- let dilated_input_shape := (args.lhsSpatialShape[spatial_dim] - 1) * args.lhs_dilation[spatial_dim] + 1
-- let padded_input_shape := args.lowPadding[spatial_dim] + dilated_input_shape + args.highPadding[spatial_dim]
-- let dilated_window_shape := (args.rhsspatialShape[spatial_dim] - 1) * args.rhs_dilation[spatial_dim] + 1
-- let is_empty_window := padded_input_shape ≤ 0 || dilated_window_shape > padded_input_shape
-- if is_empty_window then 0 else (padded_input_shape - dilated_window_shape) / args.window_strides[spatial_dim] + 1

structure convolution.Conditions {r} {lhsDims rhsDims : Dims r}
(args : convolution.ArgData lhsDims rhsDims) (outDims : Dims r) where
def lhsSpatialShape {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) : Dims (r - 2) := sorry

def rhsSpatialShape {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) : Dims (r - 2) := sorry

def outSpatialShape {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) : Dims (r - 2) := sorry

def lowPadding {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) : ArrayN ℤ (r-2) := args.padding.map (·.1)

def highPadding {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) : ArrayN ℤ (r-2) := args.padding.map (·.2)

def lhsShapeMap {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) :
α × ArrayN α (r - 2) × α
ArrayN α r := sorry

def rhsShapeMap {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) :
α × ArrayN α (r - 2) × α
ArrayN α r := sorry

def outShapeMap {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) :
α × ArrayN α (r - 2) × α
ArrayN α r := sorry

def outDims {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) : Dims r :=

let output_batch_dim_size := lhsDims[args.input_batch_dimension] / args.batch_group_count
let output_feature_dim_size := rhsDims[args.kernel_output_feature_dimension]

let dilated_input_shape := (args.lhsSpatialShape - 1) * args.lhs_dilation + 1
let padded_input_shape := args.lowPadding + dilated_input_shape + args.highPadding
let dilated_window_shape := (args.rhsSpatialShape - 1) * args.rhs_dilation + 1
let is_empty_window := padded_input_shape ≤ 0 || dilated_window_shape > padded_input_shape
let output_spatial_dims := if is_empty_window then 0 else (padded_input_shape - dilated_window_shape) / args.window_strides + 1

args.outShapeMap (output_batch_dim_size, output_spatial_dims, output_feature_dim_size)

end Args

structure Conditions {r} {lhsDims rhsDims : Dims r}
(args : Args lhsDims rhsDims) (outDims : Dims r) where
c1 : lhsDims = rhsDims
c2 : args.window_strides.size = r - 2
c3 : 0 < args.window_strides
Expand All @@ -336,18 +318,25 @@ structure convolution.Conditions {r} {lhsDims rhsDims : Dims r}
c10 : lhsDims[args.input_batch_dimension] % args.batch_group_count = 0
c11 : lhsDims[args.input_feature_dimension] % args.feature_group_count = 0
c12 : args.input_spatial_dimensions.size = r - 2
c13 : ∀ d, 0 ≤ args.input_spatial_dimensions[d] ∧ args.input_spatial_dimensions[d] < r
c13 : ∀ d, (0:ℕ) ≤ args.input_spatial_dimensions[d] ∧ args.input_spatial_dimensions[d] < r
c14 : rhsDims[args.kernel_input_feature_dimension] = lhsDims[args.input_feature_dimension] / args.feature_group_count
c15 : rhsDims[args.kernel_output_feature_dimension] % args.batch_group_count = 0


def convolution {r} {lhsDims rhsDims outDims : Dims r}
end Convolution

variable {R} [RealScalar R]

-- case
open Convolution in
def convolutionCore {r} {lhsDims rhsDims outDims : Dims r}
(lhs : TensorIndex lhsDims → R) (rhs : TensorIndex rhsDims → R)
(args : convolution.ArgData lhsDims rhsDims) :
(args : Args lhsDims rhsDims)
(h : Conditions args outDims) :
TensorIndex outDims → R :=

fun i =>
let output_spatial_index : ArrayN ℤ (r-2) := sorry -- get the correct parts of `i`
let (_,output_spatial_index,_) := args.outShapeMap.symm i.1 -- get the correct parts of `i`

let lhsWindowShape :=
args.lhsShapeMap (lhsDims[args.input_batch_dimension],
Expand All @@ -359,30 +348,65 @@ def convolution {r} {lhsDims rhsDims outDims : Dims r}
let lhs_base_dilation := args.lhsShapeMap (1,args.lhs_dilation,1)
let lhs_window_dilations := args.lhsShapeMap (1,args.rhs_dilation,1)

let padded_lhs := pad lhs 0 lhs_padding_low lhs_padding_high lhs_base_dilation.toNat
-- there is some issue with elaboration and we have to specify these arguments explicitly
(outDims:= pad.outDims lhsDims lhs_padding_low lhs_padding_high lhs_base_dilation.toNat) (by infer_var)


let lhs_window_start : ArrayN ℤ r := args.lhsShapeMap (0,output_spatial_index,0)
let lhs_window := slice padded_lhs
{ start_indices := lhs_window_start
limit_indices := (lhs_window_start + lhsWindowShape)
strides := lhs_window_dilations
c1 := sorry, c2 := sorry, c3 := sorry, c4 := sorry, c5 := sorry }
-- there is some issue with elaboration and we have to specify these arguments explicitly
(outDims:=slice.outDims lhs_window_start (lhs_window_start + lhsWindowShape) lhs_window_dilations) (by infer_var)

let padded_lhs := pad lhs 0 lhs_padding_low lhs_padding_high lhs_base_dilation.toNat (outDims:= pad.outDims lhsDims lhs_padding_low lhs_padding_high lhs_base_dilation.toNat) (by infer_var)
let lhs_window_start : ArrayN ℤ r := args.lhsShapeMap 0 output_spatial_index 0
let lhs_window := slice padded_lhs lhs_window_start (lhs_window_start + lhs_window_dimensions) lhs_window_dilations
(outDims:=slice.outDims lhs_window_start (lhs_window_start + lhs_window_dimensions) lhs_window_dilations) (by infer_var)
let dot_product : R :=
dot_general lhs_window rhs
{lhs_batching_dimensions := #[]
rhs_batching_dimensions := #[]
lhs_contracting_dimensions := args.input_spatial_dimensions.1 ++ [args.input_feature_dimension]
rhs_contracting_dimensions := args.kernel_spatial_dimensions.1 ++ [args.kernel_input_feature_dimension]

c1 := by simp
c2 := by simp
c3 := by have := args.c13.1; simp_all
c4 := by have := args.c18.1; simp_all
c5 := by simp
c6 := by simp
c7 := by simp
c8 := by simp
c9 := by intro d; simp at d; have := d.2; aesop
lhs_contracting_dimensions := args.input_spatial_dimensions.1 ++ #[args.input_feature_dimension]
rhs_contracting_dimensions := args.kernel_spatial_dimensions.1 ++ #[args.kernel_input_feature_dimension]

c1 := sorry
c2 := sorry
c3 := sorry
c4 := sorry
c5 := sorry
c6 := sorry
c7 := sorry
c8 := sorry
c9 := sorry
c10 := by sorry
c11 := by simp
c12 := by sorry}
(t:= 0) (outDims := sorry) sorry
(t:=0) (outDims := sorry) sorry

dot_product


#check Array.modify

open Convolution in
def convolution {r} {lhsDims rhsDims outDims : Dims r}
(lhs : TensorIndex lhsDims → R) (rhs : TensorIndex rhsDims → R)
(args : Args lhsDims rhsDims)
(h : Conditions args outDims) : TensorIndex outDims → R :=

if args.feature_group_count > 1 then
let lhsDims' : Dims r := ⟨lhsDims.1.modify args.input_feature_dimension.1 (fun d => d / args.feature_group_count), sorry
let lhses : Fin args.feature_group_count → TensorIndex lhsDims' → R := sorry
--split lhs args.feature_group_count args.input_feature_dimension
let rhsDims' : Dims r := ⟨rhsDims.1.modify args.kernel_output_feature_dimension.1 (fun d => d / args.feature_group_count), sorry
let rhses : Fin args.feature_group_count → TensorIndex rhsDims' → R := sorry
--split rhs args.feature_group_count args.kernel_output_feature_dimension
let results := fun i => convolution (lhses i) (rhses i) args h
concatenate results args.output_feature_dimension
else if args.batch_group_count > 1 then
let lhses := split lhs args.batch_group_count args.input_batch_dimension
let rhses := split rhs args.batch_group_count args.kernel_output_feature_dimension
let results := lhses.zipWith rhses (fun lhs rhs => convolutionCore lhs rhs args h)
concatenate results args.output_feature_dimension
else
convolutionCore lhs rhs args h
Loading

0 comments on commit 0ebfd32

Please sign in to comment.