Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training from checkpoint fails when using multiple GPUs #28

Closed
JWHennessey opened this issue Jan 13, 2020 · 7 comments
Closed

Training from checkpoint fails when using multiple GPUs #28

JWHennessey opened this issue Jan 13, 2020 · 7 comments

Comments

@JWHennessey
Copy link

Hi,

Thanks for implementing the paper in pytorch. I am having problems training a model starting from a specific checkpoint when using multiple GPUs.

When using a single GPU I can run train.py using the following command:
python -m torch.distributed.launch --nproc_per_node=1 train.py --batch 16 --iter 150000 --ckpt checkpoint/start_ckpt.pt dataset
The first sample images look like they are from when I stopped training the network previously.

When I use 4 GPUs I run into a Cuda out of memory issue, using the command:
python -m torch.distributed.launch --nproc_per_node=4 train.py --batch 16 --iter 150000 --ckpt checkpoint/start_ckpt.pt dataset

I then get the following CUDA: Out of memory error

load model: checkpoint/start_ckpt.pt load model: checkpoint/start_ckpt.pt load model: checkpoint/start_ckpt.pt Traceback (most recent call last): File "trainoad model: checkpoint/start_ckpt.pt train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device) File "train.py", line 189, in train real_pred = discriminator(real_img) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 442, in forward output = self.module(*inputs[0], **kwargs[0]) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/jameswhennessey/stylegan2-pytorch/model.py", line 647, in forward out = self.convs(input) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/jameswhennessey/stylegan2-pytorch/model.py", line 598, in forward out = self.conv2(out) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/jameswhennessey/stylegan2-pytorch/op/fused_act.py", line 82, in forward return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) File "/home/jameswhennessey/stylegan2-pytorch/op/fused_act.py", line 86, in fused_leaky_relu return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) File "/home/jameswhennessey/stylegan2-pytorch/op/fused_act.py", line 55, in forward out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) File "/home/jameswhennessey/stylegan2-pytorch/op/fused_act.py", line 86, in fused_leaky_relu return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) File "/home/jameswhennessey/stylegan2-pytorch/op/fused_act.py", line 55, in forward out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) RuntimeError: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 15.75 GiB total capacity; 8.03 GiB already allocated; 101.38 MiB free; 565.07 MiB cached) (malloc at /opt/conda/conda-bld/pytorch_1573049306803/work/c10/cuda/CUDACachingAllocator.cpp:267) frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7f397a01f687 in /opt/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so) ......

Do you have any suggestions on how to resolve this issue?

Thanks in advance.

@cyrilzakka
Copy link

Maybe try decreasing the batch number? Also running

import gc
gc.collect() 

before training should maybe help if you're not starting from a fresh session.

@JWHennessey
Copy link
Author

JWHennessey commented Jan 13, 2020

Thanks for the suggestions.

Even with a batch size of 1 the error occurs. Also, I don't get the error when training with a batch size of 16 with a single GPU.

I have tried using gc.collect() but the problem persists.

Any other suggestions would be very welcome. Thanks.

@rosinality
Copy link
Owner

Doesn't this problem occurs if you starts training from scratch?

@JWHennessey
Copy link
Author

The problem doesn't occur when training from scratch.

I just trained another model from scratch and the same issue occurs with the new checkpoint.

@JWHennessey
Copy link
Author

Interestingly the problem doesn't occur when using 2 GPUs either. Only 3 or more. I have only tested up for 4.

Any suggestions very welcome. Thanks

@JWHennessey
Copy link
Author

I seemed to have resolved the issue. In train.py after loading the checkpoint you need to add

del ckpt  # dereference seems crucial
torch.cuda.empty_cache()

The whole snippet is now:

if args.ckpt is not None:
        print('load model:', args.ckpt)

        ckpt = torch.load(args.ckpt)

        try:
            ckpt_name = os.path.basename(args.ckpt)
            args.start_iter = int(os.path.splitext(ckpt_name)[0])

        except ValueError:
            pass

        generator.load_state_dict(ckpt['g'])
        discriminator.load_state_dict(ckpt['d'])
        g_ema.load_state_dict(ckpt['g_ema'])

        g_optim.load_state_dict(ckpt['g_optim'])
        d_optim.load_state_dict(ckpt['d_optim'])

        del ckpt 
        torch.cuda.empty_cache()

I'll submit a PR.

@afruehstueck
Copy link

I ran into the same problem. Emptying the cache did not fix the problem for me.
torch.load(args.ckpt, map_location=lambda storage, loc: storage) as suggested by @rosinality in #31 (comment) seems to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants