Skip to content

Commit

Permalink
[Pallas:TPU] Fix lowering of convert_element_type(float32) -> bool.
Browse files Browse the repository at this point in the history
The original implementation doesn't handle 0 < |x| < 1 correctly. It used to be convert_element_type(x, int32) ==> 0 ==> convert_element_type(0, bool) ==> false, which is different from XLA semantics: convert_element_type(x, bool) ==> true.

Hypothesis library seems to draw values of 0.5.

While I'm here, remove some stale skip conditions. They are fixed due to recent Pallas/Mosaic changes.

PiperOrigin-RevId: 680779747
  • Loading branch information
Google-ML-Automation committed Oct 1, 2024
1 parent ce21a12 commit 04b0aba
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
36 changes: 23 additions & 13 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,18 +1617,20 @@ def _convert_helper(x, *, to_dtype):
x = x.astype(jnp.float32)
return x.astype(to_dtype)
if jnp.issubdtype(from_dtype, jnp.floating):
if jnp.issubdtype(to_dtype, jnp.signedinteger):
if jnp.issubdtype(to_dtype, np.dtype("bool")):
# Cast to float32 rather than int32 because 0 < |x| < 1 rounds to 0,
# leading to false in bool. However, convert_element_type(x, bool)
# returns true. It's handled correctly when x is float32.
x = x.astype(jnp.float32)
elif jnp.issubdtype(to_dtype, jnp.signedinteger):
if from_dtype.itemsize < 4:
x = x.astype(jnp.float32)
if to_dtype.itemsize < 4:
# Need to clip values to match XLA
minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max
x = jnp.clip(x, minval, maxval)
return x.astype(jnp.int32).astype(to_dtype)
return x.astype(to_dtype)
elif jnp.issubdtype(to_dtype, np.dtype("bool")):
x = x.astype(jnp.int32)
return x.astype(jnp.float32)
return x.astype(to_dtype)
raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}")

def _convert_element_type_lowering_rule(
Expand Down Expand Up @@ -1675,24 +1677,32 @@ def _convert_element_type_lowering_rule(
):
return arith.extui(out_type, x)
elif (
jnp.issubdtype(old_dtype, jnp.integer)
(
(is_float := jnp.issubdtype(old_dtype, jnp.floating))
or jnp.issubdtype(old_dtype, jnp.integer)
)
and new_dtype == jnp.bool_
and old_dtype.itemsize == 4
):
pred = _cmpi_lowering_types[lax.ne_p]
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
# Lower float32 or (u)int32 -> bool to cmp neq %in, 0
const_type = _dtype_to_ir_type(old_dtype)
const_zero = ir.IntegerAttr.get(const_type, 0)
if is_float:
pred = _cmpf_lowering_types[lax.ne_p]
const_zero = ir.FloatAttr.get(const_type, 0)
op = arith.CmpFOp
else:
pred = _cmpi_lowering_types[lax.ne_p]
const_zero = ir.IntegerAttr.get(const_type, 0)
op = arith.CmpIOp
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
if in_aval.shape:
in_type = aval_to_ir_type(in_aval, is_kernel_boundary=False)
vector_zeros = arith.ConstantOp(
in_type,
ir.DenseElementsAttr.get_splat(in_type, const_zero),
)
return arith.CmpIOp(predicate, x, vector_zeros).result
return arith.CmpIOp(
predicate, x, arith.ConstantOp(const_type, const_zero)
).result
return op(predicate, x, vector_zeros).result
return op(predicate, x, arith.ConstantOp(const_type, const_zero)).result
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
multiple_results=False)(ctx, x)

Expand Down
8 changes: 0 additions & 8 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,6 @@ def test_cast(self, from_dtype, to_dtype, data):
self.skipTest("Not supported: bad canonicalization")
if from_dtype == "bool" and to_dtype in {"int16", "int8"}:
self.skipTest("Not supported: cannot extend to sub-32 bit types")
if from_dtype in {"bfloat16", "float32"} and to_dtype == "bool":
self.skipTest("Not supported: unsupported relayout")
if from_dtype == "bool" and to_dtype in {"int32", "bfloat16", "float32"}:
self.skipTest("Not supported: unsupported relayout")
if from_dtype in {"int16", "int8"} and to_dtype == "bool":
self.skipTest("Not supported: cannot truncate from sub-32 bit types")
if from_dtype in {"int16", "int8"} and to_dtype == "bool":
self.skipTest("Not supported: cannot truncate from sub-32 bit types")
if jtu.test_device_matches(["gpu"]):
if (from_dtype in {"bfloat16", "float32"} and
to_dtype in {"int8", "int16", "int32"}):
Expand Down

0 comments on commit 04b0aba

Please sign in to comment.