Skip to content

Commit

Permalink
Merge pull request #74 from boun-tabi-LMG/zeynepyirmibes-eval-patch
Browse files Browse the repository at this point in the history
Adding evaluation of TURNA-encoder
  • Loading branch information
zeynepyirmibes authored May 5, 2024
2 parents 5bcf918 + 9df38db commit a0e954b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 12 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ dataset_processor = DatasetProcessor(
test_dataset = dataset_processor.load_and_preprocess_data(split="test")

test_params = {
'per_device_eval_batch_size': 4
'per_device_eval_batch_size': 4,
'output_dir': './',
'predict_with_generate': True
}

model_path = "turna_tr_news_summarization"
Expand Down
5 changes: 4 additions & 1 deletion experiments/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def main(cfg: DictConfig):
test_params = cfg.test_params
generation_params = cfg.generation_params
dataset_location = cfg.dataset_loc

if "num_labels" in cfg.keys():
num_labels = cfg.num_labels

logger.info("Loading test dataset")
dataset_processor = DatasetProcessor(dataset_name, task, task_format, task_mode, tokenizer_path, max_input_length, max_target_length, dataset_location)
Expand Down Expand Up @@ -69,7 +72,7 @@ def main(cfg: DictConfig):
evaluator = EvaluatorForConditionalGeneration(model_path, tokenizer_path, task, max_input_length, max_target_length, test_params, generation_params, postprocess_fn)
elif task_format == 'classification':
logger.info("Evaluating in classification mode")
evaluator = EvaluatorForClassification(model_path, tokenizer_path, task, test_params)
evaluator = EvaluatorForClassification(model_path, tokenizer_path, task, max_input_length, test_params, num_labels, postprocess_fn)


logger.info("Evaluating model")
Expand Down
31 changes: 28 additions & 3 deletions turkish_lm_tuner/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification,
Seq2SeqTrainer, Seq2SeqTrainingArguments,
Trainer, TrainingArguments,
EvalPrediction
EvalPrediction,
AutoConfig, AutoModel,
AutoModelForTokenClassification,
)

from .t5_classifier import T5ForClassification, T5ForClassificationConfig


from .metrics import load_task_metrics
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -63,11 +68,31 @@ def compute_metrics(self, preds, labels):

class EvaluatorForClassification(BaseEvaluator):

def __init__(self, model_path, tokenizer_path, task, max_input_length, test_params, num_labels, postprocess_fn=None):
super().__init__(model_path, tokenizer_path, task, test_params, postprocess_fn)
self.max_input_length = max_input_length
self.num_labels = num_labels

def initialize_model(self):
# If used without fine-tuning, model should be loaded from the model save path
return AutoModelForSequenceClassification.from_pretrained(self.model_path)
AutoConfig.register("t5_turna_enc", T5ForClassificationConfig)
AutoModel.register(T5ForClassificationConfig, T5ForClassification)
config = AutoConfig.from_pretrained(self.model_path)

if config.model_type in ["t5", "mt5", "t5_turna_enc"]:
if self.task == "classification":
return T5ForClassification.from_pretrained(self.model_path, config, self.num_labels, "single_label_classification")
elif self.task in ["ner", "pos_tagging"]:
return T5ForClassification.from_pretrained(self.model_path, config, self.num_labels, "token_classification")
else:
return T5ForClassification.from_pretrained(self.model_path, config, 1, "regression")
else:
if self.task == "classification":
return AutoModelForSequenceClassification.from_pretrained(self.model_path, num_labels=self.num_labels)
elif self.task in ["ner", "pos_tagging"]:
return AutoModelForTokenClassification.from_pretrained(self.model_path, num_labels=self.num_labels)

def initialize_trainer(self, model):

test_args = TrainingArguments(
**self.test_params)

Expand Down
12 changes: 10 additions & 2 deletions turkish_lm_tuner/t5_classifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from torch import nn
from transformers import T5EncoderModel
from transformers import T5EncoderModel, PretrainedConfig
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.models.t5.modeling_t5 import T5PreTrainedModel

class T5ForClassificationConfig(PretrainedConfig):
model_type="t5_turna_enc"

class T5ForClassification(T5PreTrainedModel): # nn.Module
config_class = T5ForClassificationConfig
"""
T5 encoder adapted for classification
Args:
Expand All @@ -18,7 +21,12 @@ class T5ForClassification(T5PreTrainedModel): # nn.Module
def __init__(self, pretrained_model_name, config, num_labels, problem_type, dropout_prob=0.1):
super().__init__(config)

self.encoder = T5EncoderModel.from_pretrained(pretrained_model_name)
try:
self.encoder = T5EncoderModel.from_pretrained(pretrained_model_name)
except Exception as e:
pretrained_model_name = config._name_or_path
self.encoder = T5EncoderModel.from_pretrained(pretrained_model_name)

self.dropout = nn.Dropout(dropout_prob)
self.classifier = nn.Linear(self.encoder.config.d_model, num_labels)
self.config = self.encoder.config
Expand Down
16 changes: 11 additions & 5 deletions turkish_lm_tuner/tr_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def convert_sts_label(label):
try:
return(float(label.strip()))
except:
return 0
try:
return(float(label))
except:
return 0

return [convert_sts_label(ex) for ex in examples]

class NLI_TRDataset(BaseDataset):
Expand Down Expand Up @@ -660,7 +664,12 @@ def __init__(self, dataset_name=None):
self.OUT_LABEL_DICT = {v: k for k, v in self.IN_LABEL_DICT.items()}

def postprocess_data(self, examples):
return [self.OUT_LABEL_DICT.get(ex.strip(), -1) for ex in examples]
def convert_class_label(label):
if type(label) == type(""):
return self.OUT_LABEL_DICT.get(label.strip(), -1)
else:
return label
return [convert_class_label(ex) for ex in examples]

def load_dataset(self, split=None):
return super().load_dataset(split)
Expand Down Expand Up @@ -721,9 +730,6 @@ def preprocess_data(self, examples, skip_output_processing=False):
output = [self.IN_LABEL_DICT[ex] for ex in examples["label"]]
return {"input_text": examples["text"], "target_text": output}

def postprocess_data(self, examples):
return [self.OUT_LABEL_DICT.get(ex.strip(), -1) for ex in examples]

def load_dataset(self, split=None):
dataset = LocalDataset.load_dataset(self, split)
#dataset = datasets.load_dataset(self.dataset_loc, data_files=self.dataset_info, split=split)
Expand Down

0 comments on commit a0e954b

Please sign in to comment.