diff --git a/detect.py b/detect.py index 5e0c4416a4..397bd69fd8 100644 --- a/detect.py +++ b/detect.py @@ -170,7 +170,7 @@ def detect(save_img=False): parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') - parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu or mps') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') diff --git a/models/experimental.py b/models/experimental.py index 735d7aa0eb..5cc07077eb 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -249,7 +249,7 @@ def attempt_load(weights, map_location=None): model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: attempt_download(w) - ckpt = torch.load(w, map_location=map_location) # load + ckpt = torch.load(w, map_location=map_location) if 'mps' not in map_location.type else torch.load(w, map_location='cpu') # load model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model # Compatibility updates diff --git a/utils/general.py b/utils/general.py index decdcc64ec..425c7a1417 100644 --- a/utils/general.py +++ b/utils/general.py @@ -613,6 +613,9 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ + if 'mps' in prediction.device.type: + prediction = prediction.cpu() + nc = prediction.shape[2] - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 1e631b5555..e2a2fa9d6f 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -10,9 +10,12 @@ from contextlib import contextmanager from copy import deepcopy from pathlib import Path +from packaging.version import Version import torch import torch.backends.cudnn as cudnn +if Version(torch.__version__) >= Version('1.13.0'): + import torch.backends.mps import torch.nn as nn import torch.nn.functional as F import torchvision @@ -61,16 +64,20 @@ def git_describe(path=Path(__file__).parent): # path must be a directory def select_device(device='', batch_size=None): - # device = 'cpu' or '0' or '0,1,2,3' + # device = 'cpu' or 'mps' or '0' or '0,1,2,3' s = f'YOLOR 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string cpu = device.lower() == 'cpu' - if cpu: + mps = device.lower() == 'mps' + if not Version(torch.__version__) >= Version('1.13.0'): + cpu = True + mps = False + if cpu or mps: os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False elif device: # non-cpu device requested os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability - cuda = not cpu and torch.cuda.is_available() + cuda = not cpu and not mps and torch.cuda.is_available() if cuda: n = torch.cuda.device_count() if n > 1 and batch_size: # check that batch_size is compatible with device_count @@ -79,11 +86,16 @@ def select_device(device='', batch_size=None): for i, d in enumerate(device.split(',') if device else range(n)): p = torch.cuda.get_device_properties(i) s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB + device_arg = 'cuda:0' + elif mps and torch.backends.mps.is_available() and torch.backends.mps.is_built(): + s += 'MPS\n' + device_arg = 'mps' else: s += 'CPU\n' + device_arg = 'cpu' logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe - return torch.device('cuda:0' if cuda else 'cpu') + return torch.device(device_arg) def time_synchronized():