Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK authored Nov 13, 2024
1 parent 830b978 commit 37487a8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 35 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "95c2d798148f12565dd4c9ddc753d196e47f230f"
LLVM_COMMIT = "01d233ff403823389f8480897e41aea84ecbb3d3"

LLVM_SHA256 = "f11e5bbf17d50ff31addc9e1737d64e64a144fce928166de5878c72a1efcf9b4"
LLVM_SHA256 = "283a1d9c251d5028ae78f7a659816588fedaa6a8ba5733bee7249724fb3ed2bc"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
95c2d798148f12565dd4c9ddc753d196e47f230f
01d233ff403823389f8480897e41aea84ecbb3d3
44 changes: 12 additions & 32 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,42 +155,22 @@ Element Tensor::get(const Index &index) const {
// integer variants.
if (isSupportedIntegerType(elementType)) {
IntegerType intTy = cast<IntegerType>(elementType);

if (elementType.isSignlessInteger(2) || elementType.isSignlessInteger(4) ||
elementType.isSignlessInteger(8)) {
auto elementData = reinterpret_cast<const int8_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
} else if (elementType.isSignlessInteger(16)) {
auto elementData = reinterpret_cast<const int16_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
} else if (elementType.isSignlessInteger(32)) {
auto elementData = reinterpret_cast<const int32_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
} else if (elementType.isSignlessInteger(64)) {
auto elementData = reinterpret_cast<const int64_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
} else if (elementType.isUnsignedInteger(2) ||
elementType.isUnsignedInteger(4) ||
elementType.isUnsignedInteger(8)) {
const unsigned int bitwidth = intTy.getWidth();
if (bitwidth == 2 || bitwidth == 4 || bitwidth == 8) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
} else if (elementType.isUnsignedInteger(16)) {
// Set implicitTrunc to ignore garbage bits on 2-bit and 4-bit types.
const bool implicitTrunc = bitwidth == 2 || bitwidth == 4;
return Element(elementType, APInt(bitwidth, *elementData,
/*isSigned=*/false, implicitTrunc));
} else if (bitwidth == 16) {
auto elementData = reinterpret_cast<const uint16_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
} else if (elementType.isUnsignedInteger(32)) {
return Element(elementType, APInt(bitwidth, *elementData));
} else if (bitwidth == 32) {
auto elementData = reinterpret_cast<const uint32_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
} else if (elementType.isUnsignedInteger(64)) {
return Element(elementType, APInt(bitwidth, *elementData));
} else if (bitwidth == 64) {
auto elementData = reinterpret_cast<const uint64_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
return Element(elementType, APInt(bitwidth, *elementData));
}
}

Expand Down

0 comments on commit 37487a8

Please sign in to comment.