Skip to content

Commit

Permalink
Initial support for Metal Performance Shaders (MPS) (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
dameikle authored Oct 2, 2024
1 parent 4a3d0dd commit a8cfb0d
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 26 deletions.
2 changes: 1 addition & 1 deletion eole/bin/model/extract_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def add_args(cls, parser):
@classmethod
def run(cls, args):
args.cuda = args.gpu > -1
if args.cuda:
if args.cuda and torch.cuda.is_available():
torch.cuda.set_device(args.gpu)

# Add in default model arguments, possibly added since training.
Expand Down
3 changes: 2 additions & 1 deletion eole/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from eole.utils.distributed import ErrorHandler
from eole.utils.distributed_workers import spawned_infer
from eole.utils.logging import init_logger
from eole.utils.misc import get_device_type
from eole.transforms import get_transforms_cls, make_transforms, TransformPipe


Expand Down Expand Up @@ -250,7 +251,7 @@ def __init__(self, config, model_type=None):
if config.world_size == 1:
self.device_id = config.gpu_ranks[0]
self.device_index = config.gpu_ranks
self.device = "cuda"
self.device = get_device_type()
else:
self.device_id = -1
self.device_index = 0
Expand Down
4 changes: 2 additions & 2 deletions eole/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from eole.config.data import Dataset
from eole.utils.logging import init_logger, logger
from eole.utils.misc import RandomShuffler
from eole.utils.misc import RandomShuffler, get_device
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -446,7 +446,7 @@ def build_dynamic_dataset_iter(
if corpora is None:
assert task != CorpusTask.TRAIN, "only valid corpus is ignorable."
return None
device = torch.device(device_id) if device_id >= 0 else torch.device("cpu")
device = get_device(device_id=device_id) if device_id >= 0 else torch.device("cpu")
if hasattr(config, "training"):
num_workers = getattr(config.training, "num_workers", 0)
else:
Expand Down
8 changes: 4 additions & 4 deletions eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from torch.nn.utils import skip_init
from torch.nn.init import xavier_uniform_, zeros_, uniform_
from eole.utils.misc import use_gpu, sequence_mask
from eole.utils.misc import use_gpu, sequence_mask, get_device
from eole.inputters.inputter import dict_to_vocabs

# copied from model_builder to facilitate tests, but should not live there in the end
Expand Down Expand Up @@ -259,11 +259,11 @@ def training_logic(self, running_config, vocabs, checkpoint, device_id):
running_config.world_size > 1
and running_config.parallel_mode == "tensor_parallel"
):
device = torch.device("cuda")
device = get_device()
offset = device_id
else:
if use_gpu(running_config):
device = torch.device("cuda")
device = get_device()
else:
device = torch.device("cpu")
offset = 0
Expand Down Expand Up @@ -340,7 +340,7 @@ def inference_logic(self, checkpoint, running_config, vocabs, device_id=None):
if use_gpu(running_config):
if len(running_config.gpu_ranks) > 0:
device_id = running_config.gpu_ranks[0]
device = torch.device("cuda", device_id)
device = get_device(device_id=device_id)
else:
device = torch.device("cpu")
offset = 0
Expand Down
5 changes: 2 additions & 3 deletions eole/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def _compute_attention(
self.position_encoding_type
not in [PositionEncodingType.Relative, PositionEncodingType.Alibi]
and not return_attn
and query.device != torch.device("cpu")
and query.device.type != "cpu"
):
causal = self.is_decoder and attn_type == "self" and mask is not None
# keeping this (vs sdpa below) only because it handles windows_size
Expand Down Expand Up @@ -462,10 +462,9 @@ def _compute_attention(
value,
~mask if mask is not None else None,
self.dropout_p,
is_causal=causal,
is_causal=False,
)
attn = None

else:
query /= sqrt(self.dim_per_head)
# batch x num_heads x query_len x key_len
Expand Down
17 changes: 15 additions & 2 deletions eole/predict/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ def initialize_(self, enc_out, device, target_prefix):
self.topk_scores = torch.empty(
(self.batch_size, self.beam_size), dtype=torch.float, device=device
)
# MPS doesn't support torch.isin() in Torch 2.3
# Avoiding need to CPU fallback by adding alternative implementation
# Can be removed when Torch 2.4 is supported
self._is_finished_list = (
self._is_finished_list_mps
if (device is not None and device.type == "mps")
else self._is_finished_list_isin
)
"""
self.topk_ids = torch.empty(
(self.batch_size, self.beam_size), dtype=torch.long, device=device
Expand Down Expand Up @@ -382,10 +390,15 @@ def advance(self, log_probs, attn):
)
self.topk_scores -= cov_penalty.view(_B, self.beam_size).float()

self.is_finished_list = torch.isin(self.topk_ids, self.eos_t).tolist()

self.is_finished_list = self._is_finished_list()
self.ensure_max_length()

def _is_finished_list_isin(self):
return torch.isin(self.topk_ids, self.eos_t).tolist()

def _is_finished_list_mps(self):
return (self.topk_ids.unsqueeze(1) == self.eos_t).sum(dim=1).bool().tolist()


class BeamSearch(BeamSearchBase):
"""
Expand Down
14 changes: 14 additions & 0 deletions eole/predict/greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ def initialize(self, enc_out, src_len, device=None, target_prefix=None):
self.beams_scores = torch.zeros(
(self.batch_size * self.beam_size, 1), dtype=torch.float, device=device
)
# MPS doesn't support torch.isin() in Torch 2.3
# Avoiding need to CPU fallback by adding alternative implementation
# Can be removed when Torch 2.4 is supported
self._is_finished_list = (
self._is_finished_list_mps
if (device is not None and device.type == "mps")
else self._is_finished_list_isin
)
return fn_map_state, enc_out

@property
Expand Down Expand Up @@ -292,6 +300,12 @@ def update_finished(self):
self.original_batch_idx = self.original_batch_idx[self.select_indices]
self.maybe_update_target_prefix(self.select_indices)

def _is_finished_list_isin(self):
return torch.isin(self.topk_ids, self.eos_t).tolist()

def _is_finished_list_mps(self):
return (self.topk_ids.unsqueeze(1) == self.eos_t).sum(dim=1).bool().tolist()


class GreedySearchLM(GreedySearch):
def update_finished(self):
Expand Down
10 changes: 4 additions & 6 deletions eole/predict/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from eole.transforms import TransformPipe
from eole.constants import DefaultTokens
from eole.predict.prediction import PredictionBuilder
from eole.utils.misc import set_random_seed, report_matrix, sequence_mask
from eole.utils.misc import set_random_seed, report_matrix, sequence_mask, get_device
from eole.utils.alignment import build_align_pharaoh


Expand Down Expand Up @@ -111,10 +111,8 @@ def __init__(
self._tgt_vocab_len = len(self._tgt_vocab)

self._gpu = gpu
self._use_cuda = gpu > -1
self._dev = (
torch.device("cuda", self._gpu) if self._use_cuda else torch.device("cpu")
)
self._use_gpu = gpu > -1
self._dev = get_device(self._gpu) if self._use_gpu else torch.device("cpu")

self.n_best = n_best
self.max_length = max_length
Expand Down Expand Up @@ -165,7 +163,7 @@ def __init__(
"log_probs": [],
}

set_random_seed(seed, self._use_cuda)
set_random_seed(seed, self._use_gpu)
self.with_score = with_score

self.return_gold_log_probs = return_gold_log_probs
Expand Down
3 changes: 2 additions & 1 deletion eole/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def _init_train(config):

def configure_process(config, device_id):
if device_id >= 0:
torch.cuda.set_device(device_id)
if torch.cuda.is_available():
torch.cuda.set_device(device_id)
set_random_seed(config.seed, device_id >= 0)


Expand Down
9 changes: 5 additions & 4 deletions eole/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import eole.utils
from eole.utils.loss import LossCompute
from eole.utils.logging import logger
from eole.utils.misc import clear_gpu_cache, get_autocast
from eole.utils.scoring_utils import ScoringPreparator
from eole.scorers import get_scorers_cls, build_scorers

Expand Down Expand Up @@ -322,7 +323,7 @@ def train(
report_stats = eole.utils.Statistics()
self._start_report_manager(start_time=total_stats.start_time)
# Let's clean the GPUs before training loop
torch.cuda.empty_cache()
clear_gpu_cache()

for i, (batches, normalization) in enumerate(self._accum_batches(train_iter)):

Expand Down Expand Up @@ -412,7 +413,7 @@ def validate(self, valid_iter, moving_average=None):
src_len = batch["srclen"]
tgt = batch["tgt"]

with torch.cuda.amp.autocast(enabled=self.optim.amp):
with get_autocast(enabled=self.optim.amp):
# F-prop through the model.
model_out, attns, estim = valid_model(
src, tgt, src_len, with_align=self.with_align
Expand Down Expand Up @@ -515,7 +516,7 @@ def _gradient_accumulation(
if self.accum_count == 1:
self.optim.zero_grad(set_to_none=True)
try:
with torch.cuda.amp.autocast(enabled=self.optim.amp):
with get_autocast(enabled=self.optim.amp):
model_out, attns, estim = self.model(
src, tgt, src_len, bptt=bptt, with_align=self.with_align
)
Expand Down Expand Up @@ -549,7 +550,7 @@ def _gradient_accumulation(
"Step %d, cuda OOM - batch removed",
self.optim.training_step,
)
torch.cuda.empty_cache()
clear_gpu_cache()
if self.n_gpu > 1 and self.parallel_mode == "tensor_parallel":
torch.distributed.destroy_process_group()
sys.exit()
Expand Down
4 changes: 3 additions & 1 deletion eole/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from eole.modules.sparse_activations import LogSparsemax
from eole.constants import DefaultTokens
from eole.models.model import DecoderModel
from eole.utils.misc import get_device

try:
import ctranslate2
Expand Down Expand Up @@ -75,8 +76,9 @@ def from_config(cls, config, model, vocabs, train=True):
training/validation logging.
The Criterion and LossCompute options are triggered by opt settings.
"""

device = torch.device(
"cuda" if eole.utils.misc.use_gpu(config.training) else "cpu"
get_device() if eole.utils.misc.use_gpu(config.training) else "cpu"
)
pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD)
padding_idx = vocabs["tgt"][pad_token]
Expand Down
45 changes: 44 additions & 1 deletion eole/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import random
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
import inspect
import numpy as np
import os
Expand Down Expand Up @@ -72,6 +72,49 @@ def use_gpu(config):
return hasattr(config, "gpu_ranks") and len(config.gpu_ranks) > 0


def clear_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()


def get_device_type():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"


def get_device(device_id=None):
if torch.cuda.is_available():
if device_id is not None:
return torch.device(f"cuda:{device_id}")
else:
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")


def get_autocast(enabled=True, device_type="auto"):
if not enabled:
return nullcontext()

if device_type == "auto":
device_type = get_device_type()

if device_type == "cuda":
return torch.cuda.amp.autocast()
elif device_type == "mps":
return torch.amp.autocast(device_type="mps")
else:
return torch.cpu.amp.autocast()


def set_random_seed(seed, is_cuda):
"""Sets the random seed."""
if seed > 0:
Expand Down

0 comments on commit a8cfb0d

Please sign in to comment.