Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in compilation on GPU #15744

Open
jakevdp opened this issue Aug 5, 2024 · 0 comments
Open

Error in compilation on GPU #15744

jakevdp opened this issue Aug 5, 2024 · 0 comments

Comments

@jakevdp
Copy link
Contributor

jakevdp commented Aug 5, 2024

First reported at jax-ml/jax#22865

This appears to be an issue with one of the fusion passes for the vmapped program.

Here's a more stripped-down repro, with the traceback from a GPU runtime (jax v0.4.31):

import jax
import jax.numpy as jnp

@jax.vmap
def fn(x):
    R1 = jnp.array([[x[0], 0, 0],
                    [0, x[0], 0],
                    [0, 0, x[0]]])
    R2 = jnp.array([[x[0], 0, 0],
                    [0, x[1], 0],
                    [0, 0, x[2]]])
    H = jnp.eye(4)
    H = H.at[:3, :3].set(R2.T)
    pos = H @ jnp.concatenate([x, jnp.array([1.0])])
    return pos, R1

x_v = jnp.zeros((5, 3))
jax.jit(fn).lower(x_v).compile()
Traceback
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-1-11d7cd79e4c8>](https://localhost:8080/#) in <cell line: 18>()
     16 
     17 x_v = jnp.zeros((5, 3))
---> 18 jax.jit(fn).lower(x_v).compile()

[/usr/local/lib/python3.10/dist-packages/jax/_src/stages.py](https://localhost:8080/#) in compile(self, compiler_options)
    673     kw: dict[str, Any] = {"compiler_options": compiler_options}
    674     return Compiled(
--> 675         self._lowering.compile(**kw),  # pytype: disable=wrong-keyword-args
    676         self.args_info,
    677         self.out_tree,

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in compile(self, compiler_options)
   2293   def compile(self, compiler_options=None) -> MeshExecutable:
   2294     if self._executable is None or compiler_options is not None:
-> 2295       executable = UnloadedMeshExecutable.from_hlo(
   2296           self._name, self._hlo, **self.compile_args,
   2297           compiler_options=compiler_options)

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in from_hlo(***failed resolving arguments***)
   2805           break
   2806 
-> 2807     xla_executable = _cached_compilation(
   2808         hlo, name, mesh, spmd_lowering,
   2809         tuple_args, auto_spmd_lowering, allow_prop_to_inputs,

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values, pgle_profiler)
   2619       "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec",
   2620       fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2621     xla_executable = compiler.compile_or_get_cached(
   2622         backend, computation, dev, compile_options, host_callbacks,
   2623         pgle_profiler)

[/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks, pgle_profiler)
    397   else:
    398     log_persistent_cache_miss(module_name, cache_key)
--> 399     return _compile_and_write_cache(
    400         backend,
    401         computation,

[/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in _compile_and_write_cache(backend, computation, compile_options, host_callbacks, module_name, cache_key)
    625 ) -> xc.LoadedExecutable:
    626   start_time = time.monotonic()
--> 627   executable = backend_compile(
    628       backend, computation, compile_options, host_callbacks
    629   )

[/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    334   def wrapper(*args, **kwargs):
    335     with TraceAnnotation(name, **decorator_kwargs):
--> 336       return func(*args, **kwargs)
    337     return wrapper
    338   return wrapper

[/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks)
    265   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    266   # to take in `host_callbacks`
--> 267   return backend.compile(built_c, compile_options=options)
    268 
    269 def compile_or_get_cached(

XlaRuntimeError: INVALID_ARGUMENT: Binary op with incompatible shapes: f32[4,5,4] and f32[5,4,4].

The HLO that is sent to XLA looks like this:

print(jax.jit(fn).lower(x_v).as_text())
output
module @jit_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<5x3xf32> {mhlo.layout_mode = "default"}) -> (tensor<5x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<5x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
    %0 = stablehlo.slice %arg0 [0:5, 0:1] : (tensor<5x3xf32>) -> tensor<5x1xf32>
    %1 = stablehlo.reshape %0 : (tensor<5x1xf32>) -> tensor<5xf32>
    %2 = stablehlo.slice %arg0 [0:5, 0:1] : (tensor<5x3xf32>) -> tensor<5x1xf32>
    %3 = stablehlo.reshape %2 : (tensor<5x1xf32>) -> tensor<5xf32>
    %4 = stablehlo.slice %arg0 [0:5, 0:1] : (tensor<5x3xf32>) -> tensor<5x1xf32>
    %5 = stablehlo.reshape %4 : (tensor<5x1xf32>) -> tensor<5xf32>
    %6 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %7 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %8 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %9 = stablehlo.broadcast_in_dim %7, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %10 = stablehlo.broadcast_in_dim %8, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %11 = stablehlo.concatenate %6, %9, %10, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %12 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %13 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %14 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %15 = stablehlo.broadcast_in_dim %12, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %16 = stablehlo.broadcast_in_dim %14, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %17 = stablehlo.concatenate %15, %13, %16, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %18 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %19 = stablehlo.broadcast_in_dim %cst_5, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %20 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32>
    %21 = stablehlo.broadcast_in_dim %18, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %22 = stablehlo.broadcast_in_dim %19, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %23 = stablehlo.concatenate %21, %22, %20, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32>
    %24 = stablehlo.broadcast_in_dim %11, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32>
    %25 = stablehlo.broadcast_in_dim %17, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32>
    %26 = stablehlo.broadcast_in_dim %23, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32>
    %27 = stablehlo.concatenate %24, %25, %26, dim = 1 : (tensor<5x1x3xf32>, tensor<5x1x3xf32>, tensor<5x1x3xf32>) -> tensor<5x3x3xf32>
    %28 = stablehlo.slice %arg0 [0:5, 0:1] : (tensor<5x3xf32>) -> tensor<5x1xf32>
    %29 = stablehlo.reshape %28 : (tensor<5x1xf32>) -> tensor<5xf32>
    %30 = stablehlo.slice %arg0 [0:5, 1:2] : (tensor<5x3xf32>) -> tensor<5x1xf32>
    %31 = stablehlo.reshape %30 : (tensor<5x1xf32>) -> tensor<5xf32>
    %32 = stablehlo.slice %arg0 [0:5, 2:3] : (tensor<5x3xf32>) -> tensor<5x1xf32>
    %33 = stablehlo.reshape %32 : (tensor<5x1xf32>) -> tensor<5xf32>
    %34 = stablehlo.broadcast_in_dim %29, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32>
    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %35 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %36 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %37 = stablehlo.broadcast_in_dim %35, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %38 = stablehlo.broadcast_in_dim %36, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %39 = stablehlo.concatenate %34, %37, %38, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %40 = stablehlo.broadcast_in_dim %cst_8, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %41 = stablehlo.broadcast_in_dim %31, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32>
    %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %42 = stablehlo.broadcast_in_dim %cst_9, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %43 = stablehlo.broadcast_in_dim %40, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %44 = stablehlo.broadcast_in_dim %42, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %45 = stablehlo.concatenate %43, %41, %44, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32>
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %46 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %47 = stablehlo.broadcast_in_dim %cst_11, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %48 = stablehlo.broadcast_in_dim %33, dims = [0] : (tensor<5xf32>) -> tensor<5x1xf32>
    %49 = stablehlo.broadcast_in_dim %46, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %50 = stablehlo.broadcast_in_dim %47, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %51 = stablehlo.concatenate %49, %50, %48, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x3xf32>
    %52 = stablehlo.broadcast_in_dim %39, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32>
    %53 = stablehlo.broadcast_in_dim %45, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32>
    %54 = stablehlo.broadcast_in_dim %51, dims = [0, 2] : (tensor<5x3xf32>) -> tensor<5x1x3xf32>
    %55 = stablehlo.concatenate %52, %53, %54, dim = 1 : (tensor<5x1x3xf32>, tensor<5x1x3xf32>, tensor<5x1x3xf32>) -> tensor<5x3x3xf32>
    %56 = stablehlo.iota dim = 0 : tensor<4x4xi32>
    %57 = stablehlo.iota dim = 1 : tensor<4x4xi32>
    %c = stablehlo.constant dense<0> : tensor<i32>
    %58 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<4x4xi32>
    %59 = stablehlo.add %56, %58 : tensor<4x4xi32>
    %60 = stablehlo.compare  EQ, %59, %57,  SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1>
    %61 = stablehlo.convert %60 : (tensor<4x4xi1>) -> tensor<4x4xf32>
    %62 = stablehlo.transpose %55, dims = [0, 2, 1] : (tensor<5x3x3xf32>) -> tensor<5x3x3xf32>
    %c_12 = stablehlo.constant dense<0> : tensor<i32>
    %63 = stablehlo.broadcast_in_dim %c_12, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %c_13 = stablehlo.constant dense<0> : tensor<i32>
    %64 = stablehlo.broadcast_in_dim %c_13, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %65 = stablehlo.concatenate %63, %64, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %66 = stablehlo.broadcast_in_dim %61, dims = [1, 2] : (tensor<4x4xf32>) -> tensor<5x4x4xf32>
    %67 = "stablehlo.scatter"(%66, %65, %62) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 1, 2], scatter_dims_to_operand_dims = [1, 2]>, unique_indices = true}> ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      stablehlo.return %arg2 : tensor<f32>
    }) : (tensor<5x4x4xf32>, tensor<2xi32>, tensor<5x3x3xf32>) -> tensor<5x4x4xf32>
    %68 = stablehlo.broadcast_in_dim %cst, dims = [1] : (tensor<1xf32>) -> tensor<5x1xf32>
    %69 = stablehlo.concatenate %arg0, %68, dim = 1 : (tensor<5x3xf32>, tensor<5x1xf32>) -> tensor<5x4xf32>
    %70 = stablehlo.dot_general %67, %69, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<5x4x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
    return %70, %27 : tensor<5x4xf32>, tensor<5x3x3xf32>
  }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant