diff --git a/export.py b/export.py index c55b62282..3099d80c9 100644 --- a/export.py +++ b/export.py @@ -46,10 +46,10 @@ def __init__(self, model, device, max_seq_length=1024): self.model = model # init model here if necessary - def forward(self, x, input_pos): + def forward(self, idx, input_pos): # input_pos: [B, 1] assert input_pos.shape[-1] == 1 - logits = self.model(x, input_pos) + logits = self.model(idx, input_pos) return logits # sample(logits, **sampling_kwargs)