Skip to content

Commit

Permalink
Merge pull request #76 from boun-tabi-LMG/minor-patch
Browse files Browse the repository at this point in the history
Enhancements and Fixes for Classification Functionality
  • Loading branch information
gokceuludogan authored Nov 17, 2024
2 parents 096a71d + 51fd0f7 commit b9d2cbb
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 66 deletions.
2 changes: 1 addition & 1 deletion experiments/conf/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model_name: /pretrained_checkpoints/ckpt-1.74M
model_name: boun-tabi-LMG/TURNA
task_mode: '' # '[S2S]: ', '[NLU]: ', '[NLG]: '
training_params:
num_train_epochs: 10
Expand Down
3 changes: 1 addition & 2 deletions experiments/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def main(cfg: DictConfig):
model_trainer = TrainerForConditionalGeneration(model_name, task, training_params, optimizer_params, model_save_path, max_input_length, max_target_length, postprocess_fn)
elif task_format == 'classification':
logger.info("******Classification Mode******")
model_trainer = TrainerForClassification(model_name, task, training_params, optimizer_params, model_save_path, num_labels, postprocess_fn)

model_trainer = TrainerForClassification(model_name, task, training_params, optimizer_params, model_save_path, max_input_length, num_labels, postprocess_fn)
trainer, model = model_trainer.train_and_evaluate(train_dataset, eval_dataset, test_dataset)

logger.info("Best model saved at %s", model_save_path)
Expand Down
9 changes: 6 additions & 3 deletions turkish_lm_tuner/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

from .tr_datasets import initialize_dataset
from .tr_datasets import initialize_dataset, BaseDataset

class DatasetProcessor:
"""
Expand All @@ -33,7 +33,8 @@ def __init__(self,
tokenizer_name: str = None,
max_input_length: int = None,
max_target_length: int = None,
dataset_loc: str = ''):
dataset_loc: str = '',
dataset: BaseDataset = None):

logger.info(f"Initializing dataset processor for {dataset_name} dataset with {tokenizer_name} tokenizer and {task} task in {task_format} format with {task_mode} mode")
logger.info(f"Max input length: {max_input_length} Max target length: {max_target_length}")
Expand All @@ -46,6 +47,7 @@ def __init__(self,
self.max_input_length = max_input_length
self.max_target_length = max_target_length
self.dataset_loc = dataset_loc
self.dataset = dataset

def load_and_preprocess_data(self, split='train'):
"""
Expand All @@ -54,7 +56,8 @@ def load_and_preprocess_data(self, split='train'):
split: Split of the dataset to be loaded. Either 'train', 'validation' or 'test'
"""
logger.info(f"Loading {split} split of {self.dataset_name} dataset")
self.dataset = initialize_dataset(self.dataset_name, self.dataset_loc)
if self.dataset is None:
self.dataset = initialize_dataset(self.dataset_name, self.dataset_loc)
data = self.dataset.load_dataset(split)

logger.info(f"Preprocessing {self.dataset_name} dataset")
Expand Down
21 changes: 9 additions & 12 deletions turkish_lm_tuner/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .metrics import load_task_metrics
import pandas as pd
import numpy as np
import torch
import os
import logging

Expand Down Expand Up @@ -74,17 +75,10 @@ def __init__(self, model_path, tokenizer_path, task, max_input_length, test_para
self.num_labels = num_labels

def initialize_model(self):
AutoConfig.register("t5_turna_enc", T5ForClassificationConfig)
AutoModel.register(T5ForClassificationConfig, T5ForClassification)
config = AutoConfig.from_pretrained(self.model_path)

if config.model_type in ["t5", "mt5", "t5_turna_enc"]:
if self.task == "classification":
return T5ForClassification.from_pretrained(self.model_path, config, self.num_labels, "single_label_classification")
elif self.task in ["ner", "pos_tagging"]:
return T5ForClassification.from_pretrained(self.model_path, config, self.num_labels, "token_classification")
else:
return T5ForClassification.from_pretrained(self.model_path, config, 1, "regression")
if config.model_type in ["t5", "mt5"]:
return T5ForClassification.from_pretrained(self.model_path)
else:
if self.task == "classification":
return AutoModelForSequenceClassification.from_pretrained(self.model_path, num_labels=self.num_labels)
Expand All @@ -107,6 +101,10 @@ def compute_metrics(self, eval_preds):
preds, labels = eval_preds
if self.task == "semantic_similarity":
preds = preds.flatten()
elif self.task == "multi_label_classification":
sigmoid_outputs = torch.sigmoid(torch.Tensor(preds))
# Apply 0.5 threshold to get binary predictions
preds = (sigmoid_outputs > 0.5).int()
else:
preds = np.argmax(preds, axis=-1)

Expand All @@ -119,12 +117,11 @@ def compute_metrics(self, eval_preds):
labels = self.postprocess_fn(labels)

logger.info("Computing metrics")

result = super().compute_metrics(preds, labels)

logger.info("Predictions: %s", preds[:5])
logger.info("Labels: %s", labels[:5])

result = super().compute_metrics(preds, labels)

predictions = pd.DataFrame(
{'Prediction': preds,
'Label': labels
Expand Down
56 changes: 40 additions & 16 deletions turkish_lm_tuner/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,48 @@ def compute(self, preds, labels):
return self.metric.compute(predictions=preds, references=labels, average="weighted")

class F1(BaseMetric):
def __init__(self):
def __init__(self, average="binary"):
super().__init__("f1")
self.average = average

def compute(self, preds, labels):
return self.metric.compute(predictions=preds, references=labels, average=self.average)

class F1Macro(BaseMetric):
class F1Macro(F1):
def __init__(self):
super().__init__("f1")
super().__init__("macro")

def compute(self, preds, labels):
return self.metric.compute(predictions=preds, references=labels, average="macro")

class F1Micro(BaseMetric):
class F1Micro(F1):
def __init__(self):
super().__init__("f1")
super().__init__("micro")

def compute(self, preds, labels):
return self.metric.compute(predictions=preds, references=labels, average="micro")

class F1Weighted(BaseMetric):
class F1Weighted(F1):
def __init__(self):
super().__init__("f1")
super().__init__("weighted")

def compute(self, preds, labels):
return self.metric.compute(predictions=preds, references=labels, average="weighted")
class F1MultiBase(F1):
def __init__(self, average):
"""
Initializes the F1 multi-base class for multilabel metrics.
Args:
average (str): The averaging method to use (e.g., 'macro', 'micro', 'weighted').
"""
super().__init__(average)
self.metric = evaluate.load("f1", "multilabel")
self.average = average

class F1MultiMacro(F1MultiBase):
def __init__(self):
super().__init__("macro")

class F1MultiMicro(F1MultiBase):
def __init__(self):
super().__init__("micro")

class F1MultiWeighted(F1MultiBase):
def __init__(self):
super().__init__("weighted")

class Pearsonr(BaseMetric):
def __init__(self):
Expand Down Expand Up @@ -154,6 +173,9 @@ def compute(self, preds, labels, **kwargs):
("f1_macro", "F1Macro"),
("f1_micro", "F1Micro"),
("f1_weighted", "F1Weighted"),
("f1_multi_macro", "F1MultiMacro"),
("f1_multi_micro", "F1MultiMicro"),
("f1_multi_weighted", "F1MultiWeighted"),
("pearsonr", "Pearsonr"),
("bleu", "BLEU"),
("meteor", "METEOR"),
Expand Down Expand Up @@ -205,6 +227,8 @@ def load_task_metrics(task):
"""
if task == "classification":
return load_metrics(["accuracy", "precision_weighted", "recall_weighted", "f1_weighted"])
elif task == "mult_label_classification":
return load_metrics(["f1_multi_weighted"])
elif task in ["summarization", "paraphrasing", "title_generation"]:
return load_metrics(["rouge", "bleu", "meteor", "ter"])
elif task == "nli":
Expand Down Expand Up @@ -235,7 +259,7 @@ class Evaluator:
def __init__(self, task=None, metrics=None):
"""
Initializes the Evaluator class.
˜
Args:
task (str, optional): The name of the task for which to load metrics. Defaults to None.
metrics (list, optional): A list of metric names to load. Defaults to None.
Expand Down
27 changes: 8 additions & 19 deletions turkish_lm_tuner/t5_classifier.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from torch import nn
from transformers import T5EncoderModel, PretrainedConfig
from transformers import T5EncoderModel, T5Config
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.models.t5.modeling_t5 import T5PreTrainedModel

class T5ForClassificationConfig(PretrainedConfig):
model_type="t5_turna_enc"

class T5ForClassification(T5PreTrainedModel): # nn.Module
config_class = T5ForClassificationConfig
"""
T5 encoder adapted for classification
Args:
Expand All @@ -18,29 +14,22 @@ class T5ForClassification(T5PreTrainedModel): # nn.Module
problem_type: Problem type. It can be either 'single_label_classification', 'multi_label_classification', 'token_classification' or 'regression'
dropout_prob: Dropout probability
"""
def __init__(self, pretrained_model_name, config, num_labels, problem_type, dropout_prob=0.1):
def __init__(self, config: T5Config):
super().__init__(config)

try:
self.encoder = T5EncoderModel.from_pretrained(pretrained_model_name)
except Exception as e:
pretrained_model_name = config._name_or_path
self.encoder = T5EncoderModel.from_pretrained(pretrained_model_name)
self.transformer = T5EncoderModel(config)

self.dropout = nn.Dropout(dropout_prob)
self.classifier = nn.Linear(self.encoder.config.d_model, num_labels)
self.config = self.encoder.config
self.config.num_labels = num_labels
self.config.problem_type = problem_type
self.config.dropout_prob = dropout_prob
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

self.model_parallel = False

def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
encoder_output = self.encoder(input_ids, attention_mask=attention_mask)
def forward(self, input_ids, attention_mask=None, labels=None):
encoder_output = self.transformer(input_ids, attention_mask=attention_mask)

if self.config.problem_type == "token_classification":
sequence_output = encoder_output.last_hidden_state
else:
Expand Down
32 changes: 19 additions & 13 deletions turkish_lm_tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,26 +136,32 @@ def train_and_evaluate(self, train_dataset, eval_dataset, test_dataset):


class TrainerForClassification(BaseModelTrainer):
def __init__(self, model_name, task, training_params, optimizer_params, model_save_path, num_labels, postprocess_fn=None):
def __init__(self, model_name, task, training_params, optimizer_params, model_save_path, max_input_length, num_labels, postprocess_fn=None):
super().__init__(model_name, training_params, optimizer_params)
self.num_labels = num_labels
self.task = task
self.evaluator = EvaluatorForClassification(model_save_path, model_name, task, training_params, postprocess_fn=postprocess_fn)
self.evaluator = EvaluatorForClassification(model_save_path, model_name, task, max_input_length, training_params, num_labels, postprocess_fn=postprocess_fn)

def initialize_model(self):
config = AutoConfig.from_pretrained(self.model_name)

if config.model_type in ["t5", "mt5"]:
if self.task == "classification":
return T5ForClassification(self.model_name, config, self.num_labels, "single_label_classification")
elif self.task in ["ner", "pos_tagging"]:
return T5ForClassification(self.model_name, config, self.num_labels, "token_classification")
else:
return T5ForClassification(self.model_name, config, 1, "regression")
else:
if self.task == "classification":
return AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=self.num_labels)
elif self.task in ["ner", "pos_tagging"]:
return AutoModelForTokenClassification.from_pretrained(self.model_name, num_labels=self.num_labels)
task_map = {
"classification": "single_label_classification",
"multi_label_classification": "multi_label_classification",
"ner": "token_classification",
"pos_tagging": "token_classification"
}
task_type = task_map.get(self.task, "regression")
num_labels = self.num_labels if task_type != "regression" else 1
return T5ForClassification.from_pretrained(self.model_name, num_labels=num_labels, problem_type=task_type)

if self.task == "classification":
return AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=self.num_labels)
if self.task == "multi_label_classification":
return AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=self.num_labels, problem_type=self.task)
if self.task in ["ner", "pos_tagging"]:
return AutoModelForTokenClassification.from_pretrained(self.model_name, num_labels=self.num_labels)

def train_and_evaluate(self, train_dataset, eval_dataset, test_dataset):
logger.info("Training in classification mode")
Expand Down

0 comments on commit b9d2cbb

Please sign in to comment.