Skip to content

Commit

Permalink
Merge pull request #902 from xmos/multi-softmax
Browse files Browse the repository at this point in the history
Add batched softmax operator
  • Loading branch information
panickal-xmos authored Jul 11, 2024
2 parents 3c0a431 + 1e0cd80 commit 175ca98
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 4 deletions.
40 changes: 40 additions & 0 deletions integration_tests/models/8x8/test_softmax/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

import numpy as np
import tensorflow as tf

BATCH_SIZE = 100
input_shape = (2,)
input_data = tf.keras.Input(shape=input_shape, batch_size=BATCH_SIZE)
print(input_data.shape)

# Apply the Softmax layer
output = tf.keras.layers.Softmax()(input_data)
print(output.shape)

# Create the model
model = tf.keras.Model(inputs=input_data, outputs=output)

# Convert the model to TensorFlow Lite format
converter = tf.lite.TFLiteConverter.from_keras_model(model)

# Optional: Define a representative dataset for quantization (not required for this simple model)
def representative_dataset_gen():
for _ in range(100):
yield [np.random.uniform(low=-1., high=1., size=(BATCH_SIZE,) + input_shape).astype(np.float32)]

# Optional: Set optimization options (can be commented out if not needed)
converter.representative_dataset = representative_dataset_gen
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

# Convert the model
tflite_model = converter.convert()

# Save the TFLite model
model_name = 'test_softmax_10.tflite'
with open(model_name, 'wb') as f:
f.write(tflite_model)

print(f"TFLite model saved as {model_name}")
1 change: 1 addition & 0 deletions integration_tests/models/8x8/test_softmax/params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MAX_ABS_ERROR: 0.0
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion third_party/lib_nn
10 changes: 10 additions & 0 deletions xformer/IR/XCoreOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ def XC_SoftmaxOp : XC_Op<"softmax", [Pure]> {
let results = (outs TensorOf<[QI8]> : $output);
}

def XC_BatchedSoftmaxOp : XC_Op<"batched_softmax", [Pure]> {
let summary = "Batched softmax op";

let description = [{Batched softmax op.}];

let arguments = (ins TensorOf<[QI8]> : $input, TensorOf<[F32]> : $lut);

let results = (outs TensorOf<[QI8]> : $output);
}

def XC_Conv2DV2Op : XC_Op<"conv2d_v2", [Pure]> {
let summary = "Conv2D V2 op";

Expand Down
2 changes: 2 additions & 0 deletions xformer/Transforms/TranslateToCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ std::vector<uint8_t> Beta_TransposeConvF32Op::buildCustomOptions() {
std::vector<uint8_t> Beta_FcF32Op::buildCustomOptions() { return {}; }
std::vector<uint8_t> LookupOp::buildCustomOptions() { return {}; }
std::vector<uint8_t> SoftmaxOp::buildCustomOptions() { return {}; }
std::vector<uint8_t> BatchedSoftmaxOp::buildCustomOptions() { return {}; }

std::vector<uint8_t> AddOp::buildCustomOptions() {
flexbuffers::Builder fbb;
Expand Down Expand Up @@ -246,6 +247,7 @@ void TranslateToCustomOp::runOnOperation() {
patterns.insert<RewriteToCustomOp<LoadFlashOp>>(ctx);
patterns.insert<RewriteToCustomOp<LookupOp>>(ctx);
patterns.insert<RewriteToCustomOp<SoftmaxOp>>(ctx);
patterns.insert<RewriteToCustomOp<BatchedSoftmaxOp>>(ctx);
patterns.insert<RewriteToCustomOp<MulOp>>(ctx);
patterns.insert<RewriteToCustomOp<Pad3To4Op>>(ctx);
patterns.insert<RewriteToCustomOp<SliceOp>>(ctx);
Expand Down
12 changes: 10 additions & 2 deletions xformer/Transforms/XCPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,22 @@ def getExpLookupF32
def isSingleSegment
: Constraint<CPred<"$0.getType().cast<ShapedType>().getRank() == 2">>;

def isSingleBatch : Constraint<CPred<"$0.getType().cast<ShapedType>().getDimSize(0) == 1">>;
def isMultiBatch : Constraint<CPred<"$0.getType().cast<ShapedType>().getDimSize(0) != 1">>;

def betaIsOne : Constraint<CPred<"$0.getValue().convertToFloat() == 1.0">>;

// Softmax -> XC_SoftmaxOp
// Softmax -> XC_SoftmaxOp if single batch else XC_BatchedSoftmaxOp
def:
Pat<(TFL_SoftmaxOp
: $output TensorOf<[QI8]>:$input, $beta),
(XC_SoftmaxOp $input, (Arith_ConstantOp (getExpLookupF32
$output))), [(betaIsOne $beta), (isSingleSegment $input)]>;
$output))), [(betaIsOne $beta), (isSingleSegment $input), (isSingleBatch $input)]>;
def:
Pat<(TFL_SoftmaxOp
: $output TensorOf<[QI8]>:$input, $beta),
(XC_BatchedSoftmaxOp $input, (Arith_ConstantOp (getExpLookupF32
$output))), [(betaIsOne $beta), (isSingleSegment $input), (isMultiBatch $input)]>;

// Beta float activation lookup
def getActivationType
Expand Down
1 change: 1 addition & 0 deletions xformer/lib_tflite_micro.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ filegroup(
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_load_from_flash.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_lookup.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_softmax.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_batched_softmax.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_add.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_pad.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_concat.cc",
Expand Down

0 comments on commit 175ca98

Please sign in to comment.