Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MacOS mps supported #2039

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def detect(save_img=False):
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu or mps')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
Expand Down
2 changes: 1 addition & 1 deletion models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def attempt_load(weights, map_location=None):
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w)
ckpt = torch.load(w, map_location=map_location) # load
ckpt = torch.load(w, map_location=map_location) if 'mps' not in map_location.type else torch.load(w, map_location='cpu') # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model

# Compatibility updates
Expand Down
3 changes: 3 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""

if 'mps' in prediction.device.type:
prediction = prediction.cpu()

nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates

Expand Down
20 changes: 16 additions & 4 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
from packaging.version import Version

import torch
import torch.backends.cudnn as cudnn
if Version(torch.__version__) >= Version('1.13.0'):
import torch.backends.mps
import torch.nn as nn
import torch.nn.functional as F
import torchvision
Expand Down Expand Up @@ -61,16 +64,20 @@ def git_describe(path=Path(__file__).parent): # path must be a directory


def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
# device = 'cpu' or 'mps' or '0' or '0,1,2,3'
s = f'YOLOR 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu'
if cpu:
mps = device.lower() == 'mps'
if not Version(torch.__version__) >= Version('1.13.0'):
cpu = True
mps = False
if cpu or mps:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability

cuda = not cpu and torch.cuda.is_available()
cuda = not cpu and not mps and torch.cuda.is_available()
if cuda:
n = torch.cuda.device_count()
if n > 1 and batch_size: # check that batch_size is compatible with device_count
Expand All @@ -79,11 +86,16 @@ def select_device(device='', batch_size=None):
for i, d in enumerate(device.split(',') if device else range(n)):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
device_arg = 'cuda:0'
elif mps and torch.backends.mps.is_available() and torch.backends.mps.is_built():
s += 'MPS\n'
device_arg = 'mps'
else:
s += 'CPU\n'
device_arg = 'cpu'

logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
return torch.device('cuda:0' if cuda else 'cpu')
return torch.device(device_arg)


def time_synchronized():
Expand Down