You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
it doesn't work when batch_size is more than 1
i found that different lengths between samples in minibatch cause error.
code of that function is below
def collate_dialog(self, batch):
'''
Padding and Collating
'''
max_num_persona = max(len(b['persona']) for b in batch)
for b_i, b in enumerate(batch):
b['persona'] += [[] for _ in range(max_num_persona -len(b['persona']))]
max_seq_len = max(len(c) for b in batch for c in b['input_ids'])
max_persona_len = max(len(p) for b in batch for p in b['persona'])
max_history_len = max(len(b['history']) for b in batch)
padded_batch = {}
for name in batch[0].keys():
if name in ['input_ids', 'token_type_ids']:
padded = torch.LongTensor(
[[c + [self.pad_id]*(max_seq_len - len(c)) for c in sample[name]] for sample in batch])
padded_batch[name] = padded.view((-1, max_num_persona, self.num_candidates) + padded.shape[2:])
elif name == 'persona':
aa = [[p + [self.pad_id]*(max_persona_len - len(p)) for p in sample['persona']] for sample in batch]
padded_batch[name] = torch.LongTensor(
[[p + [self.pad_id]*(max_persona_len - len(p)) for p in sample['persona']] for sample in batch])
elif name == 'history':
padded_batch[name] = torch.LongTensor(
[sample[name] + [self.pad_id]*(max_history_len - len(sample[name])) for sample in batch])
elif name == "lm_labels":
padded = torch.LongTensor(
[[c + [-100]*(max_seq_len - len(c)) for c in sample[name]] for sample in batch])
padded_batch[name] = padded.view((-1, max_num_persona, self.num_candidates) + padded.shape[2:])
elif name == "mc_token_ids":
padded_batch[name] = torch.LongTensor([sample[name] for sample in batch])\
.view((-1, max_num_persona, self.num_candidates))
elif name in ["mc_labels", "effects"]:
padded_batch[name] = torch.LongTensor([sample[name] for sample in batch])
else:
assert False, f"Unexpected batch element with key '{name}'"
# print("PersonaChatDataset.collate_dialog:")
# for k, v in padded_batch.items():
# print(f"{k}.shape:", v.shape)
return padded_batch
The text was updated successfully, but these errors were encountered:
it doesn't work when batch_size is more than 1
i found that different lengths between samples in minibatch cause error.
code of that function is below
The text was updated successfully, but these errors were encountered: