Skip to content

Commit

Permalink
Fix sparse.expand_single_dim function
Browse files Browse the repository at this point in the history
To be honest I'm not sure how I've fixed this...
I had some issues with building a tf.function that suddenly stopped.
  • Loading branch information
qmarcou committed Feb 14, 2024
1 parent 8dab1ed commit e376bf8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
16 changes: 15 additions & 1 deletion keras_utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,23 @@ def expand_single_dim(sp_tensor: tf.sparse.SparseTensor,
times: int,
axis: int):
out_tensor = sp_tensor
for j in tf.range(start=0, limit=times - 1, delta=1):
# sp_tensor_shape = sp_tensor.get_shape()
# tensor_shape_invariant : tf.TensorShape = tf.TensorShape.concatenate(
# sp_tensor_shape[0:axis],
# tf.TensorShape([None])).concatenate(
# sp_tensor_shape[axis+1:],)
j=tf.constant(0)
while j<(times-1):
#for j in tf.range(start=0, limit=times - 1, delta=1):
# tf.autograph.experimental.set_loop_options(
# shape_invariants=[(out_tensor, tf.TensorShape([None]))]
# )
# tf.autograph.experimental.set_loop_options(
# shape_invariants=[(out_tensor, tensor_shape_invariant)]
# )
out_tensor = tf.sparse.concat(axis=axis,
sp_inputs=[out_tensor, sp_tensor])
j+=1
return out_tensor


Expand Down
2 changes: 1 addition & 1 deletion keras_utils/tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_expend_single_unit_dim(self):
tf.ones(2, dtype=tf.float32),
[2, 2])
sp_t = tf.sparse.expand_dims(sp_t, axis=0)
sp_t_exp = sparse.expand_single_dim(sp_t,
sp_t_exp = sparse.expand_single_dim(sp_tensor=sp_t,
times=2,
axis=0)
exp_out = np.array([[[1., 0.], [0., 1.]],
Expand Down

0 comments on commit e376bf8

Please sign in to comment.