-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
25 lines (22 loc) · 1005 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoConfig, AutoModelForTokenClassification
class JointATISModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, model_name, num_labels, num_intents):
super().__init__()
self.model = AutoModelForTokenClassification.from_pretrained(
model_name, num_labels=num_labels
)
self.model_config = AutoConfig.from_pretrained(model_name)
self.intent_head = nn.Linear(self.model_config.hidden_size, num_intents)
def forward(self, input_ids, attention_mask, labels):
outputs = self.model(
input_ids, attention_mask, labels=labels, output_hidden_states=True
)
pooled_output = outputs["hidden_states"][-1][:, 0, :]
intent_logits = self.intent_head(pooled_output)
return {
"dst_logits": outputs.logits,
"intent_loss": intent_logits,
"dst_loss": outputs.loss,
}