From 8f3a936726ca456aa6e4382ee63146ae1e550c41 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Mon, 23 Sep 2024 22:54:48 -0600 Subject: [PATCH 1/2] GH-3429: Add experimental multi GPU support --- flair/__init__.py | 6 + flair/distributed_utils.py | 49 +++++++++ flair/nn/model.py | 3 + flair/trainers/trainer.py | 220 ++++++++++++++++++++----------------- 4 files changed, 176 insertions(+), 102 deletions(-) create mode 100644 flair/distributed_utils.py diff --git a/flair/__init__.py b/flair/__init__.py index 341f630e4..55bd6d00c 100644 --- a/flair/__init__.py +++ b/flair/__init__.py @@ -33,6 +33,12 @@ else: device = torch.device("cpu") +distributed = False +"""Experimental flag to indicate multiple GPUs are in use. + +Set by `launch_distributed` -- do not set manually. +""" + # global variable: version __version__ = "0.14.0" """The current version of the flair library installed.""" diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py new file mode 100644 index 000000000..596570b7c --- /dev/null +++ b/flair/distributed_utils.py @@ -0,0 +1,49 @@ +import logging +import os + +import torch +import torch.multiprocessing as mp +from torch.distributed import destroy_process_group, init_process_group + +import flair + +log = logging.getLogger("flair") + + +def launch_distributed(fp, *args): + """Executes the function fp(*args) on multiple GPUs (all local GPUs)""" + world_size = torch.cuda.device_count() + log.info(f"Launching {world_size} distributed processes") + mp.spawn(entrypoint, args=(world_size, fp, *args), nprocs=world_size) + + +def entrypoint(rank, world_size, fp, *args): + ddp_setup(rank, world_size) + fp(*args) + destroy_process_group() + + +def ddp_setup(rank: int, world_size: int) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + init_process_group(backend="nccl", rank=rank, world_size=world_size) + flair.distributed = True + flair.device = torch.device(rank) + torch.cuda.set_device(flair.device) + + +def is_main_process() -> bool: + """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu""" + if flair.distributed: + return flair.device.index == 0 + else: + return True + + +class DistributedModel(torch.nn.parallel.DistributedDataParallel): + """DistributedDataParallel, but redirects access to methods and attributes to the original Model""" + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) diff --git a/flair/nn/model.py b/flair/nn/model.py index eeb5b7c84..1c3e999b3 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -17,6 +17,7 @@ from flair.class_utils import get_non_abstract_subclasses from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset +from flair.distributed_utils import is_main_process from flair.embeddings import Embeddings from flair.embeddings.base import load_embeddings from flair.file_utils import Tqdm, load_torch_state @@ -118,6 +119,8 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: model_file: the model file checkpoint: currently unused. """ + if not is_main_process(): + return model_state = self._get_state_dict() # write out a "model card" if one is set diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 03e6edc08..196fe39b9 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -11,12 +11,14 @@ import torch from torch.optim.sgd import SGD +from torch.utils.data import DistributedSampler from torch.utils.data.dataset import ConcatDataset import flair import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader +from flair.distributed_utils import DistributedModel, is_main_process from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -420,6 +422,8 @@ def train_custom( base_path=base_path, ).attach_to(self) + if flair.distributed: + self.model = DistributedModel(self.model, device_ids=[flair.device.index]) # === END BLOCK: ACTIVATE PLUGINS === # # derive parameters the function was called with (or defaults) @@ -509,36 +513,37 @@ def train_custom( else "model from best epoch (best-model.pt)" ) - log_line(log) - log.info(f'Model: "{self.model}"') - log_line(log) - log.info(f"{self.corpus}") - log_line(log) - log.info(f"Train: {len(train_data)} sentences") - log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})") - log_line(log) - log.info("Training Params:") - log.info( - f' - learning_rate: "{learning_rate}" ' - f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}' - ) - log.info(f' - mini_batch_size: "{mini_batch_size}"') - log.info(f' - max_epochs: "{max_epochs}"') - log.info(f' - shuffle: "{shuffle}"') - log_line(log) - log.info("Plugins:") - for plugin in plugins: - log.info(" - " + str(plugin)) - log_line(log) - log.info(f"Final evaluation on {final_eval_info}") - log.info(f' - metric: "{main_evaluation_metric}"') - log_line(log) - log.info("Computation:") - log.info(f" - compute on device: {flair.device}") - log.info(f" - embedding storage: {embeddings_storage_mode}") - log_line(log) - log.info(f'Model training base path: "{base_path}"') - log_line(log) + if is_main_process(): + log_line(log) + log.info(f'Model: "{self.model}"') + log_line(log) + log.info(f"{self.corpus}") + log_line(log) + log.info(f"Train: {len(train_data)} sentences") + log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})") + log_line(log) + log.info("Training Params:") + log.info( + f' - learning_rate: "{learning_rate}" ' + f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}' + ) + log.info(f' - mini_batch_size: "{mini_batch_size}"') + log.info(f' - max_epochs: "{max_epochs}"') + log.info(f' - shuffle: "{shuffle}"') + log_line(log) + log.info("Plugins:") + for plugin in plugins: + log.info(" - " + str(plugin)) + log_line(log) + log.info(f"Final evaluation on {final_eval_info}") + log.info(f' - metric: "{main_evaluation_metric}"') + log_line(log) + log.info("Computation:") + log.info(f" - compute on device: {flair.device}") + log.info(f" - embedding storage: {embeddings_storage_mode}") + log_line(log) + log.info(f'Model training base path: "{base_path}"') + log_line(log) # At any point you can hit Ctrl + C to break out of training early. try: @@ -560,12 +565,21 @@ def train_custom( if not shuffle_first_epoch and epoch == 1: shuffle_data_this_epoch = False - batch_loader = DataLoader( - train_data, - batch_size=mini_batch_size, - shuffle=shuffle_data_this_epoch, - sampler=sampler, - ) + if flair.distributed: + batch_loader = DataLoader( + train_data, + batch_size=mini_batch_size, + shuffle=False, + sampler=DistributedSampler(train_data, shuffle=shuffle_data_this_epoch), + ) + batch_loader.sampler.set_epoch(epoch) + else: + batch_loader = DataLoader( + train_data, + batch_size=mini_batch_size, + shuffle=shuffle_data_this_epoch, + sampler=sampler, + ) self.model.train() @@ -682,49 +696,50 @@ def train_custom( # Determine if this is the best model or if we need to anneal current_epoch_has_best_model_so_far = False - validation_scores: tuple - - for evaluation_split, evaluation_split_data in evaluation_splits.items(): - eval_result = self.model.evaluate( - evaluation_split_data, - out_path=base_path / f"{evaluation_split}.tsv", - mini_batch_size=eval_batch_size, - exclude_labels=exclude_labels, - main_evaluation_metric=main_evaluation_metric, - gold_label_dictionary=gold_label_dictionary_for_eval, - embedding_storage_mode=embeddings_storage_mode, - gold_label_type=self.model.label_type, - gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, - ) + validation_scores = () + + if is_main_process(): + for evaluation_split, evaluation_split_data in evaluation_splits.items(): + eval_result = self.model.evaluate( + evaluation_split_data, + out_path=base_path / f"{evaluation_split}.tsv", + mini_batch_size=eval_batch_size, + exclude_labels=exclude_labels, + main_evaluation_metric=main_evaluation_metric, + gold_label_dictionary=gold_label_dictionary_for_eval, + embedding_storage_mode=embeddings_storage_mode, + gold_label_type=self.model.label_type, + gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, + ) - # log results - log.info( - f"{evaluation_split.upper()} : loss {eval_result.loss}" - f" - {main_evaluation_metric[1]}" - f" ({main_evaluation_metric[0]})" - f" {round(eval_result.main_score, 4)}" - ) + # log results + log.info( + f"{evaluation_split.upper()} : loss {eval_result.loss}" + f" - {main_evaluation_metric[1]}" + f" ({main_evaluation_metric[0]})" + f" {round(eval_result.main_score, 4)}" + ) - # depending on memory mode, embeddings are moved to CPU, GPU or deleted - store_embeddings(evaluation_split_data, embeddings_storage_mode) + # depending on memory mode, embeddings are moved to CPU, GPU or deleted + store_embeddings(evaluation_split_data, embeddings_storage_mode) - self._publish_eval_result(eval_result, evaluation_split, global_step=epoch) + self._publish_eval_result(eval_result, evaluation_split, global_step=epoch) - # use DEV split to determine if this is the best model so far - if determine_best_epoch_using_dev_score and evaluation_split == "dev": - validation_scores = eval_result.main_score, eval_result.loss + # use DEV split to determine if this is the best model so far + if determine_best_epoch_using_dev_score and evaluation_split == "dev": + validation_scores = eval_result.main_score, eval_result.loss - if eval_result.main_score > best_epoch_score: - current_epoch_has_best_model_so_far = True - best_epoch_score = eval_result.main_score + if eval_result.main_score > best_epoch_score: + current_epoch_has_best_model_so_far = True + best_epoch_score = eval_result.main_score - # if not using DEV score, determine best model using train loss - if not determine_best_epoch_using_dev_score: - validation_scores = (train_loss,) + # if not using DEV score, determine best model using train loss + if not determine_best_epoch_using_dev_score: + validation_scores = (train_loss,) - if epoch_train_loss < best_epoch_score: - current_epoch_has_best_model_so_far = True - best_epoch_score = train_loss + if epoch_train_loss < best_epoch_score: + current_epoch_has_best_model_so_far = True + best_epoch_score = train_loss # - LossFilePlugin -> somehow prints all relevant metrics # - AnnealPlugin -> scheduler step @@ -776,41 +791,42 @@ def train_custom( self.dispatch("_training_finally") # test best model if test data is present - if self.corpus.test and not train_with_test: - log_line(log) + if is_main_process(): + if self.corpus.test and not train_with_test: + log_line(log) - self.model.eval() + self.model.eval() - if (base_path / "best-model.pt").exists(): - log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) - else: - log.info("Testing using last state of model ...") - - test_results = self.model.evaluate( - self.corpus.test, - gold_label_type=self.model.label_type, - mini_batch_size=eval_batch_size, - out_path=base_path / "test.tsv", - embedding_storage_mode="none", - main_evaluation_metric=main_evaluation_metric, - gold_label_dictionary=gold_label_dictionary_for_eval, - exclude_labels=exclude_labels, - return_loss=False, - ) + if (base_path / "best-model.pt").exists(): + log.info("Loading model from best epoch ...") + self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) + else: + log.info("Testing using last state of model ...") + + test_results = self.model.evaluate( + self.corpus.test, + gold_label_type=self.model.label_type, + mini_batch_size=eval_batch_size, + out_path=base_path / "test.tsv", + embedding_storage_mode="none", + main_evaluation_metric=main_evaluation_metric, + gold_label_dictionary=gold_label_dictionary_for_eval, + exclude_labels=exclude_labels, + return_loss=False, + ) - log.info(test_results.detailed_results) - log_line(log) + log.info(test_results.detailed_results) + log_line(log) - # get and return the final test score of best model - self.return_values["test_score"] = test_results.main_score + # get and return the final test score of best model + self.return_values["test_score"] = test_results.main_score - else: - if (base_path / "best-model.pt").exists(): - log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) - self.return_values["test_score"] = 0 - log.info("Test data not provided setting final score to 0") + else: + if (base_path / "best-model.pt").exists(): + log.info("Loading model from best epoch ...") + self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) + self.return_values["test_score"] = 0 + log.info("Test data not provided setting final score to 0") # MetricHistoryPlugin -> stores the loss history in return_values self.dispatch("after_training") From 546e290042788a3879081aef9c9542c44402a910 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Mon, 30 Sep 2024 16:45:25 -0600 Subject: [PATCH 2/2] GH-3429: Move process spawning inside `.train`; WIP sync gradients --- flair/__init__.py | 6 - flair/distributed_utils.py | 58 ++++--- flair/nn/model.py | 7 +- flair/trainers/trainer.py | 328 +++++++++++++++++++------------------ 4 files changed, 212 insertions(+), 187 deletions(-) diff --git a/flair/__init__.py b/flair/__init__.py index 55bd6d00c..341f630e4 100644 --- a/flair/__init__.py +++ b/flair/__init__.py @@ -33,12 +33,6 @@ else: device = torch.device("cpu") -distributed = False -"""Experimental flag to indicate multiple GPUs are in use. - -Set by `launch_distributed` -- do not set manually. -""" - # global variable: version __version__ = "0.14.0" """The current version of the flair library installed.""" diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py index 596570b7c..39ec8c4ce 100644 --- a/flair/distributed_utils.py +++ b/flair/distributed_utils.py @@ -1,49 +1,63 @@ import logging import os +from multiprocessing.connection import Connection +from typing import Callable +import numpy as np import torch import torch.multiprocessing as mp from torch.distributed import destroy_process_group, init_process_group import flair +from flair.class_utils import T log = logging.getLogger("flair") -def launch_distributed(fp, *args): - """Executes the function fp(*args) on multiple GPUs (all local GPUs)""" - world_size = torch.cuda.device_count() - log.info(f"Launching {world_size} distributed processes") - mp.spawn(entrypoint, args=(world_size, fp, *args), nprocs=world_size) - +def launch_distributed(fn, *args, **kwargs): + """Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU). -def entrypoint(rank, world_size, fp, *args): - ddp_setup(rank, world_size) - fp(*args) + Returns: the return value of the function fp(*args, **kwargs) from the rank 0 process + """ + world_size = torch.cuda.device_count() + log.info(f"Launching {world_size} processes") + parent_conn, child_conn = mp.Pipe() + mp.spawn(_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size) + return_value = parent_conn.recv() + return return_value + + +def _entrypoint(rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict) -> None: + """Lifecycle of a process -- setup, run, cleanup.""" + log.info(f"Started process on rank={rank}") + _ddp_setup(rank, world_size) + return_value = fn(*args, **kwargs) + if is_main_process(): + child_conn.send(return_value) destroy_process_group() -def ddp_setup(rank: int, world_size: int) -> None: +def _ddp_setup(rank: int, world_size: int) -> None: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" - init_process_group(backend="nccl", rank=rank, world_size=world_size) - flair.distributed = True flair.device = torch.device(rank) torch.cuda.set_device(flair.device) + init_process_group(backend="nccl", rank=rank, world_size=world_size) def is_main_process() -> bool: - """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu""" - if flair.distributed: - return flair.device.index == 0 + """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() == 0 else: return True -class DistributedModel(torch.nn.parallel.DistributedDataParallel): - """DistributedDataParallel, but redirects access to methods and attributes to the original Model""" - def __getattr__(self, name): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.module, name) +def aggregate_if_distributed(value: T, aggregation_fn: Callable = np.mean) -> T: + """Gathers value from each process and returns the aggregated value according to the supplied function.""" + if torch.distributed.is_initialized(): + gathered_values = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_values, value) + return aggregation_fn(gathered_values) + else: + return value diff --git a/flair/nn/model.py b/flair/nn/model.py index 1c3e999b3..0d2b2ed7a 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -17,7 +17,6 @@ from flair.class_utils import get_non_abstract_subclasses from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset -from flair.distributed_utils import is_main_process from flair.embeddings import Embeddings from flair.embeddings.base import load_embeddings from flair.file_utils import Tqdm, load_torch_state @@ -48,6 +47,10 @@ def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: """ raise NotImplementedError + def forward(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: + """Wraps forward_loss to maintain compatibility with hooks.""" + return self.forward_loss(data_points) + @abstractmethod def evaluate( self, @@ -119,8 +122,6 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: model_file: the model file checkpoint: currently unused. """ - if not is_main_process(): - return model_state = self._get_state_dict() # write out a "model card" if one is set diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 196fe39b9..f6580e313 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -7,9 +7,12 @@ import warnings from inspect import signature from pathlib import Path -from typing import List, Optional, Tuple, Type, Union +from queue import Queue +from typing import Any, Dict, List, Optional, Tuple, Type, Union +import numpy as np import torch +from torch.nn.parallel import DistributedDataParallel from torch.optim.sgd import SGD from torch.utils.data import DistributedSampler from torch.utils.data.dataset import ConcatDataset @@ -18,7 +21,7 @@ import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader -from flair.distributed_utils import DistributedModel, is_main_process +from flair.distributed_utils import aggregate_if_distributed, is_main_process, launch_distributed from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -165,6 +168,8 @@ def train( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, + # scaling + multi_gpu: bool = False, # plugins plugins: Optional[List[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -200,7 +205,12 @@ def train( "kwargs", ]: local_variables.pop(var) - return self.train_custom(**local_variables, **kwargs) + + if multi_gpu: + self._event_queue = None # Each process will make its own queue rather than share + return launch_distributed(self.train_custom, **local_variables, **kwargs) + else: + return self.train_custom(**local_variables, **kwargs) def fine_tune( self, @@ -239,8 +249,9 @@ def fine_tune( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # amp + # scaling use_amp: bool = False, + multi_gpu: bool = False, # plugins plugins: Optional[List[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -254,47 +265,21 @@ def fine_tune( if attach_default_scheduler: plugins.append(LinearSchedulerPlugin(warmup_fraction=warmup_fraction)) - return self.train_custom( - base_path=base_path, - # training parameters - learning_rate=learning_rate, - decoder_learning_rate=decoder_learning_rate, - mini_batch_size=mini_batch_size, - eval_batch_size=eval_batch_size, - mini_batch_chunk_size=mini_batch_chunk_size, - max_epochs=max_epochs, - optimizer=optimizer, - train_with_dev=train_with_dev, - train_with_test=train_with_test, - reduce_transformer_vocab=reduce_transformer_vocab, - # evaluation and monitoring - main_evaluation_metric=main_evaluation_metric, - monitor_test=monitor_test, - monitor_train_sample=monitor_train_sample, - use_final_model_for_eval=use_final_model_for_eval, - gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, - exclude_labels=exclude_labels, - # sampling and shuffling - sampler=sampler, - shuffle=shuffle, - shuffle_first_epoch=shuffle_first_epoch, - # evaluation and monitoring - embeddings_storage_mode=embeddings_storage_mode, - epoch=epoch, - # when and what to save - save_final_model=save_final_model, - save_optimizer_state=save_optimizer_state, - save_model_each_k_epochs=save_model_each_k_epochs, - # logging parameters - create_file_logs=create_file_logs, - create_loss_file=create_loss_file, - write_weights=write_weights, - # amp - use_amp=use_amp, - # plugins - plugins=plugins, - **kwargs, - ) + # call self.train_custom with all parameters (minus the ones specific to the LinearSchedulerPlugin) + local_variables = locals() + for var in [ + "self", + "warmup_fraction", + "attach_default_scheduler", + "kwargs", + ]: + local_variables.pop(var) + + if multi_gpu: + self._event_queue = None + return launch_distributed(self.train_custom, **local_variables, **kwargs) + else: + return self.train_custom(**local_variables, **kwargs) def train_custom( self, @@ -333,8 +318,9 @@ def train_custom( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # amp + # scaling use_amp: bool = False, + multi_gpu: bool = False, # plugins plugins: Optional[List[TrainerPlugin]] = None, **kwargs, @@ -377,6 +363,7 @@ def train_custom( create_file_logs: If True, logging output is written to a file create_loss_file: If True, a loss file logging output is created use_amp: If True, uses the torch automatic mixed precision + multi_gpu: If True, uses all available GPUs write_weights: If True, write weights to weights.txt on each batch logging event. plugins: Any additional plugins you want to pass to the trainer **kwargs: Additional arguments, for instance for the optimizer @@ -422,8 +409,11 @@ def train_custom( base_path=base_path, ).attach_to(self) - if flair.distributed: - self.model = DistributedModel(self.model, device_ids=[flair.device.index]) + if multi_gpu: + self.model.to(flair.device) + self.ddp_model = DistributedDataParallel(self.model, device_ids=[flair.device.index]) + self._event_queue = Queue() # Each process uses its own queue rather than share + log.disabled = not is_main_process() # Disable logging in distributed mode for all but the main process # === END BLOCK: ACTIVATE PLUGINS === # # derive parameters the function was called with (or defaults) @@ -512,38 +502,38 @@ def train_custom( if use_final_model_for_eval else "model from best epoch (best-model.pt)" ) - - if is_main_process(): - log_line(log) - log.info(f'Model: "{self.model}"') - log_line(log) - log.info(f"{self.corpus}") - log_line(log) - log.info(f"Train: {len(train_data)} sentences") - log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})") - log_line(log) - log.info("Training Params:") - log.info( - f' - learning_rate: "{learning_rate}" ' - f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}' - ) - log.info(f' - mini_batch_size: "{mini_batch_size}"') - log.info(f' - max_epochs: "{max_epochs}"') - log.info(f' - shuffle: "{shuffle}"') - log_line(log) - log.info("Plugins:") - for plugin in plugins: - log.info(" - " + str(plugin)) - log_line(log) - log.info(f"Final evaluation on {final_eval_info}") - log.info(f' - metric: "{main_evaluation_metric}"') - log_line(log) - log.info("Computation:") - log.info(f" - compute on device: {flair.device}") - log.info(f" - embedding storage: {embeddings_storage_mode}") - log_line(log) - log.info(f'Model training base path: "{base_path}"') - log_line(log) + computation_device_info = f"{torch.cuda.device_count()} GPUs" if multi_gpu else flair.device + + log_line(log) + log.info(f'Model: "{self.model}"') + log_line(log) + log.info(f"{self.corpus}") + log_line(log) + log.info(f"Train: {len(train_data)} sentences") + log.info(f" (train_with_dev={train_with_dev}, train_with_test={train_with_test})") + log_line(log) + log.info("Training Params:") + log.info( + f' - learning_rate: "{learning_rate}" ' + f'{"(decoder: " + str(decoder_learning_rate) + ")" if decoder_learning_rate else ""}' + ) + log.info(f' - mini_batch_size: "{mini_batch_size}"') + log.info(f' - max_epochs: "{max_epochs}"') + log.info(f' - shuffle: "{shuffle}"') + log_line(log) + log.info("Plugins:") + for plugin in plugins: + log.info(" - " + str(plugin)) + log_line(log) + log.info(f"Final evaluation on {final_eval_info}") + log.info(f' - metric: "{main_evaluation_metric}"') + log_line(log) + log.info("Computation:") + log.info(f" - compute on device: {computation_device_info}") + log.info(f" - embedding storage: {embeddings_storage_mode}") + log_line(log) + log.info(f'Model training base path: "{base_path}"') + log_line(log) # At any point you can hit Ctrl + C to break out of training early. try: @@ -565,14 +555,14 @@ def train_custom( if not shuffle_first_epoch and epoch == 1: shuffle_data_this_epoch = False - if flair.distributed: + if multi_gpu: batch_loader = DataLoader( train_data, batch_size=mini_batch_size, shuffle=False, sampler=DistributedSampler(train_data, shuffle=shuffle_data_this_epoch), ) - batch_loader.sampler.set_epoch(epoch) + batch_loader.sampler.set_epoch(epoch - 1) else: batch_loader = DataLoader( train_data, @@ -617,7 +607,10 @@ def train_custom( for batch_step in batch_steps: # forward pass with torch.autocast(device_type=flair.device.type, enabled=use_amp): - loss, datapoint_count = self.model.forward_loss(batch_step) + if multi_gpu: + loss, datapoint_count = self.ddp_model(batch_step) + else: + loss, datapoint_count = self.model.forward_loss(batch_step) batch_train_samples += datapoint_count batch_train_loss += loss.item() @@ -663,8 +656,11 @@ def train_custom( if epoch_train_samples > 0 else epoch_train_samples / (batch_no + 1) ) + intermittent_loss = aggregate_if_distributed(intermittent_loss) current_time = time.time() + samples_per_second = epoch_train_samples / (current_time - epoch_start_time) + samples_per_second = aggregate_if_distributed(samples_per_second, np.sum) lr_info, momentum_info = self._get_current_lr_and_momentum(batch_count) log.info( @@ -672,7 +668,7 @@ def train_custom( f" - iter {batch_no + 1}/{len(batch_loader)}" f" - loss {intermittent_loss:.8f}" f" - time (sec): {(current_time - epoch_start_time):.2f}" - f" - samples/sec: {epoch_train_samples / (current_time - epoch_start_time):.2f}" + f" - samples/sec: {samples_per_second:.2f}" f"{lr_info}{momentum_info}" ) @@ -681,6 +677,7 @@ def train_custom( self.dispatch("after_training_batch", **batch_kw) train_loss = epoch_train_loss / epoch_train_samples + train_loss = aggregate_if_distributed(train_loss) self._record(MetricRecord.scalar(("train", "loss"), train_loss, epoch)) total_train_samples += epoch_train_samples @@ -696,50 +693,49 @@ def train_custom( # Determine if this is the best model or if we need to anneal current_epoch_has_best_model_so_far = False - validation_scores = () - - if is_main_process(): - for evaluation_split, evaluation_split_data in evaluation_splits.items(): - eval_result = self.model.evaluate( - evaluation_split_data, - out_path=base_path / f"{evaluation_split}.tsv", - mini_batch_size=eval_batch_size, - exclude_labels=exclude_labels, - main_evaluation_metric=main_evaluation_metric, - gold_label_dictionary=gold_label_dictionary_for_eval, - embedding_storage_mode=embeddings_storage_mode, - gold_label_type=self.model.label_type, - gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, - ) + validation_scores: tuple = () + + for evaluation_split, evaluation_split_data in evaluation_splits.items(): + eval_result = self.model.evaluate( + evaluation_split_data, + out_path=base_path / f"{evaluation_split}.tsv", + mini_batch_size=eval_batch_size, + exclude_labels=exclude_labels, + main_evaluation_metric=main_evaluation_metric, + gold_label_dictionary=gold_label_dictionary_for_eval, + embedding_storage_mode=embeddings_storage_mode, + gold_label_type=self.model.label_type, + gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, + ) - # log results - log.info( - f"{evaluation_split.upper()} : loss {eval_result.loss}" - f" - {main_evaluation_metric[1]}" - f" ({main_evaluation_metric[0]})" - f" {round(eval_result.main_score, 4)}" - ) + # log results + log.info( + f"{evaluation_split.upper()} : loss {eval_result.loss}" + f" - {main_evaluation_metric[1]}" + f" ({main_evaluation_metric[0]})" + f" {round(eval_result.main_score, 4)}" + ) - # depending on memory mode, embeddings are moved to CPU, GPU or deleted - store_embeddings(evaluation_split_data, embeddings_storage_mode) + # depending on memory mode, embeddings are moved to CPU, GPU or deleted + store_embeddings(evaluation_split_data, embeddings_storage_mode) - self._publish_eval_result(eval_result, evaluation_split, global_step=epoch) + self._publish_eval_result(eval_result, evaluation_split, global_step=epoch) - # use DEV split to determine if this is the best model so far - if determine_best_epoch_using_dev_score and evaluation_split == "dev": - validation_scores = eval_result.main_score, eval_result.loss + # use DEV split to determine if this is the best model so far + if determine_best_epoch_using_dev_score and evaluation_split == "dev": + validation_scores = eval_result.main_score, eval_result.loss - if eval_result.main_score > best_epoch_score: - current_epoch_has_best_model_so_far = True - best_epoch_score = eval_result.main_score + if eval_result.main_score > best_epoch_score: + current_epoch_has_best_model_so_far = True + best_epoch_score = eval_result.main_score - # if not using DEV score, determine best model using train loss - if not determine_best_epoch_using_dev_score: - validation_scores = (train_loss,) + # if not using DEV score, determine best model using train loss + if not determine_best_epoch_using_dev_score: + validation_scores = (train_loss,) - if epoch_train_loss < best_epoch_score: - current_epoch_has_best_model_so_far = True - best_epoch_score = train_loss + if train_loss < best_epoch_score: + current_epoch_has_best_model_so_far = True + best_epoch_score = train_loss # - LossFilePlugin -> somehow prints all relevant metrics # - AnnealPlugin -> scheduler step @@ -752,14 +748,14 @@ def train_custom( if save_best_model and current_epoch_has_best_model_so_far: log.info("saving best model") - self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state) # - SWAPlugin -> restores SGD weights from SWA self.dispatch("after_training_loop") # if we do not use dev data for model selection, save final model if save_final_model: - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) except KeyboardInterrupt: log_line(log) @@ -769,7 +765,7 @@ def train_custom( if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except TrainingInterrupt as exc: @@ -780,7 +776,7 @@ def train_custom( if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except Exception: @@ -791,42 +787,41 @@ def train_custom( self.dispatch("_training_finally") # test best model if test data is present - if is_main_process(): - if self.corpus.test and not train_with_test: - log_line(log) + if self.corpus.test and not train_with_test: + log_line(log) - self.model.eval() + self.model.eval() - if (base_path / "best-model.pt").exists(): - log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) - else: - log.info("Testing using last state of model ...") - - test_results = self.model.evaluate( - self.corpus.test, - gold_label_type=self.model.label_type, - mini_batch_size=eval_batch_size, - out_path=base_path / "test.tsv", - embedding_storage_mode="none", - main_evaluation_metric=main_evaluation_metric, - gold_label_dictionary=gold_label_dictionary_for_eval, - exclude_labels=exclude_labels, - return_loss=False, - ) + if (base_path / "best-model.pt").exists(): + log.info("Loading model from best epoch ...") + self._load_model(base_path / "best-model.pt") + else: + log.info("Testing using last state of model ...") + + test_results = self.model.evaluate( + self.corpus.test, + gold_label_type=self.model.label_type, + mini_batch_size=eval_batch_size, + out_path=base_path / "test.tsv", + embedding_storage_mode="none", + main_evaluation_metric=main_evaluation_metric, + gold_label_dictionary=gold_label_dictionary_for_eval, + exclude_labels=exclude_labels, + return_loss=False, + ) - log.info(test_results.detailed_results) - log_line(log) + log.info(test_results.detailed_results) + log_line(log) - # get and return the final test score of best model - self.return_values["test_score"] = test_results.main_score + # get and return the final test score of best model + self.return_values["test_score"] = test_results.main_score - else: - if (base_path / "best-model.pt").exists(): - log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) - self.return_values["test_score"] = 0 - log.info("Test data not provided setting final score to 0") + else: + if (base_path / "best-model.pt").exists(): + log.info("Loading model from best epoch ...") + self._load_model(base_path / "best-model.pt") + self.return_values["test_score"] = 0 + log.info("Test data not provided setting final score to 0") # MetricHistoryPlugin -> stores the loss history in return_values self.dispatch("after_training") @@ -840,7 +835,9 @@ def train_custom( def _get_current_lr_and_momentum(self, batch_count): current_learning_rate = [group["lr"] for group in self.optimizer.param_groups] + current_learning_rate = [aggregate_if_distributed(m) for m in current_learning_rate] momentum = [group.get("momentum", 0) for group in self.optimizer.param_groups] + momentum = [aggregate_if_distributed(m) for m in momentum] lr_info = " - lr: " + ",".join([f"{m:.6f}" for m in current_learning_rate]) momentum_info = " - momentum: " + ",".join([f"{m:.6f}" for m in momentum]) self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, batch_count)) @@ -921,3 +918,22 @@ def _initialize_model_card(self, **training_parameters): def _record(self, metric): self.dispatch("metric_recorded", metric) + + def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: + """Saves the current model. Safe to call from a distributed context. + + Args: + model_file: the model file + checkpoint: currently unused. + """ + if is_main_process(): + self.model.save(model_file, checkpoint) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Prevent any process from loading a model until writing is complete + + def _load_model(self, model_file: Union[str, Path]) -> None: + """Loads the model from the given file into the current state. Safe to call from a distributed context.""" + self.model.load_state_dict(self.model.load(model_file).state_dict()) + if torch.distributed.is_initialized(): + self.ddp_model = DistributedDataParallel(self.model, device_ids=[flair.device.index])