diff --git a/trainer.py b/trainer.py index b163064..cf75fb8 100644 --- a/trainer.py +++ b/trainer.py @@ -165,7 +165,8 @@ def test(model, args, block_len = 'default',use_cuda = False): num_test_batch = int(args.num_block/(args.batch_size)) for batch_idx in range(num_test_batch): X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) - fwd_noise = generate_noise(X_test.shape, args, test_sigma=sigma) + noise_shape = (args.batch_size, args.block_len, args.code_rate_n) + fwd_noise = generate_noise(noise_shape, args, test_sigma=sigma) X_test, fwd_noise= X_test.to(device), fwd_noise.to(device)