Skip to content

Commit

Permalink
Update models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod authored Nov 2, 2023
1 parent f89bbce commit 68793e6
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,14 @@ class default:
save_steps = 1000000
label_names = ["labels"]
include_inputs_for_metrics = True
model_name = "sileod/deberta-v3-base-tasksource-nli"

default, hparams = to_dict(default), to_dict(hparams)
self.p = hparams.get('p', 1)
self.num_proc = hparams.get('num_proc',None)
self.batched = hparams.get('batched',False)

trainer_args = transformers.TrainingArguments(
**fc.project({**default,**hparams}, dir(transformers.TrainingArguments))
**{**default, **fc.project(hparams, dir(transformers.TrainingArguments))},
)
if not tokenizer:
tokenizer = AutoTokenizer.from_pretrained(hparams["model_name"])
Expand Down Expand Up @@ -383,7 +382,7 @@ class default:
task: dataset["test"]
for task, dataset in self.processed_tasks.items()
}
# We preventstrainer from automatically evaluating on each dataset:
# We prevent Trainer from automatically evaluating on each dataset:
# transformers.Trainer recognizes eval_dataset instances of "dict"
# But we use a custom "evaluate" function so that we can use different metrics for each task
self.eval_dataset = MappingProxyType(self.eval_dataset)
Expand Down

0 comments on commit 68793e6

Please sign in to comment.