Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA Illegal memory access when starting training #18

Open
marthinwurer opened this issue Nov 30, 2024 · 3 comments
Open

CUDA Illegal memory access when starting training #18

marthinwurer opened this issue Nov 30, 2024 · 3 comments

Comments

@marthinwurer
Copy link

0batch [00:00, ?batch/s]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[21], line 5
      3 for i in range(runs):
      4     start_time = time.time()
----> 5     result, stats, model = train_run(1e-3)
      6     results.append(result)
      7     print("result:", result)

Cell In[20], line 20, in train_run(lr)
     17 stats = defaultdict(list)
     19 start_time = time.time()
---> 20 train_loop(autoencoder, train_dataloader, optimizer, 10, stats)
     21 result = test_loop(autoencoder, test_dataloader)
     22 return result, stats, autoencoder

Cell In[17], line 24, in train_loop(model, dataloader, optimizer, epochs, stats)
     19             image = data.cuda()
     20 #             print(image.shape)
     21 #             break
     22 
     23     #         loss = train_batch(image, model, optimizer, autoencoder.spectral_loss)
---> 24             loss, losses = train_batch(model, image, image, optimizer, F.mse_loss)
     26             if isinstance(losses, tuple):
     27                 stats['raw_losses'].append(loss)

Cell In[16], line 19, in train_batch(model, inputs, targets, optimizer, loss_func)
     16 loss = F.mse_loss(inputs, outputs)
     17 loss.backward()
---> 19 optimizer.step()
     21 return loss.detach().item(), None

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    482         else:
    483             raise RuntimeError(
    484                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    485             )
--> 487 out = func(*args, **kwargs)
    488 self._optimizer_step_code()
    490 # call optimizer step post hooks

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:601, in StatefulOptimizer.step(self, closure)
    599 for top_group in self.param_groups:
    600     for group in self.get_groups(top_group):
--> 601         self._step(group)
    602         if self.use_ema:
    603             self.ema_update(group)

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/precond_schedule_palm_foreach_soap.py:64, in PrecondSchedulePaLMForeachSOAP._step(self, group)
     62     state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
     63     init_preconditioner(g, state, max_precond_dim, precondition_1d)
---> 64     update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
     65     continue  # first step is skipped so that we never use the current gradients in the projection.
     67 # Projecting gradients to the eigenbases of Shampoo's preconditioner
     68 # i.e. projecting to the eigenbases of matrices in state['GG']

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:423, in update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond)
    421 compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
    422 if state['Q'] is None:
--> 423     state['Q'] = get_orthogonal_matrix(state['GG'])
    424 if update_precond:
    425     get_orthogonal_matrix_QR(state['GG'], state['Q'], state['exp_avg_sq'])

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:320, in get_orthogonal_matrix(mat)
    318 for modifier in (None, torch.double, 'cpu'):
    319     if modifier is not None:
--> 320         m = m.to(modifier)
    321     try:
    322         Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device))[1].to(device=device,
    323                                                                                         dtype=dtype)

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

This is using:

    optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(autoencoder.parameters(), lr=lr)

It runs fine with default torch sgd, so I assume it's not a problem with my model.

Let me know if you need more information.

@ClashLuke
Copy link
Owner

Hm, I haven't seen that myself. Is the trace above with CUDA_LAUNCH_BLOCKING=1? .double() and .cpu() shouldn't cause illegal memory accesses

@marthinwurer
Copy link
Author

I tried it again this morning both with CUDA_LAUNCH_BLOCKING=1 and without, and neither triggered the crash. I think it might have had something to do with the other notebooks I still had open with CUDA contexts and some interaction there. I'll close this for now and reopen if I run into it again.

@marthinwurer
Copy link
Author

I triggered it again, here's the stack trace where I put os.environ['CUDA_LAUNCH_BLOCKING'] = "1" at the start of the jupyter notebook.

0batch [00:01, ?batch/s]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[22], line 5
      3 for i in range(runs):
      4     start_time = time.time()
----> 5     result, stats, model = train_run(1e-3)
      6     results.append(result)
      7     print("result:", result)

Cell In[21], line 21, in train_run(lr)
     18 stats = defaultdict(list)
     20 start_time = time.time()
---> 21 train_loop(autoencoder, train_dataloader, optimizer, 10, stats)
     22 result = test_loop(autoencoder, test_dataloader)
     23 return result, stats, autoencoder

Cell In[18], line 24, in train_loop(model, dataloader, optimizer, epochs, stats)
     19             image = data.cuda()
     20 #             print(image.shape)
     21 #             break
     22 
     23     #         loss = train_batch(image, model, optimizer, autoencoder.spectral_loss)
---> 24             loss, losses = train_batch(model, image, image, optimizer, F.mse_loss)
     26             if isinstance(losses, tuple):
     27                 stats['raw_losses'].append(loss)

Cell In[17], line 22, in train_batch(model, inputs, targets, optimizer, loss_func)
     19 # loss = F.mse_loss(inputs, outputs)
     20 loss.backward()
---> 22 optimizer.step()
     24 # return loss.detach().item(), None
     25 return loss.detach().item(), tuple(l.detach().item() for l in losses)

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    482         else:
    483             raise RuntimeError(
    484                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    485             )
--> 487 out = func(*args, **kwargs)
    488 self._optimizer_step_code()
    490 # call optimizer step post hooks

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:601, in StatefulOptimizer.step(self, closure)
    599 for top_group in self.param_groups:
    600     for group in self.get_groups(top_group):
--> 601         self._step(group)
    602         if self.use_ema:
    603             self.ema_update(group)

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/precond_schedule_palm_foreach_soap.py:64, in PrecondSchedulePaLMForeachSOAP._step(self, group)
     62     state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
     63     init_preconditioner(g, state, max_precond_dim, precondition_1d)
---> 64     update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
     65     continue  # first step is skipped so that we never use the current gradients in the projection.
     67 # Projecting gradients to the eigenbases of Shampoo's preconditioner
     68 # i.e. projecting to the eigenbases of matrices in state['GG']

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:423, in update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond)
    421 compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
    422 if state['Q'] is None:
--> 423     state['Q'] = get_orthogonal_matrix(state['GG'])
    424 if update_precond:
    425     get_orthogonal_matrix_QR(state['GG'], state['Q'], state['exp_avg_sq'])

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:320, in get_orthogonal_matrix(mat)
    318 for modifier in (None, torch.double, 'cpu'):
    319     if modifier is not None:
--> 320         m = m.to(modifier)
    321     try:
    322         Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device))[1].to(device=device,
    323                                                                                         dtype=dtype)

RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@marthinwurer marthinwurer reopened this Dec 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants