Skip to content

Commit

Permalink
Merge pull request #184 from pperle/fix-float16-conversion
Browse files Browse the repository at this point in the history
Fix for Float16 Conversion in ByteConversionUtils
  • Loading branch information
PaulTR authored Jan 18, 2024
2 parents 0535ea2 + 0023cb0 commit 5255c97
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 33 deletions.
116 changes: 83 additions & 33 deletions lib/src/util/byte_conversion_utils.dart
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import 'dart:typed_data';

import 'package:tflite_flutter/tflite_flutter.dart';

class ByteConvertionError extends ArgumentError {
ByteConvertionError({
class ByteConversionError extends ArgumentError {
ByteConversionError({
required this.input,
required this.tensorType,
}) : super(
Expand Down Expand Up @@ -52,19 +52,13 @@ class ByteConversionUtils {
static Uint8List _convertElementToBytes(Object o, TensorType tensorType) {
// Float32
if (tensorType.value == TfLiteType.kTfLiteFloat32) {
if (o is double) {
var buffer = Uint8List(4).buffer;
var bdata = ByteData.view(buffer);
bdata.setFloat32(0, o, Endian.little);
return buffer.asUint8List();
}
if (o is int) {
if (o is num) {
var buffer = Uint8List(4).buffer;
var bdata = ByteData.view(buffer);
bdata.setFloat32(0, o.toDouble(), Endian.little);
return buffer.asUint8List();
}
throw ByteConvertionError(
throw ByteConversionError(
input: o,
tensorType: tensorType,
);
Expand All @@ -78,7 +72,7 @@ class ByteConversionUtils {
bdata.setUint8(0, o);
return buffer.asUint8List();
}
throw ByteConvertionError(
throw ByteConversionError(
input: o,
tensorType: tensorType,
);
Expand All @@ -92,7 +86,7 @@ class ByteConversionUtils {
bdata.setInt32(0, o, Endian.little);
return buffer.asUint8List();
}
throw ByteConvertionError(
throw ByteConversionError(
input: o,
tensorType: tensorType,
);
Expand All @@ -106,7 +100,7 @@ class ByteConversionUtils {
bdata.setInt64(0, o, Endian.big);
return buffer.asUint8List();
}
throw ByteConvertionError(
throw ByteConversionError(
input: o,
tensorType: tensorType,
);
Expand All @@ -120,27 +114,18 @@ class ByteConversionUtils {
bdata.setInt16(0, o, Endian.little);
return buffer.asUint8List();
}
throw ByteConvertionError(
throw ByteConversionError(
input: o,
tensorType: tensorType,
);
}

// Float16
if (tensorType.value == TfLiteType.kTfLiteFloat16) {
if (o is double) {
var buffer = Uint8List(4).buffer;
var bdata = ByteData.view(buffer);
bdata.setFloat32(0, o, Endian.little);
return buffer.asUint8List().sublist(0, 2);
if (o is num) {
return ByteConversionUtils.floatToFloat16Bytes(o.toDouble());
}
if (o is int) {
var buffer = Uint8List(4).buffer;
var bdata = ByteData.view(buffer);
bdata.setFloat32(0, o.toDouble(), Endian.little);
return buffer.asUint8List().sublist(0, 2);
}
throw ByteConvertionError(
throw ByteConversionError(
input: o,
tensorType: tensorType,
);
Expand All @@ -154,7 +139,7 @@ class ByteConversionUtils {
bdata.setInt8(0, o);
return buffer.asUint8List();
}
throw ByteConvertionError(
throw ByteConversionError(
input: o,
tensorType: tensorType,
);
Expand Down Expand Up @@ -185,13 +170,10 @@ class ByteConversionUtils {
}
return list.reshape<int>(shape);
} else if (tensorType.value == TfLiteType.kTfLiteFloat16) {
Uint8List list32 = Uint8List(bytes.length * 2);
for (var i = 0; i < bytes.length; i += 2) {
list32[i] = bytes[i];
list32[i + 1] = bytes[i + 1];
}
for (var i = 0; i < list32.length; i += 4) {
list.add(ByteData.view(list32.buffer).getFloat32(i, Endian.little));
int float16 = ByteData.view(bytes.buffer).getUint16(i, Endian.little);
double float32 = _float16ToFloat32(float16);
list.add(float32);
}
return list.reshape<double>(shape);
} else if (tensorType.value == TfLiteType.kTfLiteInt8) {
Expand All @@ -212,4 +194,72 @@ class ByteConversionUtils {
}
throw UnsupportedError("$tensorType is not Supported.");
}

static Uint8List floatToFloat16Bytes(double value) {
int float16 = _float32ToFloat16(value);
final ByteData byteDataBuffer = ByteData(2)
..setUint16(0, float16, Endian.little);
return Uint8List.fromList(byteDataBuffer.buffer.asUint8List());
}

static int _float32ToFloat16(double value) {
final Float32List float32Buffer = Float32List(1);
final Uint32List int32Buffer = float32Buffer.buffer.asUint32List();

float32Buffer[0] = value;
int f = int32Buffer[0];
int sign = (f >> 16) & 0x8000;
int exponent = (f >> 23) & 0xFF;
int mantissa = f & 0x007FFFFF;

if (exponent == 0) return sign;
if (exponent == 255) return sign | 0x7C00;

exponent = exponent - 127 + 15;
if (exponent >= 31) return sign | 0x7C00;
if (exponent <= 0) return sign;

// Implement rounding
int roundMantissa = (mantissa >> 13) + ((mantissa >> 12) & 1);

return sign | (exponent << 10) | roundMantissa;
}

static double bytesToFloat32(Uint8List bytes) {
final ByteData byteDataBuffer = ByteData(2);
int float16 = byteDataBuffer.buffer
.asUint8List()
.buffer
.asByteData()
.getUint16(0, Endian.little);
return _float16ToFloat32(float16);
}

static double _float16ToFloat32(int value) {
final Float32List float32Buffer = Float32List(1);
final Uint32List int32Buffer = float32Buffer.buffer.asUint32List();

int sign = (value & 0x8000) << 16;
int exponent = (value & 0x7C00) >> 10;
int mantissa = (value & 0x03FF) << 13;

if (exponent == 0) {
if (mantissa == 0) return sign == 0 ? 0.0 : -0.0;
while ((mantissa & 0x00800000) == 0) {
mantissa <<= 1;
exponent -= 1;
}
exponent += 1;
} else if (exponent == 31) {
if (mantissa == 0) {
return sign == 0 ? double.infinity : double.negativeInfinity;
}
return double.nan;
}

exponent = exponent - 15 + 127;
int32Buffer[0] = sign | (exponent << 23) | mantissa;

return float32Buffer[0];
}
}
118 changes: 118 additions & 0 deletions test/util/byte_conversion_utils_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import 'package:flutter_test/flutter_test.dart';
import 'package:tflite_flutter/tflite_flutter.dart';

void main() {
group('convertObjectToBytes and convertBytesToObject', () {
test('TensorType.float32', () async {
var bytes =
ByteConversionUtils.convertObjectToBytes(1.1, TensorType.float32);
expect(bytes, [205, 204, 140, 63]);
var object = ByteConversionUtils.convertBytesToObject(
bytes, TensorType.float32, [1]) as List;
expect(object[0], closeTo(1.1, 0.0001));
});

test('TensorType.float16', () async {
var bytes =
ByteConversionUtils.convertObjectToBytes(1.1, TensorType.float16);
expect(bytes, [102, 60]);
var object = ByteConversionUtils.convertBytesToObject(
bytes, TensorType.float16, [1]) as List;
expect(object[0], closeTo(1.1, 0.001));

/*
```python
import tensorflow as tf
for value in [1.2, 1.3, 1.4, 1.5]:
value_tf = tf.constant(value, dtype=tf.float16)
byte_data_tf = tf.io.serialize_tensor(value_tf)
last_two_bytes_tf = byte_data_tf.numpy()[-2:] # Get the last two bytes
print([x for x in last_two_bytes_tf], value_tf.numpy().item()) # [0, 62] 1.5
```
[205, 60] 1.2001953125
[51, 61] 1.2998046875
[154, 61] 1.400390625
[0, 62] 1.5
*/

List<double> values = [1.2, 1.3, 1.4, 1.5];
List<List<int>> bytesList = [
[205, 60],
[51, 61],
[154, 61],
[0, 62]
];

for (int i = 0; i < values.length; i++) {
var bytes = ByteConversionUtils.convertObjectToBytes(
values[i], TensorType.float16);
expect(bytes, bytesList[i]);
var object = ByteConversionUtils.convertBytesToObject(
bytes, TensorType.float16, [1]) as List;
expect(object[0], closeTo(values[i], 0.001));
}
});

test('TensorType.int64', () async {
var bytes = ByteConversionUtils.convertObjectToBytes(1, TensorType.int64);
expect(bytes, [0, 0, 0, 0, 0, 0, 0, 1]);
var object =
ByteConversionUtils.convertBytesToObject(bytes, TensorType.int64, [1])
as List;
expect(object[0], 1);
});

test('TensorType.int32', () async {
var bytes = ByteConversionUtils.convertObjectToBytes(1, TensorType.int32);
expect(bytes, [1, 0, 0, 0]);
var object =
ByteConversionUtils.convertBytesToObject(bytes, TensorType.int32, [1])
as List;
expect(object[0], 1);
});

test('TensorType.int16', () async {
var bytes = ByteConversionUtils.convertObjectToBytes(1, TensorType.int16);
expect(bytes, [1, 0]);
var object =
ByteConversionUtils.convertBytesToObject(bytes, TensorType.int16, [1])
as List;
expect(object[0], 1);
});

test('TensorType.int8', () async {
var bytes = ByteConversionUtils.convertObjectToBytes(1, TensorType.int8);
expect(bytes, [1]);
var object =
ByteConversionUtils.convertBytesToObject(bytes, TensorType.int8, [1])
as List;
expect(object[0], 1);
});

test('TensorType.uint8', () async {
var bytes = ByteConversionUtils.convertObjectToBytes(1, TensorType.uint8);
expect(bytes, [1]);
var object =
ByteConversionUtils.convertBytesToObject(bytes, TensorType.uint8, [1])
as List;
expect(object[0], 1);
});
});

group('errors', () {
test('float to int8', () async {
expect(
() => ByteConversionUtils.convertObjectToBytes(1.1, TensorType.int8),
throwsA(isA<ByteConversionError>()));
});

test('float to None', () async {
expect(
() =>
ByteConversionUtils.convertObjectToBytes(1.1, TensorType.noType),
throwsA(isA<ArgumentError>()));
});
});
}

0 comments on commit 5255c97

Please sign in to comment.