Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix misaligned arguments in super() constructor call #6

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,5 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
.vscode
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
**Deprecation notice:** this package's responsibilities are now handled by a mixin class provided by [ArchIt](https://github.com/bauwenst/ArchIt), and with proper object-oriented design. This repo will no longer be updated.

# Enhanced Multitask Trainer for Separately Reporting Task's Metrics or Losses in HuggingFace Transformers

The [HuggingFace transformers](https://github.com/huggingface/transformers) library is widely used for model training. For example, to adapt a pretrained BERT model to a specific task domain, we often continue pretraining the model with two tasks in BERT: 1) Next Sentence Prediction (NSP) using the `[CLS]` token, and 2) Masked Language Modeling (MLM) using masked tokens.

**A key issue is that the default `Trainer` in `transformers` assumes the first element of the output is the final loss to minimize. The loss returned by the `forward` method must be a scalar, so when training a multitask model like BERT, the loss needs to be combined.**
*A key issue is that the default `Trainer` in `transformers` assumes the first element of the output is the final loss to minimize. The loss returned by the `forward` method must be a scalar, so when training a multitask model like BERT, the loss needs to be combined.*

**The `Trainer` class offers command-line arguments to control the training process. However, it only provides a combined loss value for all tasks, which obscures the individual losses of each task. This makes it challenging to monitor training and debug different task settings. Additionally, the `Tensorboard` report only shows the combined loss in its metrics.**
*The `Trainer` class offers command-line arguments to control the training process. However, it only provides a combined loss value for all tasks, which obscures the individual losses of each task. This makes it challenging to monitor training and debug different task settings. Additionally, the `Tensorboard` report only shows the combined loss in its metrics.*

To facilitate multitask model training and review the loss of each task, as well as other training metrics, this trainer implementation is simple and useful.

Expand All @@ -17,10 +19,10 @@ By the way, another utility you might need is [parser-binding](https://github.co
Follow these steps to use the `HfMultiTaskTrainer`:

1. Install the trainer:

```sh
pip install hf-mtask-trainer
```
```sh
pip install "hf_mtask_trainer @ git+https://github.com/bauwenst/hf-multitask-trainer"
```

2. Replace the default trainer with `HfMultiTaskTrainer`:

Expand Down
8 changes: 2 additions & 6 deletions hf_mtask_trainer/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
from .types import Number


class MultiTaskModuleMixin():

def report_metrics(
self, state: AdditionalState, **metrics: Union[Number, torch.Tensor,
npt.NDArray]
):
class MultiTaskModuleMixin:
def report_metrics(self, state: AdditionalState, **metrics: Union[Number, torch.Tensor, npt.NDArray]):
state.add_metrics(**metrics)
9 changes: 3 additions & 6 deletions hf_mtask_trainer/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
class AdditionalState:

def __init__(self, args: TrainingArguments) -> None:
self.metrics: Dict[str, List[Union[Number, torch.Tensor,
npt.NDArray]]] = defaultdict(list)
self.metrics: Dict[str, List[Union[Number, torch.Tensor, npt.NDArray]]] = defaultdict(list)
self.args = weakref.ref(args)

def add_metrics(self, **metrics: Union[Number, torch.Tensor, npt.NDArray]):
Expand All @@ -46,8 +45,7 @@ def add_metrics(self, **metrics: Union[Number, torch.Tensor, npt.NDArray]):
def get_metrics(
self,
step_scale: float = 1.0,
gather_func: Optional[Callable[
[Union[torch.Tensor, List[torch.Tensor]]], torch.Tensor]] = None,
gather_func: Optional[Callable[[Union[torch.Tensor, List[torch.Tensor]]], torch.Tensor]] = None,
round_digits: Optional[int] = None
) -> Dict[str, Number]:
metrics: Dict[str, List[Number]] = defaultdict(list)
Expand Down Expand Up @@ -84,8 +82,7 @@ def get_metrics(

def pop_metrics(
self,
gather_func: Optional[Callable[
[Union[torch.Tensor, List[torch.Tensor]]], torch.Tensor]] = None,
gather_func: Optional[Callable[[Union[torch.Tensor, List[torch.Tensor]]], torch.Tensor]] = None,
round_digits: Optional[int] = None
):
ret = self.get_metrics(gather_func, round_digits)
Expand Down
66 changes: 31 additions & 35 deletions hf_mtask_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.trainer_utils import EvalPrediction, speed_metrics
from transformers.training_args import TrainingArguments

from .mixins import MultiTaskModuleMixin
Expand All @@ -42,15 +42,12 @@


def _patching_module_base(module: Module, additional_state: AdditionalState):
if isinstance(
module, Module
) and hasattr(module, 'supports_report_metrics') and module.supports_report_metrics and MultiTaskModuleMixin not in module.__class__.__bases__:
module.__class__.__bases__ = module.__class__.__bases__ + (
MultiTaskModuleMixin,
)
module.report_metrics = partial(
module.report_metrics, additional_state
)
if isinstance(module, Module) \
and hasattr(module, 'supports_report_metrics') \
and module.supports_report_metrics \
and MultiTaskModuleMixin not in module.__class__.__bases__:
module.__class__.__bases__ = module.__class__.__bases__ + (MultiTaskModuleMixin,)
module.report_metrics = partial(module.report_metrics, additional_state)


class HfMultiTaskTrainer(Trainer):
Expand All @@ -67,45 +64,44 @@ def __init__(
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Optional[Tuple[Optimizer, LambdaLR]] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[
[torch.Tensor, torch.Tensor], torch.Tensor]] = None
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
):
self.additional_state = AdditionalState(args)
if model is not None:
report_pathcing = partial(
_patching_module_base, additional_state=self.additional_state
)
model.apply(report_pathcing)
report_patching = partial(_patching_module_base, additional_state=self.additional_state)
model.apply(report_patching)
super().__init__(
model, args, data_collator, train_dataset, eval_dataset, tokenizer,
model_init, compute_metrics, callbacks, optimizers,
preprocess_logits_for_metrics
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float]=None) -> None:
# Copied from transformers 4.47.0
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self.args.include_num_input_tokens_seen:
logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
if start_time is not None:
speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen)

if hasattr(self, 'additional_state'):
additional_logs = self.additional_state.pop_metrics(
gather_func=self._nested_gather
)
else:
additional_logs = {}
##### Added
additional_logs = self.additional_state.pop_metrics(gather_func=self._nested_gather) if hasattr(self, 'additional_state') else dict()

epoch = logs.pop('epoch', None)
logs.update(additional_logs)
logs['epoch'] = epoch
#####

output = {
**logs,
**{
"step": self.state.global_step
}
}
# Copied from transformers 4.47.0
output = logs | {"step": self.state.global_step}
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(
self.args, self.state, self.control, logs
)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
4 changes: 2 additions & 2 deletions hf_mtask_trainer/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from typing import NewType
from typing import NewType, Union

Number = NewType('Number', (int, float))
Number = NewType('Number', Union[int, float])
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ packages = find:
# six >= 1.10
install_requires =
torch
transformers
transformers >= 4.45.1
accelerate

# Test dependencies, all dependencies for tests here. The format is align to install_requires.
Expand Down