Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code issue #9

Open
wgkwgk opened this issue Apr 9, 2021 · 1 comment
Open

code issue #9

wgkwgk opened this issue Apr 9, 2021 · 1 comment

Comments

@wgkwgk
Copy link

wgkwgk commented Apr 9, 2021

it doesn't work when batch_size is more than 1
i found that different lengths between samples in minibatch cause error.

code of that function is below

def collate_dialog(self, batch):
        '''
        Padding and Collating
        '''
        max_num_persona = max(len(b['persona']) for b in batch)
        for b_i, b in enumerate(batch):
            b['persona'] += [[] for _ in range(max_num_persona -len(b['persona']))]
        max_seq_len = max(len(c) for b in batch for c in b['input_ids'])
        max_persona_len = max(len(p) for b in batch for p in b['persona'])
        max_history_len = max(len(b['history']) for b in batch)
        padded_batch = {}

        for name in batch[0].keys():
            if name in ['input_ids', 'token_type_ids']:

                padded = torch.LongTensor(
                    [[c + [self.pad_id]*(max_seq_len - len(c)) for c in sample[name]] for sample in batch])

                padded_batch[name] = padded.view((-1, max_num_persona, self.num_candidates) + padded.shape[2:])
            elif name == 'persona':
                aa = [[p + [self.pad_id]*(max_persona_len - len(p)) for p in sample['persona']] for sample in batch]

                padded_batch[name] = torch.LongTensor(
                    [[p + [self.pad_id]*(max_persona_len - len(p)) for p in sample['persona']] for sample in batch])
            elif name == 'history':
                padded_batch[name] = torch.LongTensor(
                    [sample[name] + [self.pad_id]*(max_history_len - len(sample[name])) for sample in batch])
            elif name == "lm_labels":
                padded = torch.LongTensor(
                    [[c + [-100]*(max_seq_len - len(c)) for c in sample[name]] for sample in batch])
                padded_batch[name] = padded.view((-1, max_num_persona, self.num_candidates) + padded.shape[2:])
            elif name == "mc_token_ids":
                padded_batch[name] = torch.LongTensor([sample[name] for sample in batch])\
                    .view((-1, max_num_persona, self.num_candidates))
            elif name in ["mc_labels", "effects"]:
                padded_batch[name] = torch.LongTensor([sample[name] for sample in batch])
            else:
                assert False, f"Unexpected batch element with key '{name}'"
        # print("PersonaChatDataset.collate_dialog:")
        # for k, v in padded_batch.items():
        #     print(f"{k}.shape:", v.shape)
        return padded_batch   
@wgkwgk
Copy link
Author

wgkwgk commented Apr 14, 2021

could you solve that problem edit that code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant