diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 6944fc8535af..ae880190ad46 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2057,7 +2057,12 @@ def cumsum( return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name) -def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str = "int64"): +def multinomial_from_uniform( + prob: Tensor, + uniform_sample: Tensor, + sample_indices: Optional[Tensor] = None, + dtype: str = "int64", +): """Returns a tensor where each row contains the index sampled from the multinomial probability distribution located in the corresponding row of tensor prob. @@ -2075,13 +2080,25 @@ def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str = The sum of values in each row is 1, forming a valid distribution. uniform_sample : Tensor - The uniformly sampled 2-D tensor with the shape (batch, 1). + The uniformly sampled 2-D tensor with the shape (n, 1). Values range from 0 to 1, indicating probabilities sampled uniformly. + sample_indices : Optional[Tensor] + The 2-D tensor with the shape [n, 1], which indicates the specific + probability distribution to sample from. The value of sample_indices[i] + determines that the ith token should be sampled from the sample_indices[i]th + probability distribution. For instance, if there are 3 distinct probability + distributions and the requirement is to sample 2, 3, and 4 tokens from each, + then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2]. + + dtype : str + The data type of output tensor. + + Returns ------- result : Tensor - The computed tensor with shape (batch, 1). + The computed tensor with shape (n, 1). Examples -------- @@ -2089,29 +2106,52 @@ def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str = prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]] usample = [[0.4], [0.9]] + sample_indices = [[0], [1]] multinomial_from_uniform(prob, usample) -> [[1], [2]] + multinomial_from_uniform(prob, usample, sample_indices) + -> [[1], [2]] """ prob_dtype = prob.dtype sample_dtype = uniform_sample.dtype - batch = prob.shape[0] + out_batch = uniform_sample.shape[0] + + if sample_indices is not None: + assert ( + sample_indices.shape == uniform_sample.shape + ), "The shape of sample_indices must match the shape of uniform_sample." + else: + assert ( + prob.shape[0] == uniform_sample.shape[0] + ), "Number of samples must match the number of probability distributions." + sample_indices = Tensor.from_const(np.arange(out_batch).reshape(out_batch, 1)) + + sample_indices_dtype = sample_indices.dtype @T.prim_func(private=True) - def _get_sample_index(A: T.handle, B: T.handle, C: T.handle): + def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): batch, vocab_size = T.int64(), T.int64() prob = T.match_buffer(A, (batch, vocab_size), prob_dtype) - usample = T.match_buffer(B, (batch, 1), sample_dtype) - output_index = T.match_buffer(C, (batch, 1), dtype) + out_batch = T.int64() + usample = T.match_buffer(B, (out_batch, 1), sample_dtype) + sample_indices = T.match_buffer(C, (out_batch, 1), sample_indices_dtype) + output_index = T.match_buffer(D, (out_batch, 1), dtype) - for ax0, ax1 in T.grid(batch, vocab_size): + for ax0, ax1 in T.grid(out_batch, vocab_size): with T.block("T_get_sample_index"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.writes(output_index[v_ax0, 0]) - if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 1 == vocab_size: + if ( + usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1] + or v_ax1 + 1 == vocab_size + ): if v_ax1 == 0: output_index[v_ax0, 0] = 0 - elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]: + elif ( + usample[v_ax0, T.int64(0)] + >= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1] + ): output_index[v_ax0, 0] = v_ax1 cumsum_prob = cumsum(prob, axis=1, exclusive=False) @@ -2119,13 +2159,18 @@ def _get_sample_index(A: T.handle, B: T.handle, C: T.handle): return tensor_ir_op( _get_sample_index, "get_sample_index", - args=[cumsum_prob, uniform_sample], - out=Tensor.placeholder([batch, 1], dtype), + args=[cumsum_prob, uniform_sample, sample_indices], + out=Tensor.placeholder([out_batch, 1], dtype), ) def sample_top_p_top_k_from_sorted_prob( - sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor, uniform_sample: Tensor + sorted_prob: Tensor, + sorted_index: Tensor, + top_p: Tensor, + top_k: Tensor, + uniform_sample: Tensor, + sample_indices: Optional[Tensor] = None, ): """Samples indices from a sorted probability tensor based on top_p and top_k criteria. @@ -2152,12 +2197,20 @@ def sample_top_p_top_k_from_sorted_prob( to consider for top-k sampling. uniform_sample : Tensor - Uniformly sampled values with shape (batch, 1) are used to select the output indices. + Uniformly sampled values with shape (n, 1) are used to select the output indices. + + sample_indices : Optional[Tensor] + The 2-D tensor with the shape [n, 1], which indicates the specific + probability distribution to sample from. The value of sample_indices[i] + determines that the ith token should be sampled from the sample_indices[i]th + probability distribution. For instance, if there are 3 distinct probability + distributions and the requirement is to sample 2, 3, and 4 tokens from each, + then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2]. Returns ------- result : Tensor - The selected indices with shape (batch, 1). + The selected indices with shape (n, 1). Examples -------- @@ -2172,15 +2225,31 @@ def sample_top_p_top_k_from_sorted_prob( top_p = [[0.6],[0.9]] top_k = [[3],[2]] uniform_sample = [[0.5], [0.6]] + sample_indices = [[0], [1]] sample_top_p_top_k_from_sorted_prob( - sorted_prob, sorted_index,top_p, top_k, uniform_sample) + sorted_prob, sorted_index,top_p, top_k, uniform_sample, sample_indices) -> [2, 0] """ prob_dtype = sorted_prob.dtype index_dtype = sorted_index.dtype - batch = sorted_prob.shape[0] + prob_batch = sorted_prob.shape[0] + out_batch = uniform_sample.shape[0] + + if sample_indices is not None: + assert ( + sample_indices.shape == uniform_sample.shape + ), "The shape of sample_indices must match the shape of uniform_sample." + else: + assert ( + sorted_prob.shape[0] == uniform_sample.shape[0] + ), "Number of samples must match the number of probability distributions." + sample_indices = Tensor.from_const( + np.arange(out_batch).reshape(out_batch, 1).astype(np.int64) + ) + print("sample_indices: ", sample_indices) + sample_indices_dtype = sample_indices.dtype def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j): return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0]) @@ -2204,27 +2273,34 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1] @T.prim_func(private=True) - def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle): + def _get_index_from_sorted( + A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle + ): batch, vocab_size = T.int64(), T.int64() + out_batch = T.int64() cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype) - renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype) - usample = T.match_buffer(C, (batch, 1), prob_dtype) - indices = T.match_buffer(D, (batch, vocab_size), index_dtype) - output_index = T.match_buffer(E, (batch, 1), index_dtype) + indices = T.match_buffer(B, (batch, vocab_size), index_dtype) + renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype) + usample = T.match_buffer(D, (out_batch, 1), prob_dtype) + sample_indices = T.match_buffer(E, (out_batch, 1), sample_indices_dtype) + output_index = T.match_buffer(F, (out_batch, 1), index_dtype) - for ax0, ax1 in T.grid(batch, vocab_size): + for ax0, ax1 in T.grid(out_batch, vocab_size): with T.block("T_get_index_from_sorted"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.writes(output_index[v_ax0, 0]) if ( - usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0] + usample[v_ax0, T.int64(0)] + < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] + / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + 1 == vocab_size ): if v_ax1 == 0: output_index[v_ax0, 0] = indices[v_ax0, 0] elif ( usample[v_ax0, T.int64(0)] - >= cumsum_sorted[v_ax0, v_ax1 - 1] / renorm_prob[v_ax0, 0] + >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1] + / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] ): output_index[v_ax0, 0] = indices[v_ax0, v_ax1] @@ -2235,7 +2311,7 @@ def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E "get_renorm_prob", args=[cumsum_sorted, top_p, top_k], out=Tensor.placeholder( - [batch, 1], + [prob_batch, 1], prob_dtype, ), ) @@ -2243,8 +2319,8 @@ def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E out_index_in_sorted = tensor_ir_op( _get_index_from_sorted, "get_index_from_sorted", - args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index], - out=Tensor.placeholder([batch, 1], index_dtype), + args=[cumsum_sorted, sorted_index, renorm_prob, uniform_sample, sample_indices], + out=Tensor.placeholder([out_batch, 1], index_dtype), ) return out_index_in_sorted @@ -2293,7 +2369,7 @@ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T. top_k = T.match_buffer(D, (batch, 1), top_k_dtype) cutoff = T.match_buffer(E, (batch, 1), prob_dtype) for ax0, ax1 in T.grid(batch, vocab_size): - with T.block("T_get_renorm_prob"): + with T.block("T_get_renorm_cutoff"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0: cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0] diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 3457989a551f..0d579163cdd0 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -846,34 +846,36 @@ def test(self): @tvm.testing.requires_gpu def test_multinomial_from_uniform(): - prob_shape = (4, 5) - sample_shape = (4, 1) + prob_shape = (3, 5) + sample_shape = (6, 1) class Model(Module): - def foo(self, prob: Tensor, uniform_sample: Tensor): - z0 = op.multinomial_from_uniform(prob, uniform_sample) + def foo(self, prob: Tensor, uniform_sample: Tensor, sample_indices: Tensor): + z0 = op.multinomial_from_uniform(prob, uniform_sample, sample_indices) return z0 # fmt: off @I.ir_module class Expected: @T.prim_func(private=True) - def get_sample_index(A: T.handle, B: T.handle, C: T.handle): + def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): batch, vocab_size = T.int64(), T.int64() prob = T.match_buffer(A, (batch, vocab_size)) - usample = T.match_buffer(B, (batch, 1)) - output_index = T.match_buffer(C, (batch, 1), "int64") + out_batch = T.int64() + usample = T.match_buffer(B, (out_batch, 1)) + sample_indices = T.match_buffer(C, (out_batch, 1), "int64") + output_index = T.match_buffer(D, (out_batch, 1), "int64") # with T.block("root"): - for ax0, ax1 in T.grid(batch, vocab_size): + for ax0, ax1 in T.grid(out_batch, vocab_size): with T.block("T_get_sample_index"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(usample[v_ax0, T.int64(0)], prob[v_ax0, v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)]) + T.reads(usample[v_ax0, T.int64(0)], prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)]) T.writes(output_index[v_ax0, 0]) - if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + T.int64(1) == vocab_size: + if usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1] or v_ax1 + T.int64(1) == vocab_size: if v_ax1 == T.int64(0): output_index[v_ax0, 0] = T.int64(0) else: - if usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - T.int64(1)]: + if usample[v_ax0, T.int64(0)] >= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)]: output_index[v_ax0, 0] = v_ax1 @R.function @@ -886,13 +888,13 @@ def _initialize_effect() -> R.Tuple(R.Object): return gv @R.function - def foo(prob: R.Tensor((4, 5), dtype="float32"), uniform_sample: R.Tensor((4, 1), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((4, 1), dtype="int64"), R.Tuple(R.Object)): - R.func_attr({"num_input": 3}) + def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1), dtype="float32"), sample_indices: R.Tensor((6, 1), dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((6, 1), dtype="int64"), R.Tuple(R.Object)): + R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): - cumsum: R.Tensor((4, 5), dtype="float32") = R.cumsum(prob, axis=1, dtype="void", exclusive=False) - lv1 = R.call_tir(cls.get_sample_index, (cumsum, uniform_sample), out_sinfo=R.Tensor((4, 1), dtype="int64")) - gv1: R.Tuple(R.Tensor((4, 1), dtype="int64"), R.Tuple(R.Object)) = lv1, (_io,) + cumsum: R.Tensor((3, 5), dtype="float32") = R.cumsum(prob, axis=1, dtype="void", exclusive=0) + lv1 = R.call_tir(cls.get_sample_index, (cumsum, uniform_sample, sample_indices), out_sinfo=R.Tensor((6, 1), dtype="int64")) + gv1: R.Tuple(R.Tensor((6, 1), dtype="int64"), R.Tuple(R.Object)) = lv1, (_io,) R.output(gv1) return gv1 # fmt: on @@ -903,6 +905,7 @@ def foo(prob: R.Tensor((4, 5), dtype="float32"), uniform_sample: R.Tensor((4, 1) "foo": { "prob": spec.Tensor(prob_shape, "float32"), "uniform_sample": spec.Tensor(sample_shape, "float32"), + "sample_indices": spec.Tensor(sample_shape, "int64"), } }, debug=True, @@ -924,62 +927,59 @@ def foo(prob: R.Tensor((4, 5), dtype="float32"), uniform_sample: R.Tensor((4, 1) np_prob = np_rand / np_rand.sum(axis=1, keepdims=True) nd_prob = tvm.nd.array(np_prob, dev) # special sample to get deterministic results - nd_sample = tvm.nd.array(np.array([[1], [0], [0], [1]]).astype(np.float32), dev) - inputs = [nd_prob, nd_sample, effects] + nd_sample = tvm.nd.array(np.array([[1], [0], [1], [1], [0], [1]]).astype(np.float32), dev) + nd_sample_indices = tvm.nd.array(np.array([[0], [1], [1], [2], [2], [2]]).astype(np.int64), dev) + inputs = [nd_prob, nd_sample, nd_sample_indices, effects] res = vm["foo"](*inputs) - tvm.testing.assert_allclose(res[0].numpy(), np.array([[4], [0], [0], [4]]).astype(np.int64)) + tvm.testing.assert_allclose( + res[0].numpy(), np.array([[4], [0], [4], [4], [0], [4]]).astype(np.int64) + ) @tvm.testing.requires_gpu def test_sample_top_p_top_k_from_sorted_prob(): prob_shape = (2, 3) - sample_shape = (2, 1) + sample_shape = (3, 1) class Model(Module): def foo( - self, prob: Tensor, index: Tensor, top_p: Tensor, top_k: Tensor, uniform_sample: Tensor + self, + prob: Tensor, + index: Tensor, + top_p: Tensor, + top_k: Tensor, + uniform_sample: Tensor, + sample_indices: Tensor, ): - z0 = op.sample_top_p_top_k_from_sorted_prob(prob, index, top_p, top_k, uniform_sample) + z0 = op.sample_top_p_top_k_from_sorted_prob( + prob, index, top_p, top_k, uniform_sample, sample_indices + ) return z0 # fmt: off @I.ir_module class Expected: @T.prim_func(private=True) - def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle): + def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle): batch, vocab_size = T.int64(), T.int64() cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) - renorm_prob = T.match_buffer(B, (batch, 1)) - usample = T.match_buffer(C, (batch, 1)) - indices = T.match_buffer(D, (batch, vocab_size), "int64") - output_index = T.match_buffer(E, (batch, 1), "int64") + indices = T.match_buffer(B, (batch, vocab_size), "int64") + renorm_prob = T.match_buffer(C, (batch, 1)) + out_batch = T.int64() + usample = T.match_buffer(D, (out_batch, 1)) + sample_indices = T.match_buffer(E, (out_batch, 1), "int64") + output_index = T.match_buffer(F, (out_batch, 1), "int64") # with T.block("root"): - for ax0, ax1 in T.grid(batch, vocab_size): + for ax0, ax1 in T.grid(out_batch, vocab_size): with T.block("T_get_index_from_sorted"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads( - usample[v_ax0, T.int64(0)], - cumsum_sorted[v_ax0, v_ax1 - T.int64(1) : v_ax1 - T.int64(1) + T.int64(2)], - renorm_prob[v_ax0, 0], - indices[ - v_ax0, - T.min(T.int64(0), v_ax1) : T.min(T.int64(0), v_ax1) - + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1)), - ], - ) + T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[v_ax0, T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))]) T.writes(output_index[v_ax0, 0]) - if ( - usample[v_ax0, T.int64(0)] - < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0] - or v_ax1 + T.int64(1) == vocab_size - ): + if usample[v_ax0, T.int64(0)] < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == vocab_size: if v_ax1 == T.int64(0): output_index[v_ax0, 0] = indices[v_ax0, 0] else: - if ( - usample[v_ax0, T.int64(0)] - >= cumsum_sorted[v_ax0, v_ax1 - T.int64(1)] / renorm_prob[v_ax0, 0] - ): + if usample[v_ax0, T.int64(0)] >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]: output_index[v_ax0, 0] = indices[v_ax0, v_ax1] @T.prim_func(private=True) @@ -1015,21 +1015,14 @@ def _initialize_effect() -> R.Tuple(R.Object): return gv @R.function - def foo( - prob: R.Tensor((2, 3), dtype="float32"), - index: R.Tensor((2, 3), dtype="int64"), - top_p: R.Tensor((2, 1), dtype="float32"), - top_k: R.Tensor((2, 1), dtype="int64"), - uniform_sample: R.Tensor((2, 1), dtype="float32"), - _io: R.Object, - ) -> R.Tuple(R.Tensor((2, 1), dtype="int64"), R.Tuple(R.Object)): - R.func_attr({"num_input": 6}) + def foo(prob: R.Tensor((2, 3), dtype="float32"), index: R.Tensor((2, 3), dtype="int64"), top_p: R.Tensor((2, 1), dtype="float32"), top_k: R.Tensor((2, 1), dtype="int64"), uniform_sample: R.Tensor((3, 1), dtype="float32"), sample_indices: R.Tensor((3, 1), dtype="int64"), _io: R.Object,) -> R.Tuple(R.Tensor((3, 1), dtype="int64"), R.Tuple(R.Object)): + R.func_attr({"num_input": 7}) cls = Expected with R.dataflow(): cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(prob, axis=1, dtype="void", exclusive=None) lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32")) - lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, lv1, uniform_sample, index), out_sinfo=R.Tensor((2, 1), dtype="int64")) - gv1: R.Tuple(R.Tensor((2, 1), dtype="int64"), R.Tuple(R.Object)) = lv2, (_io,) + lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, index, lv1, uniform_sample, sample_indices), out_sinfo=R.Tensor((3, 1), dtype="int64")) + gv1: R.Tuple(R.Tensor((3, 1), dtype="int64"), R.Tuple(R.Object)) = lv2, (_io,) R.output(gv1) return gv1 # fmt: on @@ -1040,9 +1033,10 @@ def foo( "foo": { "prob": spec.Tensor(prob_shape, "float32"), "index": spec.Tensor(prob_shape, "int64"), - "top_p": spec.Tensor(sample_shape, "float32"), - "top_k": spec.Tensor(sample_shape, "int64"), + "top_p": spec.Tensor((prob_shape[0], 1), "float32"), + "top_k": spec.Tensor((prob_shape[0], 1), "int64"), "uniform_sample": spec.Tensor(sample_shape, "float32"), + "sample_indices": spec.Tensor(sample_shape, "int64"), } }, debug=True, @@ -1063,12 +1057,13 @@ def foo( indices = tvm.nd.array(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64), dev) top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev) top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev) - usample = tvm.nd.array(np.array([[0.5], [0.6]]).astype(np.float32), dev) + usample = tvm.nd.array(np.array([[0.5], [0.6], [0.7]]).astype(np.float32), dev) + sample_indices = tvm.nd.array(np.array([[0], [1], [1]]).astype(np.int64), dev) - inputs = [sorted_prob, indices, top_p, top_k, usample, effects] + inputs = [sorted_prob, indices, top_p, top_k, usample, sample_indices, effects] res = vm["foo"](*inputs) - tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], [0]]).astype(np.int64)) + tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], [0], [0]]).astype(np.int64)) @tvm.testing.requires_gpu