diff --git a/draw_model.py b/draw_model.py index 9994919..2c00aff 100644 --- a/draw_model.py +++ b/draw_model.py @@ -19,8 +19,8 @@ def __init__(self,T,A,B,z_size,N,dec_size,enc_size): self.encoder = nn.LSTMCell(2 * N * N + dec_size, enc_size) self.encoder_gru = nn.GRUCell(2 * N * N + dec_size, enc_size) - self.mu_linear = nn.Linear(dec_size, z_size) - self.sigma_linear = nn.Linear(dec_size, z_size) + self.mu_linear = nn.Linear(enc_size, z_size) + self.sigma_linear = nn.Linear(enc_size, z_size) self.decoder = nn.LSTMCell(z_size,dec_size) self.decoder_gru = nn.GRUCell(z_size,dec_size)