diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py new file mode 100644 index 000000000..39ec8c4ce --- /dev/null +++ b/flair/distributed_utils.py @@ -0,0 +1,63 @@ +import logging +import os +from multiprocessing.connection import Connection +from typing import Callable + +import numpy as np +import torch +import torch.multiprocessing as mp +from torch.distributed import destroy_process_group, init_process_group + +import flair +from flair.class_utils import T + +log = logging.getLogger("flair") + + +def launch_distributed(fn, *args, **kwargs): + """Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU). + + Returns: the return value of the function fp(*args, **kwargs) from the rank 0 process + """ + world_size = torch.cuda.device_count() + log.info(f"Launching {world_size} processes") + parent_conn, child_conn = mp.Pipe() + mp.spawn(_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size) + return_value = parent_conn.recv() + return return_value + + +def _entrypoint(rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict) -> None: + """Lifecycle of a process -- setup, run, cleanup.""" + log.info(f"Started process on rank={rank}") + _ddp_setup(rank, world_size) + return_value = fn(*args, **kwargs) + if is_main_process(): + child_conn.send(return_value) + destroy_process_group() + + +def _ddp_setup(rank: int, world_size: int) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + flair.device = torch.device(rank) + torch.cuda.set_device(flair.device) + init_process_group(backend="nccl", rank=rank, world_size=world_size) + + +def is_main_process() -> bool: + """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() == 0 + else: + return True + + +def aggregate_if_distributed(value: T, aggregation_fn: Callable = np.mean) -> T: + """Gathers value from each process and returns the aggregated value according to the supplied function.""" + if torch.distributed.is_initialized(): + gathered_values = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_values, value) + return aggregation_fn(gathered_values) + else: + return value diff --git a/flair/nn/model.py b/flair/nn/model.py index eeb5b7c84..0d2b2ed7a 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -47,6 +47,10 @@ def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: """ raise NotImplementedError + def forward(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]: + """Wraps forward_loss to maintain compatibility with hooks.""" + return self.forward_loss(data_points) + @abstractmethod def evaluate( self, diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 03e6edc08..f6580e313 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -7,16 +7,21 @@ import warnings from inspect import signature from pathlib import Path -from typing import List, Optional, Tuple, Type, Union +from queue import Queue +from typing import Any, Dict, List, Optional, Tuple, Type, Union +import numpy as np import torch +from torch.nn.parallel import DistributedDataParallel from torch.optim.sgd import SGD +from torch.utils.data import DistributedSampler from torch.utils.data.dataset import ConcatDataset import flair import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader +from flair.distributed_utils import aggregate_if_distributed, is_main_process, launch_distributed from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -163,6 +168,8 @@ def train( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, + # scaling + multi_gpu: bool = False, # plugins plugins: Optional[List[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -198,7 +205,12 @@ def train( "kwargs", ]: local_variables.pop(var) - return self.train_custom(**local_variables, **kwargs) + + if multi_gpu: + self._event_queue = None # Each process will make its own queue rather than share + return launch_distributed(self.train_custom, **local_variables, **kwargs) + else: + return self.train_custom(**local_variables, **kwargs) def fine_tune( self, @@ -237,8 +249,9 @@ def fine_tune( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # amp + # scaling use_amp: bool = False, + multi_gpu: bool = False, # plugins plugins: Optional[List[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -252,47 +265,21 @@ def fine_tune( if attach_default_scheduler: plugins.append(LinearSchedulerPlugin(warmup_fraction=warmup_fraction)) - return self.train_custom( - base_path=base_path, - # training parameters - learning_rate=learning_rate, - decoder_learning_rate=decoder_learning_rate, - mini_batch_size=mini_batch_size, - eval_batch_size=eval_batch_size, - mini_batch_chunk_size=mini_batch_chunk_size, - max_epochs=max_epochs, - optimizer=optimizer, - train_with_dev=train_with_dev, - train_with_test=train_with_test, - reduce_transformer_vocab=reduce_transformer_vocab, - # evaluation and monitoring - main_evaluation_metric=main_evaluation_metric, - monitor_test=monitor_test, - monitor_train_sample=monitor_train_sample, - use_final_model_for_eval=use_final_model_for_eval, - gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, - exclude_labels=exclude_labels, - # sampling and shuffling - sampler=sampler, - shuffle=shuffle, - shuffle_first_epoch=shuffle_first_epoch, - # evaluation and monitoring - embeddings_storage_mode=embeddings_storage_mode, - epoch=epoch, - # when and what to save - save_final_model=save_final_model, - save_optimizer_state=save_optimizer_state, - save_model_each_k_epochs=save_model_each_k_epochs, - # logging parameters - create_file_logs=create_file_logs, - create_loss_file=create_loss_file, - write_weights=write_weights, - # amp - use_amp=use_amp, - # plugins - plugins=plugins, - **kwargs, - ) + # call self.train_custom with all parameters (minus the ones specific to the LinearSchedulerPlugin) + local_variables = locals() + for var in [ + "self", + "warmup_fraction", + "attach_default_scheduler", + "kwargs", + ]: + local_variables.pop(var) + + if multi_gpu: + self._event_queue = None + return launch_distributed(self.train_custom, **local_variables, **kwargs) + else: + return self.train_custom(**local_variables, **kwargs) def train_custom( self, @@ -331,8 +318,9 @@ def train_custom( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # amp + # scaling use_amp: bool = False, + multi_gpu: bool = False, # plugins plugins: Optional[List[TrainerPlugin]] = None, **kwargs, @@ -375,6 +363,7 @@ def train_custom( create_file_logs: If True, logging output is written to a file create_loss_file: If True, a loss file logging output is created use_amp: If True, uses the torch automatic mixed precision + multi_gpu: If True, uses all available GPUs write_weights: If True, write weights to weights.txt on each batch logging event. plugins: Any additional plugins you want to pass to the trainer **kwargs: Additional arguments, for instance for the optimizer @@ -420,6 +409,11 @@ def train_custom( base_path=base_path, ).attach_to(self) + if multi_gpu: + self.model.to(flair.device) + self.ddp_model = DistributedDataParallel(self.model, device_ids=[flair.device.index]) + self._event_queue = Queue() # Each process uses its own queue rather than share + log.disabled = not is_main_process() # Disable logging in distributed mode for all but the main process # === END BLOCK: ACTIVATE PLUGINS === # # derive parameters the function was called with (or defaults) @@ -508,6 +502,7 @@ def train_custom( if use_final_model_for_eval else "model from best epoch (best-model.pt)" ) + computation_device_info = f"{torch.cuda.device_count()} GPUs" if multi_gpu else flair.device log_line(log) log.info(f'Model: "{self.model}"') @@ -534,7 +529,7 @@ def train_custom( log.info(f' - metric: "{main_evaluation_metric}"') log_line(log) log.info("Computation:") - log.info(f" - compute on device: {flair.device}") + log.info(f" - compute on device: {computation_device_info}") log.info(f" - embedding storage: {embeddings_storage_mode}") log_line(log) log.info(f'Model training base path: "{base_path}"') @@ -560,12 +555,21 @@ def train_custom( if not shuffle_first_epoch and epoch == 1: shuffle_data_this_epoch = False - batch_loader = DataLoader( - train_data, - batch_size=mini_batch_size, - shuffle=shuffle_data_this_epoch, - sampler=sampler, - ) + if multi_gpu: + batch_loader = DataLoader( + train_data, + batch_size=mini_batch_size, + shuffle=False, + sampler=DistributedSampler(train_data, shuffle=shuffle_data_this_epoch), + ) + batch_loader.sampler.set_epoch(epoch - 1) + else: + batch_loader = DataLoader( + train_data, + batch_size=mini_batch_size, + shuffle=shuffle_data_this_epoch, + sampler=sampler, + ) self.model.train() @@ -603,7 +607,10 @@ def train_custom( for batch_step in batch_steps: # forward pass with torch.autocast(device_type=flair.device.type, enabled=use_amp): - loss, datapoint_count = self.model.forward_loss(batch_step) + if multi_gpu: + loss, datapoint_count = self.ddp_model(batch_step) + else: + loss, datapoint_count = self.model.forward_loss(batch_step) batch_train_samples += datapoint_count batch_train_loss += loss.item() @@ -649,8 +656,11 @@ def train_custom( if epoch_train_samples > 0 else epoch_train_samples / (batch_no + 1) ) + intermittent_loss = aggregate_if_distributed(intermittent_loss) current_time = time.time() + samples_per_second = epoch_train_samples / (current_time - epoch_start_time) + samples_per_second = aggregate_if_distributed(samples_per_second, np.sum) lr_info, momentum_info = self._get_current_lr_and_momentum(batch_count) log.info( @@ -658,7 +668,7 @@ def train_custom( f" - iter {batch_no + 1}/{len(batch_loader)}" f" - loss {intermittent_loss:.8f}" f" - time (sec): {(current_time - epoch_start_time):.2f}" - f" - samples/sec: {epoch_train_samples / (current_time - epoch_start_time):.2f}" + f" - samples/sec: {samples_per_second:.2f}" f"{lr_info}{momentum_info}" ) @@ -667,6 +677,7 @@ def train_custom( self.dispatch("after_training_batch", **batch_kw) train_loss = epoch_train_loss / epoch_train_samples + train_loss = aggregate_if_distributed(train_loss) self._record(MetricRecord.scalar(("train", "loss"), train_loss, epoch)) total_train_samples += epoch_train_samples @@ -682,7 +693,7 @@ def train_custom( # Determine if this is the best model or if we need to anneal current_epoch_has_best_model_so_far = False - validation_scores: tuple + validation_scores: tuple = () for evaluation_split, evaluation_split_data in evaluation_splits.items(): eval_result = self.model.evaluate( @@ -722,7 +733,7 @@ def train_custom( if not determine_best_epoch_using_dev_score: validation_scores = (train_loss,) - if epoch_train_loss < best_epoch_score: + if train_loss < best_epoch_score: current_epoch_has_best_model_so_far = True best_epoch_score = train_loss @@ -737,14 +748,14 @@ def train_custom( if save_best_model and current_epoch_has_best_model_so_far: log.info("saving best model") - self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state) # - SWAPlugin -> restores SGD weights from SWA self.dispatch("after_training_loop") # if we do not use dev data for model selection, save final model if save_final_model: - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) except KeyboardInterrupt: log_line(log) @@ -754,7 +765,7 @@ def train_custom( if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except TrainingInterrupt as exc: @@ -765,7 +776,7 @@ def train_custom( if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except Exception: @@ -783,7 +794,7 @@ def train_custom( if (base_path / "best-model.pt").exists(): log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) + self._load_model(base_path / "best-model.pt") else: log.info("Testing using last state of model ...") @@ -808,7 +819,7 @@ def train_custom( else: if (base_path / "best-model.pt").exists(): log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) + self._load_model(base_path / "best-model.pt") self.return_values["test_score"] = 0 log.info("Test data not provided setting final score to 0") @@ -824,7 +835,9 @@ def train_custom( def _get_current_lr_and_momentum(self, batch_count): current_learning_rate = [group["lr"] for group in self.optimizer.param_groups] + current_learning_rate = [aggregate_if_distributed(m) for m in current_learning_rate] momentum = [group.get("momentum", 0) for group in self.optimizer.param_groups] + momentum = [aggregate_if_distributed(m) for m in momentum] lr_info = " - lr: " + ",".join([f"{m:.6f}" for m in current_learning_rate]) momentum_info = " - momentum: " + ",".join([f"{m:.6f}" for m in momentum]) self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, batch_count)) @@ -905,3 +918,22 @@ def _initialize_model_card(self, **training_parameters): def _record(self, metric): self.dispatch("metric_recorded", metric) + + def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: + """Saves the current model. Safe to call from a distributed context. + + Args: + model_file: the model file + checkpoint: currently unused. + """ + if is_main_process(): + self.model.save(model_file, checkpoint) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Prevent any process from loading a model until writing is complete + + def _load_model(self, model_file: Union[str, Path]) -> None: + """Loads the model from the given file into the current state. Safe to call from a distributed context.""" + self.model.load_state_dict(self.model.load(model_file).state_dict()) + if torch.distributed.is_initialized(): + self.ddp_model = DistributedDataParallel(self.model, device_ids=[flair.device.index])