Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpikingLSTM中前后时刻的隐状态接收问题 #544

Open
1 of 4 tasks
Seazoned opened this issue Jun 2, 2024 · 1 comment
Open
1 of 4 tasks

SpikingLSTM中前后时刻的隐状态接收问题 #544

Seazoned opened this issue Jun 2, 2024 · 1 comment

Comments

@Seazoned
Copy link

Seazoned commented Jun 2, 2024

Read before creating a new issue

  • Users who want to use SpikingJelly should first be familiar with the usage of PyTorch.
  • If you do not know much about PyTorch, we recommend that the user can learn the basic tutorials of PyTorch.
  • Do not ask for help with the basic conception of PyTorch/Machine Learning but not related to SpikingJelly. For these questions, please refer to Google or PyTorch Forums.

For faster response

You can @ the corresponding developers for your issue. Here is the division:

Features Developers
Neurons and Surrogate Functions fangwei123456
Yanqi-Chen
CUDA Acceleration fangwei123456
Yanqi-Chen
Reinforcement Learning lucifer2859
ANN to SNN Conversion DingJianhao
Lyu6PosHao
Biological Learning (e.g., STDP) AllenYolk
Others Grasshlw
lucifer2859
AllenYolk
Lyu6PosHao
DingJianhao
Yanqi-Chen
fangwei123456

We are glad to add new developers who are volunteering to help solve issues to the above table.

Issue type

  • Bug Report
  • Feature Request
  • Help wanted
  • Other

SpikingJelly version

0.0.0.0.14

Description

在处理序列信息时需要将前一时刻snnLSTM的隐状态值作为下一时刻的snnLSTM的隐状态初始值,但是这样做之后会报错。
代码如下。

Minimal code to reproduce the error/bug

import torch
import torch.nn as nn
from spikingjelly.activation_based import rnn, neuron, layer, surrogate

T = 6
h_dim = 32
batch = 5

x = torch.randn([T, batch, h_dim])
lstm = rnn.SpikingLSTM(32, 32, 1)
states = None
for t in range(8):
    out, states = lstm(x, states)

以下是报错信息

Traceback (most recent call last):
File "E:\SNN\SNN_LSTM\SNN_LSTM\encoder.py", line 80, in
out, states = lstm(x, states)
File "E:\anaconda3\envs\social\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "E:\anaconda3\envs\social\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "E:\anaconda3\envs\social\lib\site-packages\spikingjelly\activation_based\rnn.py", line 473, in forward
new_states_list[:, 0] = torch.stack(self.cells[0](x[t], states_list[:, 0]))
File "E:\anaconda3\envs\social\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "E:\anaconda3\envs\social\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "E:\anaconda3\envs\social\lib\site-packages\spikingjelly\activation_based\rnn.py", line 685, in forward
i, f, g, o = torch.split(self.surrogate_function1(self.linear_ih(x) + self.linear_hh(h)),
ValueError: not enough values to unpack (expected 4, got 1)

@Seazoned
Copy link
Author

Seazoned commented Jun 2, 2024

也许我应该换一种描述方式:
pytorch的lstm输入张量的size为(序列长度,batch_size, 特征长度)
SpikingJelly中的lstm输入张量的维度是(T, batch_size, 特征长度),其中T按照我的理解是脉冲序列的时间步长,也就是说snnlstm并没有原来的“序列长度”这个维度,因此在输入一定长度的序列时,我才想通过循环获得一个通过lstm的隐状态。
我的理解有误吗?恳请赐教

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant