We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
在本地运行文档降低内存占用 如下代码时:
import torch def tensor_memory(x: torch.Tensor): return x.element_size() * x.numel() N = 1 << 10 spike = torch.randint(0, 2, [N]).float() print('float32 size =', tensor_memory(spike)) print('torch.bool size =', tensor_memory(spike.to(torch.bool))) from spikingjelly.activation_based import tensor_cache spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike) print('bool size =', tensor_memory(spike_b)) spike_recover = tensor_cache.bool_spike_to_float(spike_b, s_dtype, s_shape, s_padding) print('spike == spike_recover?', torch.equal(spike, spike_recover))
如下报错:
AssertionError Traceback (most recent call last) Cell In[24], line 11 8 N = 1 << 20 9 spike = (torch.rand([N]) > 0.8).float() ---> 11 spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike) 13 arr = spike_b.numpy() 15 compressed_arr = zlib.compress(arr.tobytes()) File ~/anaconda3/envs/spikingjelly/lib/python3.10/site-packages/spikingjelly/activation_based/tensor_cache.py:123, in float_spike_to_bool(spike) 115 kernel_args = [spike, spike_b, numel] 116 kernel = cupy.RawKernel( 117 kernel_codes, 118 kernel_name, 119 options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend 120 ) 121 kernel( 122 (blocks,), (configure.cuda_threads,), --> 123 cuda_utils.wrap_args_to_raw_kernel( 124 device_id, 125 *kernel_args 126 ) 127 ) 128 return spike_b, s_dtype, s_shape, s_padding File ~/anaconda3/envs/spikingjelly/lib/python3.10/site-packages/spikingjelly/activation_based/cuda_utils.py:249, in wrap_args_to_raw_kernel(device, *args) 246 ret_list.append(item.data_ptr()) 248 elif isinstance(item, cupy.ndarray): --> 249 assert item.device.id == device 250 assert item.flags['C_CONTIGUOUS'] 251 ret_list.append(item) AssertionError:
完全跟着教程走的。 请问可能是什么导致该如何解决呢?
The text was updated successfully, but these errors were encountered:
No branches or pull requests
在本地运行文档降低内存占用 如下代码时:
如下报错:
完全跟着教程走的。 请问可能是什么导致该如何解决呢?
The text was updated successfully, but these errors were encountered: