You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there, thanks for this package, it's really helpful!
On a cluster with multiple GPUs, I have my model on device cuda:1.
When calculating FID with a passed gen function, new samples are generated during FID calculation. To that end, a model_fn(x) function is defined here:
and if use_dataparallel=True, the model will be wrapped with model = torch.nn.DataParallel(model).
Problem: DataParallel has a kwarg device_ids=None which defaults to all the available devices and then selects the first device as the "source" device, i.e., cuda:0. Later it asserts that all parameters and buffers of the model are on that device.
Now, if device_ids is not passed, this will result in an error because my model device is different from cuda:0.
I am wondering why DataParallel just hard codes everything to the first of all available devices, but there is a solution on the cleanfid side for this problem.
Solution: pass device_ids with the device of the model:
ifuse_dataparallel:
device_ids= [torch.cuda.current_device()] # or use next(model.parameters()).devicemodel=torch.nn.DataParallel(model, device_ids=device_ids)
defmodel_fn(x): returnmodel(x)
I would be happy to make a PR fixing this. Unless I am missing something?
Cheers,
Jan
The text was updated successfully, but these errors were encountered:
janfb
changed the title
cuda device mismatch when not using cuda:0
cuda device mismatch in DataParallel when not using cuda:0May 15, 2024
Hi there, thanks for this package, it's really helpful!
On a cluster with multiple GPUs, I have my model on device
cuda:1
.When calculating FID with a passed
gen
function, new samples are generated during FID calculation. To that end, amodel_fn(x)
function is defined here:clean-fid/cleanfid/features.py
Lines 23 to 25 in bd44693
and if
use_dataparallel=True
, the model will be wrapped withmodel = torch.nn.DataParallel(model)
.Problem:
DataParallel
has a kwargdevice_ids=None
which defaults to all the available devices and then selects the first device as the "source" device, i.e.,cuda:0
. Later it asserts that all parameters and buffers of the model are on that device.Now, if device_ids is not passed, this will result in an error because my model device is different from
cuda:0
.I am wondering why
DataParallel
just hard codes everything to the first of all available devices, but there is a solution on thecleanfid
side for this problem.Solution: pass device_ids with the device of the model:
I would be happy to make a PR fixing this. Unless I am missing something?
Cheers,
Jan
The text was updated successfully, but these errors were encountered: