Skip to content

Commit

Permalink
Fix "dataset length is unknown".
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560052846
  • Loading branch information
fineguy authored and t5-copybara committed Aug 28, 2023
1 parent 828e910 commit 22ff057
Showing 1 changed file with 12 additions and 33 deletions.
45 changes: 12 additions & 33 deletions t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,33 +1586,18 @@ def infer_fn(
'This could be an indication that the dataset is nondeterministic.'
),
)
try:
original_ds_length = len(ds)
dataset_remainder = original_ds_length % batch_size # pytype:disable=wrong-arg-types
logging.info('length of dataset = %s', len(ds))
except TypeError as e:
if str(e).endswith('dataset length is unknown.'):
logging.warning(
'The following error is likely due to the use of TensorFlow v1 in '
'your dataset pipeline. Verify you are not importing from '
'`tf.compat.v1` as part of your pipeline.'
)
raise e

if dataset_remainder:
dataset_pad_amt = batch_size - dataset_remainder
logging.info(
'Padding infer dataset with %d examples for even per-replica shards.',
dataset_pad_amt,
)
# Pad with the first example using an index of -1 so seqio will ignore.
pad_ds = (
ds.take(1)
.map(lambda i, x: (np.int64(-1), x))
.cache()
.repeat(dataset_pad_amt)
)
ds = ds.concatenate(pad_ds)
logging.info(
'Padding infer dataset with %d examples for even per-replica shards.',
batch_size - 1,
)
# Pad with the first example using an index of -1 so seqio will ignore.
pad_ds = (
ds.take(1)
.map(lambda i, x: (np.int64(-1), x))
.cache()
.repeat(batch_size - 1)
)
ds = ds.concatenate(pad_ds)

# Shard the infer dataset across replica sets.
sharded_ds = ds.shard(num_shards, shard_id).batch(
Expand Down Expand Up @@ -1750,12 +1735,6 @@ def _copy_to_host_async(x):
indices_and_outputs = jax.tree_map(
lambda x: np.array(x).tolist(), indices_and_outputs
)
if len(indices_and_outputs) != original_ds_length:
raise ValueError(
'Size of indices_and_outputs does not match length of original '
'dataset: %d versus %d'
% (len(indices_and_outputs), original_ds_length)
)

if aux_values is None:
return indices_and_outputs
Expand Down

0 comments on commit 22ff057

Please sign in to comment.