Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

运行文档 “降低内存占用” 部分未成功 #556

Open
DayBeha opened this issue Jun 25, 2024 · 0 comments
Open

运行文档 “降低内存占用” 部分未成功 #556

DayBeha opened this issue Jun 25, 2024 · 0 comments

Comments

@DayBeha
Copy link

DayBeha commented Jun 25, 2024

在本地运行文档降低内存占用 如下代码时:

import torch

def tensor_memory(x: torch.Tensor):
    return x.element_size() * x.numel()

N = 1 << 10
spike = torch.randint(0, 2, [N]).float()

print('float32 size =', tensor_memory(spike))
print('torch.bool size =', tensor_memory(spike.to(torch.bool)))

from spikingjelly.activation_based import tensor_cache

spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)


print('bool size =', tensor_memory(spike_b))

spike_recover = tensor_cache.bool_spike_to_float(spike_b, s_dtype, s_shape, s_padding)

print('spike == spike_recover?', torch.equal(spike, spike_recover))

如下报错:


AssertionError                            Traceback (most recent call last)
Cell In[24], line 11
      8 N = 1 << 20
      9 spike = (torch.rand([N]) > 0.8).float()
---> 11 spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)
     13 arr = spike_b.numpy()
     15 compressed_arr = zlib.compress(arr.tobytes())

File ~/anaconda3/envs/spikingjelly/lib/python3.10/site-packages/spikingjelly/activation_based/tensor_cache.py:123, in float_spike_to_bool(spike)
    115     kernel_args = [spike, spike_b, numel]
    116     kernel = cupy.RawKernel(
    117         kernel_codes,
    118         kernel_name,
    119         options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend
    120     )
    121     kernel(
    122         (blocks,), (configure.cuda_threads,),
--> 123         cuda_utils.wrap_args_to_raw_kernel(
    124             device_id,
    125             *kernel_args
    126         )
    127     )
    128 return spike_b, s_dtype, s_shape, s_padding

File ~/anaconda3/envs/spikingjelly/lib/python3.10/site-packages/spikingjelly/activation_based/cuda_utils.py:249, in wrap_args_to_raw_kernel(device, *args)
    246     ret_list.append(item.data_ptr())
    248 elif isinstance(item, cupy.ndarray):
--> 249     assert item.device.id == device
    250     assert item.flags['C_CONTIGUOUS']
    251     ret_list.append(item)

AssertionError: 

完全跟着教程走的。 请问可能是什么导致该如何解决呢?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant