-
Notifications
You must be signed in to change notification settings - Fork 6
/
hungarian_callback.py
72 lines (56 loc) · 2 KB
/
hungarian_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from jax.experimental import host_callback
from scipy.optimize import linear_sum_assignment
import numpy as np
import jax
from jax.nn import one_hot
from jax import vmap
from c_modules.mine import compute_parallel
from jax import numpy as jnp
def multiple_linear_assignment(X):
# X has shape (batch_size, k, k)
X_out = np.zeros(X.shape[:-1], dtype=int)
for i, x in enumerate(X):
X_out[i] = linear_sum_assignment(x)[1]
return X_out
def parlinear_assignment(X):
return compute_parallel(X)
# def hungarian_callback_loop(X):
# # This will send a batch of matrices of size (n, k, k)
# # to the host and returns an (n, k) set of indices
# # of the locations of assignments in the rows of the matrices
# return host_callback.call(
# parlinear_assignment,
# X,
# result_shape=jax.ShapeDtypeStruct(X.shape[:-1], jnp.int32),
# )
def hungarian_callback_loop(X):
# This will send a batch of matrices of size (n, k, k)
# to the host and returns an (n, k) set of indices
# of the locations of assignments in the rows of the matrices
return host_callback.call(
multiple_linear_assignment,
# parlinear_assignment,
X,
result_shape=jax.ShapeDtypeStruct(X.shape[:-1], jnp.int64),
)
# def batched_hungarian(X):
# # Input is (n, k, k)
# n = X.shape[-1]
# indices = hungarian_callback_loop(X)
# return vmap(one_hot, in_axes=(0, None))(indices, n)
def batched_hungarian(X):
# Input is (n, k, k)
n = X.shape[-1]
# indices = hungarian_callback(X)
indices = hungarian_callback_loop(X)
return vmap(one_hot, in_axes=(0, None))(indices, n)
def linear_sum_assignment_wrapper(X):
return linear_sum_assignment(X)[1]
def hungarian(X):
# Probably more efficient to use the batched version of this generally
indices = host_callback.call(
linear_sum_assignment_wrapper,
X,
result_shape=jax.ShapeDtypeStruct(X.shape[:-1], int),
)
return one_hot(indices, X.shape[0])