From 2a9412161e99daf4495ff1b4d48d1a08001e212e Mon Sep 17 00:00:00 2001 From: yihanjiang Date: Fri, 10 Jan 2020 17:59:38 -0800 Subject: [PATCH] add a lot of experimental codes.. --- channel_ae.py | 55 +++++++++ channels.py | 39 +++--- cnn_utils.py | 8 +- decoders.py | 13 +- encoders.py | 42 ++----- get_args.py | 31 ++++- main.py | 35 ++++-- main_modulation.py | 293 +++++++++++++++++++++++++++++++++++++++++++++ mod_trainer.py | 268 +++++++++++++++++++++++++++++++++++++++++ modulation.py | 34 ------ modulations.py | 110 +++++++++++++++++ trainer.py | 63 ++++++++-- utils.py | 24 +++- 13 files changed, 890 insertions(+), 125 deletions(-) create mode 100644 main_modulation.py create mode 100644 mod_trainer.py delete mode 100644 modulation.py create mode 100644 modulations.py diff --git a/channel_ae.py b/channel_ae.py index 3467a4a..4d0ede1 100644 --- a/channel_ae.py +++ b/channel_ae.py @@ -71,3 +71,58 @@ def forward(self, input, fwd_noise): x_dec = self.dec(received_codes) return x_dec, codes + + + +class Channel_ModAE(torch.nn.Module): + def __init__(self, args, enc, dec, mod, demod, modulation = 'qpsk'): + super(Channel_ModAE, self).__init__() + use_cuda = not args.no_cuda and torch.cuda.is_available() + self.this_device = torch.device("cuda" if use_cuda else "cpu") + + self.args = args + self.enc = enc + self.dec = dec + self.mod = mod + self.demod = demod + + def forward(self, input, fwd_noise): + # Setup Interleavers. + if self.args.is_interleave == 0: + pass + + elif self.args.is_same_interleaver == 0: + interleaver = RandInterlv.RandInterlv(self.args.block_len, np.random.randint(0, 1000)) + + p_array = interleaver.p_array + self.enc.set_interleaver(p_array) + self.dec.set_interleaver(p_array) + + else:# self.args.is_same_interleaver == 1 + interleaver = RandInterlv.RandInterlv(self.args.block_len, 0) # not random anymore! + p_array = interleaver.p_array + self.enc.set_interleaver(p_array) + self.dec.set_interleaver(p_array) + + codes = self.enc(input) + symbols = self.mod(codes) + + # Setup channel mode: + if self.args.channel in ['awgn', 't-dist', 'radar', 'ge_awgn']: + received_symbols = symbols + fwd_noise + + elif self.args.channel == 'fading': + print('Fading not implemented') + + else: + print('default AWGN channel') + received_symbols = symbols + fwd_noise + + if self.args.rec_quantize: + myquantize = MyQuantize.apply + received_symbols = myquantize(received_symbols, self.args.rec_quantize_level, self.args.rec_quantize_level) + + x_rec = self.demod(received_symbols) + x_dec = self.dec(x_rec) + + return x_dec, symbols diff --git a/channels.py b/channels.py index da6538d..dcccb65 100644 --- a/channels.py +++ b/channels.py @@ -4,7 +4,7 @@ from utils import snr_db2sigma, snr_sigma2db import numpy as np -def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, snr_high = 0.0, mode = 'encoder'): +def generate_noise(noise_shape, args, test_sigma = 'default', snr_low = 0.0, snr_high = 0.0, mode = 'encoder'): # SNRs at training if test_sigma == 'default': if args.channel == 'bec': @@ -22,7 +22,7 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s this_sigma_low = snr_db2sigma(snr_low) this_sigma_high= snr_db2sigma(snr_high) # mixture of noise sigma. - this_sigma = (this_sigma_low - this_sigma_high) * torch.rand((X_train_shape[0], X_train_shape[1], args.code_rate_n)) + this_sigma_high + this_sigma = (this_sigma_low - this_sigma_high) * torch.rand(noise_shape) + this_sigma_high else: if args.channel in ['bec', 'bsc', 'ge']: # bsc/bec noises @@ -32,25 +32,25 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s # SNRs at testing if args.channel == 'awgn': - fwd_noise = this_sigma * torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) + fwd_noise = this_sigma * torch.randn(noise_shape, dtype=torch.float) elif args.channel == 't-dist': - fwd_noise = this_sigma * torch.from_numpy(np.sqrt((args.vv-2)/args.vv) * np.random.standard_t(args.vv, size = (X_train_shape[0], X_train_shape[1], args.code_rate_n))).type(torch.FloatTensor) + fwd_noise = this_sigma * torch.from_numpy(np.sqrt((args.vv-2)/args.vv) * np.random.standard_t(args.vv, size = noise_shape)).type(torch.FloatTensor) elif args.channel == 'radar': - add_pos = np.random.choice([0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n), + add_pos = np.random.choice([0.0, 1.0], noise_shape, p=[1 - args.radar_prob, args.radar_prob]) - corrupted_signal = args.radar_power* np.random.standard_normal( size = (X_train_shape[0], X_train_shape[1], args.code_rate_n) ) * add_pos - fwd_noise = this_sigma * torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) +\ + corrupted_signal = args.radar_power* np.random.standard_normal( size = noise_shape ) * add_pos + fwd_noise = this_sigma * torch.randn(noise_shape, dtype=torch.float) +\ torch.from_numpy(corrupted_signal).type(torch.FloatTensor) elif args.channel == 'bec': - fwd_noise = torch.from_numpy(np.random.choice([0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n), + fwd_noise = torch.from_numpy(np.random.choice([0.0, 1.0], noise_shape, p=[this_sigma, 1 - this_sigma])).type(torch.FloatTensor) elif args.channel == 'bsc': - fwd_noise = torch.from_numpy(np.random.choice([0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n), + fwd_noise = torch.from_numpy(np.random.choice([0.0, 1.0], noise_shape, p=[this_sigma, 1 - this_sigma])).type(torch.FloatTensor) elif args.channel == 'ge_awgn': #G-E AWGN channel @@ -59,12 +59,12 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s bsc_k = snr_db2sigma(snr_sigma2db(this_sigma) + 1) # accuracy on good state bsc_h = snr_db2sigma(snr_sigma2db(this_sigma) - 1) # accuracy on good state - fwd_noise = np.zeros((X_train_shape[0], X_train_shape[1], args.code_rate_n)) - for batch_idx in range(X_train_shape[0]): - for code_idx in range(args.code_rate_n): + fwd_noise = np.zeros(noise_shape) + for batch_idx in range(noise_shape[0]): + for code_idx in range(noise_shape[2]): good = True - for time_idx in range(X_train_shape[1]): + for time_idx in range(noise_shape[1]): if good: if test_sigma == 'default': fwd_noise[batch_idx,time_idx, code_idx] = bsc_k[batch_idx,time_idx, code_idx] @@ -80,7 +80,7 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s else: print('bad!!! something happens') - fwd_noise = torch.from_numpy(fwd_noise).type(torch.FloatTensor)* torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) + fwd_noise = torch.from_numpy(fwd_noise).type(torch.FloatTensor)* torch.randn(noise_shape, dtype=torch.float) elif args.channel == 'ge': #G-E discrete channel @@ -89,12 +89,12 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s bsc_k = 1.0 # accuracy on good state bsc_h = this_sigma# accuracy on good state - fwd_noise = np.zeros((X_train_shape[0], X_train_shape[1], args.code_rate_n)) - for batch_idx in range(X_train_shape[0]): - for code_idx in range(args.code_rate_n): + fwd_noise = np.zeros(noise_shape) + for batch_idx in range(noise_shape[0]): + for code_idx in range(noise_shape[2]): good = True - for time_idx in range(X_train_shape[1]): + for time_idx in range(noise_shape[1]): if good: tmp = np.random.choice([0.0, 1.0], p=[1-bsc_k, bsc_k]) fwd_noise[batch_idx,time_idx, code_idx] = tmp @@ -110,10 +110,9 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s else: # Unspecific channel, use AWGN channel. - fwd_noise = this_sigma * torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) + fwd_noise = this_sigma * torch.randn(noise_shape, dtype=torch.float) return fwd_noise - diff --git a/cnn_utils.py b/cnn_utils.py index 2a00169..20a22b0 100644 --- a/cnn_utils.py +++ b/cnn_utils.py @@ -4,11 +4,12 @@ # utility for Same Shape CNN 1D class SameShapeConv1d(torch.nn.Module): - def __init__(self, num_layer, in_channels, out_channels, kernel_size, activation = 'elu'): + def __init__(self, num_layer, in_channels, out_channels, kernel_size, activation = 'elu', no_act = False): super(SameShapeConv1d, self).__init__() self.cnns = torch.nn.ModuleList() self.num_layer = num_layer + self.no_act = no_act for idx in range(num_layer): if idx == 0: self.cnns.append(torch.nn.Conv1d(in_channels = in_channels, out_channels=out_channels, @@ -36,7 +37,10 @@ def forward(self, inputs): inputs = torch.transpose(inputs, 1,2) x = inputs for idx in range(self.num_layer): - x = self.activation(self.cnns[idx](x)) + if self.no_act: + x = self.cnns[idx](x) + else: + x = self.activation(self.cnns[idx](x)) outputs = torch.transpose(x, 1,2) return outputs diff --git a/decoders.py b/decoders.py index e7d0838..d48614b 100644 --- a/decoders.py +++ b/decoders.py @@ -277,22 +277,15 @@ def forward(self, received): from encoders import SameShapeConv1d class DEC_LargeCNN2Int(torch.nn.Module): - def __init__(self, args, p_array): + def __init__(self, args, p_array1, p_array2): super(DEC_LargeCNN2Int, self).__init__() self.args = args use_cuda = not args.no_cuda and torch.cuda.is_available() self.this_device = torch.device("cuda" if use_cuda else "cpu") - self.interleaver1 = Interleaver(args, p_array) - self.deinterleaver1 = DeInterleaver(args, p_array) - - seed2 = 1000 - rand_gen2 = mtrand.RandomState(seed2) - p_array2 = rand_gen2.permutation(arange(args.block_len)) - - print('p_array1 dec', p_array) - print('p_array2 dec', p_array2) + self.interleaver1 = Interleaver(args, p_array1) + self.deinterleaver1 = DeInterleaver(args, p_array1) self.interleaver2 = Interleaver(args, p_array2) self.deinterleaver2 = DeInterleaver(args, p_array2) diff --git a/encoders.py b/encoders.py index cebe48d..90766e8 100644 --- a/encoders.py +++ b/encoders.py @@ -101,35 +101,28 @@ def enc_act(self, inputs): def power_constraint(self, x_input): - if not self.args.precompute_norm_stats: + if self.args.no_code_norm: + return x_input + else: this_mean = torch.mean(x_input) this_std = torch.std(x_input) - x_input_norm = (x_input-this_mean)*1.0 / this_std - else: - x_input_norm = (x_input - self.mean_scalar)/self.std_scalar - if self.training: - # save pretrained mean/std. Pretrained not implemented in this code version. - try: + if self.args.precompute_norm_stats: self.num_test_block += 1.0 self.mean_scalar = (self.mean_scalar*(self.num_test_block-1) + this_mean)/self.num_test_block self.std_scalar = (self.std_scalar*(self.num_test_block-1) + this_std)/self.num_test_block - except: - print('group normalization seems wired.!') + x_input_norm = (x_input - self.mean_scalar)/self.std_scalar + else: + x_input_norm = (x_input-this_mean)*1.0 / this_std if self.args.train_channel_mode == 'block_norm_ste': stequantize = STEQuantize.apply x_input_norm = stequantize(x_input_norm, self.args) - else: - if self.args.test_channel_mode == 'block_norm_ste': - stequantize = STEQuantize.apply - x_input_norm = stequantize(x_input_norm, self.args) - - if self.args.enc_truncate_limit>0: - x_input_norm = torch.clamp(x_input_norm, -self.args.enc_truncate_limit, self.args.enc_truncate_limit) + if self.args.enc_truncate_limit>0: + x_input_norm = torch.clamp(x_input_norm, -self.args.enc_truncate_limit, self.args.enc_truncate_limit) - return x_input_norm + return x_input_norm # Encoder with interleaver. Support different code rate. class ENC_turbofy_rate2(ENCBase): @@ -388,7 +381,7 @@ def forward(self, inputs): ####################################################### from cnn_utils import SameShapeConv1d class ENC_interCNN2Int(ENCBase): - def __init__(self, args, p_array): + def __init__(self, args, p_array1, p_array2): # turbofy only for code rate 1/3 super(ENC_interCNN2Int, self).__init__(args) self.args = args @@ -411,16 +404,7 @@ def __init__(self, args, p_array): self.enc_linear_3 = torch.nn.Linear(args.enc_num_unit, 1) - self.interleaver1 = Interleaver(args, p_array) - - - seed2 = 1000 - rand_gen2 = mtrand.RandomState(seed2) - p_array2 = rand_gen2.permutation(arange(args.block_len)) - - print('p_array1', p_array) - print('p_array2', p_array2) - + self.interleaver1 = Interleaver(args, p_array1) self.interleaver2 = Interleaver(args, p_array2) @@ -451,7 +435,7 @@ def forward(self, inputs): x_p2 = self.enc_cnn_3(x_sys_int2) x_p2 = self.enc_act(self.enc_linear_3(x_p2)) - x_tx = torch.cat([x_sys,x_p1, x_p2], dim = 2) + x_tx = torch.cat([x_sys, x_p1, x_p2], dim = 2) codes = self.power_constraint(x_tx) diff --git a/get_args.py b/get_args.py index b3deddd..9cce9dd 100644 --- a/get_args.py +++ b/get_args.py @@ -19,6 +19,7 @@ def get_args(): 'rate3_cnn2d', 'Turbo_rate3_757', # Turbo Code, rate 1/3, 757. 'Turbo_rate3_lte', # Turbo Code, rate 1/3, LTE. + 'turboae_2int', # experimental, use multiple interleavers ], default='TurboAE_rate3_cnn2d') @@ -33,6 +34,7 @@ def get_args(): 'nbcjr_rate3', # NeuralBCJR Decoder, rate 1/3, allow ft size. 'rate3_cnn', # CNN Encoder, rate 1/3. No Interleaver 'rate3_cnn2d', + 'turboae_2int', # experimental, use multiple interleavers ], default='TurboAE_rate3_cnn2d') ################################################################ @@ -81,7 +83,7 @@ def get_args(): parser.add_argument('-extrinsic', type=int, default=1) parser.add_argument('-num_iter_ft', type=int, default=5) parser.add_argument('-is_interleave', type=int, default=1, help='0 is not interleaving, 1 is fixed interleaver, >1 is random interleaver') - parser.add_argument('-is_same_interleaver', type=int, default=0, help='not random interleaver, potentially finetune?') + parser.add_argument('-is_same_interleaver', type=int, default=1, help='not random interleaver, potentially finetune?') parser.add_argument('-is_parallel', type=int, default=0) # CNN related parser.add_argument('-enc_kernel_size', type=int, default=5) @@ -91,12 +93,14 @@ def get_args(): parser.add_argument('-enc_num_layer', type=int, default=2) parser.add_argument('-dec_num_layer', type=int, default=5) - parser.add_argument('-enc_num_unit', type=int, default=100, help = 'This is CNN number of filters, and RNN units') + parser.add_argument('-dec_num_unit', type=int, default=100, help = 'This is CNN number of filters, and RNN units') + parser.add_argument('-enc_num_unit', type=int, default=100, help = 'This is CNN number of filters, and RNN units') parser.add_argument('-enc_act', choices=['tanh', 'selu', 'relu', 'elu', 'sigmoid', 'linear'], default='elu', help='only elu works') parser.add_argument('-dec_act', choices=['tanh', 'selu', 'relu', 'elu', 'sigmoid', 'linear'], default='linear') + parser.add_argument('-num_ber_puncture', type=int, default=5, help = 'Puncture bad BER positions') ################################################################ # Training ALgorithm related parameters @@ -133,6 +137,29 @@ def get_args(): default='block_norm') parser.add_argument('-enc_truncate_limit', type=float, default=0, help='0 means no truncation') + ################################################################ + # Modulation related parameters + ################################################################ + parser.add_argument('-mod_rate', type=int, default=2, help = 'code: (B, L, R), mode_output (B, L*R/mod_rate, 2)') + parser.add_argument('-mod_num_layer', type=int, default=1, help = '') + parser.add_argument('-mod_num_unit', type=int, default=20, help = '') + parser.add_argument('-demod_num_layer', type=int, default=1, help = '') + parser.add_argument('-demod_num_unit', type=int, default=20, help = '') + + parser.add_argument('-mod_lr', type = float, default=0.005, help='modulation leanring rate') + parser.add_argument('-demod_lr', type = float, default=0.005, help='demodulation leanring rate') + + parser.add_argument('-num_train_mod', type=int, default=1, help = '') + parser.add_argument('-num_train_demod', type=int, default=5, help = '') + + parser.add_argument('-mod_pc', + choices=['qpsk','symbol_power', 'block_power'], + default='block_power') + + parser.add_argument('--no_code_norm', action='store_true', default=False, + help='the output of encoder is not normalized. Modulation do the work') + + ################################################################ # STE related parameters diff --git a/main.py b/main.py index b39c73d..23e0549 100644 --- a/main.py +++ b/main.py @@ -35,7 +35,7 @@ def import_enc(args): elif args.encoder in ['TurboAE_rate3_cnn', 'TurboAE_rate3_cnn_dense']: from encoders import ENC_interCNN as ENC - elif args.encoder == 'TurboAE_rate3_cnn_2inter': + elif args.encoder == 'turboae_2int': from encoders import ENC_interCNN2Int as ENC elif args.encoder == 'rate3_cnn': @@ -75,7 +75,7 @@ def import_dec(args): elif args.decoder in ['TurboAE_rate3_cnn', 'TurboAE_rate3_cnn_dense']: from decoders import DEC_LargeCNN as DEC - elif args.decoder == 'TurboAE_rate3_cnn_2inter': + elif args.decoder == 'turboae_2int': from decoders import DEC_LargeCNN2Int as DEC elif args.encoder == 'rate3_cnn': @@ -123,25 +123,36 @@ def import_dec(args): if args.is_interleave == 1: # fixed interleaver. seed = np.random.randint(0, 1) rand_gen = mtrand.RandomState(seed) - p_array = rand_gen.permutation(arange(args.block_len)) + p_array1 = rand_gen.permutation(arange(args.block_len)) + p_array2 = rand_gen.permutation(arange(args.block_len)) elif args.is_interleave == 0: - p_array = range(args.block_len) # no interleaver. + p_array1 = range(args.block_len) # no interleaver. + p_array2 = range(args.block_len) # no interleaver. else: seed = np.random.randint(0, args.is_interleave) rand_gen = mtrand.RandomState(seed) - p_array = rand_gen.permutation(arange(args.block_len)) - - print('using random interleaver', p_array) + p_array1 = rand_gen.permutation(arange(args.block_len)) + seed = np.random.randint(0, args.is_interleave) + rand_gen = mtrand.RandomState(seed) + p_array2 = rand_gen.permutation(arange(args.block_len)) + print('using random interleaver', p_array1, p_array2) - encoder = ENC(args, p_array) - decoder = DEC(args, p_array) + if args.encoder == 'turboae_2int' and args.decoder == 'turboae_2int': + encoder = ENC(args, p_array1, p_array2) + decoder = DEC(args, p_array1, p_array2) + else: + encoder = ENC(args, p_array1) + decoder = DEC(args, p_array1) # choose support channels from channel_ae import Channel_AE model = Channel_AE(args, encoder, decoder).to(device) + # model = Channel_ModAE(args, encoder, decoder).to(device) + + # make the model parallel if args.is_parallel == 1: model.enc.set_parallel() @@ -234,6 +245,9 @@ def import_dec(args): # Testing Processes ################################################# + torch.save(model.state_dict(), './tmp/torch_model_'+identity+'.pt') + print('saved model', './tmp/torch_model_'+identity+'.pt') + if args.is_variable_block_len: print('testing block length',args.block_len_low ) test(model, args, block_len=args.block_len_low, use_cuda = use_cuda) @@ -245,8 +259,7 @@ def import_dec(args): else: test(model, args, use_cuda = use_cuda) - torch.save(model.state_dict(), './tmp/torch_model_'+identity+'.pt') - print('saved model', './tmp/torch_model_'+identity+'.pt') + diff --git a/main_modulation.py b/main_modulation.py new file mode 100644 index 0000000..a9a8f65 --- /dev/null +++ b/main_modulation.py @@ -0,0 +1,293 @@ +__author__ = 'yihanjiang' +# update 10/18/2019, code to replicate TurboAE paper in NeurIPS 2019. +# Tested on PyTorch 1.0. +# TBD: remove all non-TurboAE related functions. + +import torch +import torch.optim as optim +import numpy as np +import sys +from get_args import get_args +from mod_trainer import train, validate, test + +from numpy import arange +from numpy.random import mtrand + +# utils for logger +class Logger(object): + def __init__(self, filename, stream=sys.stdout): + self.terminal = stream + self.log = open(filename, 'a') + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + pass + +def import_enc(args): + # choose encoder + + if args.encoder == 'TurboAE_rate3_rnn': + from encoders import ENC_interRNN as ENC + + elif args.encoder in ['TurboAE_rate3_cnn', 'TurboAE_rate3_cnn_dense']: + from encoders import ENC_interCNN as ENC + + elif args.encoder == 'turboae_2int': + from encoders import ENC_interCNN2Int as ENC + + elif args.encoder == 'rate3_cnn': + from encoders import CNN_encoder_rate3 as ENC + + elif args.encoder in ['TurboAE_rate3_cnn2d', 'TurboAE_rate3_cnn2d_dense']: + from encoders import ENC_interCNN2D as ENC + + elif args.encoder == 'TurboAE_rate3_rnn_sys': + from encoders import ENC_interRNN_sys as ENC + + elif args.encoder == 'TurboAE_rate2_rnn': + from encoders import ENC_turbofy_rate2 as ENC + + elif args.encoder == 'TurboAE_rate2_cnn': + from encoders import ENC_turbofy_rate2_CNN as ENC # not done yet + + elif args.encoder in ['Turbo_rate3_lte', 'Turbo_rate3_757']: + from encoders import ENC_TurboCode as ENC # DeepTurbo, encoder not trainable. + + elif args.encoder == 'rate3_cnn2d': + from encoders import ENC_CNN2D as ENC + + else: + print('Unknown Encoder, stop') + + return ENC + +def import_dec(args): + + if args.decoder == 'TurboAE_rate2_rnn': + from decoders import DEC_LargeRNN_rate2 as DEC + + elif args.decoder == 'TurboAE_rate2_cnn': + from decoders import DEC_LargeCNN_rate2 as DEC # not done yet + + elif args.decoder in ['TurboAE_rate3_cnn', 'TurboAE_rate3_cnn_dense']: + from decoders import DEC_LargeCNN as DEC + + elif args.decoder == 'turboae_2int': + from decoders import DEC_LargeCNN2Int as DEC + + elif args.encoder == 'rate3_cnn': + from decoders import CNN_decoder_rate3 as DEC + + elif args.decoder in ['TurboAE_rate3_cnn2d', 'TurboAE_rate3_cnn2d_dense']: + from decoders import DEC_LargeCNN2D as DEC + + elif args.decoder == 'TurboAE_rate3_rnn': + from decoders import DEC_LargeRNN as DEC + + elif args.decoder == 'nbcjr_rate3': # ICLR 2018 paper + from decoders import NeuralTurbofyDec as DEC + + elif args.decoder == 'rate3_cnn2d': + from decoders import DEC_CNN2D as DEC + + return DEC + +if __name__ == '__main__': + ################################################# + # load args & setup logger + ################################################# + identity = str(np.random.random())[2:8] + print('[ID]', identity) + + # put all printed things to log file + logfile = open('./logs/'+identity+'_log.txt', 'a') + sys.stdout = Logger('./logs/'+identity+'_log.txt', sys.stdout) + + args = get_args() + print(args) + + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + + ################################################# + # Setup Channel AE: Encoder, Decoder, Channel + ################################################# + # choose encoder and decoder. + ENC = import_enc(args) + DEC = import_dec(args) + + # setup interleaver. + if args.is_interleave == 1: # fixed interleaver. + seed = np.random.randint(0, 1) + rand_gen = mtrand.RandomState(seed) + p_array1 = rand_gen.permutation(arange(args.block_len)) + p_array2 = rand_gen.permutation(arange(args.block_len)) + + elif args.is_interleave == 0: + p_array1 = range(args.block_len) # no interleaver. + p_array2 = range(args.block_len) # no interleaver. + else: + seed = np.random.randint(0, args.is_interleave) + rand_gen = mtrand.RandomState(seed) + p_array1 = rand_gen.permutation(arange(args.block_len)) + seed = np.random.randint(0, args.is_interleave) + rand_gen = mtrand.RandomState(seed) + p_array2 = rand_gen.permutation(arange(args.block_len)) + + print('using random interleaver', p_array1, p_array2) + + if args.encoder == 'turboae_2int' and args.decoder == 'turboae_2int': + encoder = ENC(args, p_array1, p_array2) + decoder = DEC(args, p_array1, p_array2) + else: + encoder = ENC(args, p_array1) + decoder = DEC(args, p_array1) + + # modulation and demodulations. + from modulations import Modulation, DeModulation + + modulator = Modulation(args) + demodulator = DeModulation(args) + + # choose support channels + from channel_ae import Channel_ModAE + model = Channel_ModAE(args, encoder, decoder, modulator, demodulator).to(device) + + + + # make the model parallel + if args.is_parallel == 1: + model.enc.set_parallel() + model.dec.set_parallel() + + # weight loading + if args.init_nw_weight == 'default': + pass + + else: + pretrained_model = torch.load(args.init_nw_weight) + + try: + model.load_state_dict(pretrained_model.state_dict(), strict = False) + + except: + model.load_state_dict(pretrained_model, strict = False) + + model.args = args + + print(model) + + + ################################################################## + # Setup Optimizers, only Adam and Lookahead for now. + ################################################################## + + if args.optimizer == 'lookahead': + print('Using Lookahead Optimizers') + from optimizers import Lookahead + lookahead_k = 5 + lookahead_alpha = 0.5 + if args.num_train_enc != 0 and args.encoder not in ['Turbo_rate3_lte', 'Turbo_rate3_757']: # no optimizer for encoder + enc_base_opt = optim.Adam(model.enc.parameters(), lr=args.enc_lr) + enc_optimizer = Lookahead(enc_base_opt, k=lookahead_k, alpha=lookahead_alpha) + + if args.num_train_dec != 0: + dec_base_opt = optim.Adam(filter(lambda p: p.requires_grad, model.dec.parameters()), lr=args.dec_lr) + dec_optimizer = Lookahead(dec_base_opt, k=lookahead_k, alpha=lookahead_alpha) + + general_base_opt = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=args.dec_lr) + general_optimizer = Lookahead(general_base_opt, k=lookahead_k, alpha=lookahead_alpha) + + else: # Adam, SGD, etc.... + if args.optimizer == 'adam': + OPT = optim.Adam + elif args.optimizer == 'sgd': + OPT = optim.SGD + else: + OPT = optim.Adam + + if args.num_train_enc != 0 and args.encoder not in ['Turbo_rate3_lte', 'Turbo_rate3_757']: # no optimizer for encoder + enc_optimizer = OPT(model.enc.parameters(),lr=args.enc_lr) + + if args.num_train_dec != 0: + dec_optimizer = OPT(filter(lambda p: p.requires_grad, model.dec.parameters()), lr=args.dec_lr) + + if args.num_train_mod != 0: + mod_optimizer = OPT(filter(lambda p: p.requires_grad, model.mod.parameters()), lr=args.mod_lr) + + if args.num_train_demod != 0: + demod_optimizer = OPT(filter(lambda p: p.requires_grad, model.demod.parameters()), lr=args.demod_lr) + + general_optimizer = OPT(filter(lambda p: p.requires_grad, model.parameters()),lr=args.dec_lr) + + ################################################# + # Training Processes + ################################################# + report_loss, report_ber = [], [] + + for epoch in range(1, args.num_epoch + 1): + + if args.joint_train == 1 and args.encoder not in ['Turbo_rate3_lte', 'Turbo_rate3_757']: + for idx in range(args.num_train_enc+args.num_train_dec): + train(epoch, model, general_optimizer, args, use_cuda = use_cuda, mode ='encoder') + + else: + if args.num_train_enc > 0 and args.encoder not in ['Turbo_rate3_lte', 'Turbo_rate3_757']: + for idx in range(args.num_train_enc): + train(epoch, model, enc_optimizer, args, use_cuda = use_cuda, mode ='encoder') + + if args.num_train_dec > 0: + for idx in range(args.num_train_dec): + train(epoch, model, dec_optimizer, args, use_cuda = use_cuda, mode ='decoder') + + if args.num_train_mod > 0: + for idx in range(args.num_train_mod): + train(epoch, model, mod_optimizer, args, use_cuda = use_cuda, mode ='decoder') + + if args.num_train_demod > 0: + for idx in range(args.num_train_demod): + train(epoch, model, demod_optimizer, args, use_cuda = use_cuda, mode ='decoder') + + this_loss, this_ber = validate(model, general_optimizer, args, use_cuda = use_cuda) + report_loss.append(this_loss) + report_ber.append(this_ber) + + if args.print_test_traj == True: + print('test loss trajectory', report_loss) + print('test ber trajectory', report_ber) + print('total epoch', args.num_epoch) + + ################################################# + # Testing Processes + ################################################# + + torch.save(model.state_dict(), './tmp/torch_model_'+identity+'.pt') + print('saved model', './tmp/torch_model_'+identity+'.pt') + + if args.is_variable_block_len: + print('testing block length',args.block_len_low ) + test(model, args, block_len=args.block_len_low, use_cuda = use_cuda) + print('testing block length',args.block_len ) + test(model, args, block_len=args.block_len, use_cuda = use_cuda) + print('testing block length',args.block_len_high ) + test(model, args, block_len=args.block_len_high, use_cuda = use_cuda) + + else: + test(model, args, use_cuda = use_cuda) + + + + + + + + + + + + + + diff --git a/mod_trainer.py b/mod_trainer.py new file mode 100644 index 0000000..7b3316e --- /dev/null +++ b/mod_trainer.py @@ -0,0 +1,268 @@ +__author__ = 'yihanjiang' +import torch +import time +import torch.nn.functional as F + +eps = 1e-6 + +from utils import snr_sigma2db, snr_db2sigma, code_power, errors_ber_pos, errors_ber, errors_bler +from loss import customized_loss +from channels import generate_noise + +import numpy as np +from numpy import arange +from numpy.random import mtrand + +###################################################################################### +# +# Trainer, validation, and test for AE code design +# +###################################################################################### + + +def train(epoch, model, optimizer, args, use_cuda = False, verbose = True, mode = 'encoder'): + + device = torch.device("cuda" if use_cuda else "cpu") + + model.train() + start_time = time.time() + train_loss = 0.0 + k_same_code_counter = 0 + + + for batch_idx in range(int(args.num_block/args.batch_size)): + + + if args.is_variable_block_len: + block_len = np.random.randint(args.block_len_low, args.block_len_high) + else: + block_len = args.block_len + + optimizer.zero_grad() + + if args.is_k_same_code and mode == 'encoder': + if batch_idx == 0: + k_same_code_counter += 1 + X_train = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + elif k_same_code_counter == args.k_same_code: + k_same_code_counter = 1 + X_train = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + else: + k_same_code_counter += 1 + else: + X_train = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + + noise_shape = (args.batch_size, int(args.block_len * args.code_rate_n /args.mod_rate), args.mod_rate) + # train encoder/decoder with different SNR... seems to be a good practice. + if mode == 'encoder': + fwd_noise = generate_noise(noise_shape, args, snr_low=args.train_enc_channel_low, snr_high=args.train_enc_channel_high, mode = 'encoder') + else: + fwd_noise = generate_noise(noise_shape, args, snr_low=args.train_dec_channel_low, snr_high=args.train_dec_channel_high, mode = 'decoder') + + X_train, fwd_noise = X_train.to(device), fwd_noise.to(device) + + output, code = model(X_train, fwd_noise) + output = torch.clamp(output, 0.0, 1.0) + + if mode == 'encoder': + loss = customized_loss(output, X_train, args, noise=fwd_noise, code = code) + + else: + loss = customized_loss(output, X_train, args, noise=fwd_noise, code = code) + #loss = F.binary_cross_entropy(output, X_train) + + loss.backward() + train_loss += loss.item() + optimizer.step() + + end_time = time.time() + train_loss = train_loss /(args.num_block/args.batch_size) + if verbose: + print('====> Epoch: {} Average loss: {:.8f}'.format(epoch, train_loss), \ + ' running time', str(end_time - start_time)) + + return train_loss + + + +def validate(model, optimizer, args, use_cuda = False, verbose = True): + + device = torch.device("cuda" if use_cuda else "cpu") + + model.eval() + test_bce_loss, test_custom_loss, test_ber= 0.0, 0.0, 0.0 + + with torch.no_grad(): + num_test_batch = int(args.num_block/args.batch_size * args.test_ratio) + for batch_idx in range(num_test_batch): + X_test = torch.randint(0, 2, (args.batch_size, args.block_len, args.code_rate_k), dtype=torch.float) + + noise_shape = (args.batch_size, int(args.block_len * args.code_rate_n /args.mod_rate), args.mod_rate) + + fwd_noise = generate_noise(noise_shape, args, + snr_low=args.train_enc_channel_low, + snr_high=args.train_enc_channel_low) + + X_test, fwd_noise= X_test.to(device), fwd_noise.to(device) + + optimizer.zero_grad() + output, codes = model(X_test, fwd_noise) + + output = torch.clamp(output, 0.0, 1.0) + + output = output.detach() + X_test = X_test.detach() + + test_bce_loss += F.binary_cross_entropy(output, X_test) + test_custom_loss += customized_loss(output, X_test, noise = fwd_noise, args = args, code = codes) + test_ber += errors_ber(output,X_test) + + + test_bce_loss /= num_test_batch + test_custom_loss /= num_test_batch + test_ber /= num_test_batch + + if verbose: + print('====> Test set BCE loss', float(test_bce_loss), + 'Custom Loss',float(test_custom_loss), + 'with ber ', float(test_ber), + ) + + report_loss = float(test_bce_loss) + report_ber = float(test_ber) + + return report_loss, report_ber + + +def test(model, args, block_len = 'default',use_cuda = False): + + device = torch.device("cuda" if use_cuda else "cpu") + model.eval() + + if block_len == 'default': + block_len = args.block_len + else: + pass + + # Precomputes Norm Statistics. + if args.precompute_norm_stats: + with torch.no_grad(): + num_test_batch = int(args.num_block/(args.batch_size)* args.test_ratio) + for batch_idx in range(num_test_batch): + X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + X_test = X_test.to(device) + _ = model.enc(X_test) + print('Pre-computed norm statistics mean ',model.enc.mean_scalar, 'std ', model.enc.std_scalar) + + ber_res, bler_res = [], [] + ber_res_punc, bler_res_punc = [], [] + snr_interval = (args.snr_test_end - args.snr_test_start)* 1.0 / (args.snr_points-1) + snrs = [snr_interval* item + args.snr_test_start for item in range(args.snr_points)] + print('SNRS', snrs) + sigmas = snrs + + for sigma, this_snr in zip(sigmas, snrs): + test_ber, test_bler = .0, .0 + with torch.no_grad(): + num_test_batch = int(args.num_block/(args.batch_size)) + for batch_idx in range(num_test_batch): + X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + noise_shape = (args.batch_size, int(args.block_len * args.code_rate_n /args.mod_rate), args.mod_rate) + fwd_noise = generate_noise(noise_shape, args, test_sigma=sigma) + + X_test, fwd_noise= X_test.to(device), fwd_noise.to(device) + + X_hat_test, the_codes = model(X_test, fwd_noise) + + + test_ber += errors_ber(X_hat_test,X_test) + test_bler += errors_bler(X_hat_test,X_test) + + if batch_idx == 0: + test_pos_ber = errors_ber_pos(X_hat_test,X_test) + codes_power = code_power(the_codes) + else: + test_pos_ber += errors_ber_pos(X_hat_test,X_test) + codes_power += code_power(the_codes) + + if args.print_pos_power: + print('code power', codes_power/num_test_batch) + if args.print_pos_ber: + res_pos = test_pos_ber/num_test_batch + res_pos_arg = np.array(res_pos.cpu()).argsort()[::-1] + res_pos_arg = res_pos_arg.tolist() + print('positional ber', res_pos) + print('positional argmax',res_pos_arg) + try: + test_ber_punc, test_bler_punc = .0, .0 + for batch_idx in range(num_test_batch): + X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + noise_shape = (args.batch_size, int(args.block_len * args.code_rate_n /args.mod_rate), args.mod_rate) + fwd_noise = generate_noise(noise_shape, args, test_sigma=sigma) + X_test, fwd_noise= X_test.to(device), fwd_noise.to(device) + + X_hat_test, the_codes = model(X_test, fwd_noise) + + test_ber_punc += errors_ber(X_hat_test,X_test, positions = res_pos_arg[:args.num_ber_puncture]) + test_bler_punc += errors_bler(X_hat_test,X_test, positions = res_pos_arg[:args.num_ber_puncture]) + + if batch_idx == 0: + test_pos_ber = errors_ber_pos(X_hat_test,X_test) + codes_power = code_power(the_codes) + else: + test_pos_ber += errors_ber_pos(X_hat_test,X_test) + codes_power += code_power(the_codes) + except: + print('no pos BER specified.') + + test_ber /= num_test_batch + test_bler /= num_test_batch + print('Test SNR',this_snr ,'with ber ', float(test_ber), 'with bler', float(test_bler)) + ber_res.append(float(test_ber)) + bler_res.append( float(test_bler)) + + try: + test_ber_punc /= num_test_batch + test_bler_punc /= num_test_batch + print('Punctured Test SNR',this_snr ,'with ber ', float(test_ber_punc), 'with bler', float(test_bler_punc)) + ber_res_punc.append(float(test_ber_punc)) + bler_res_punc.append( float(test_bler_punc)) + except: + print('No puncturation is there.') + + print('final results on SNRs ', snrs) + print('BER', ber_res) + print('BLER', bler_res) + print('final results on punctured SNRs ', snrs) + print('BER', ber_res_punc) + print('BLER', bler_res_punc) + + # compute adjusted SNR. (some quantization might make power!=1.0) + enc_power = 0.0 + with torch.no_grad(): + for idx in range(num_test_batch): + X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + X_test = X_test.to(device) + X_code = model.enc(X_test) + enc_power += torch.std(X_code) + enc_power /= float(num_test_batch) + print('encoder power is',enc_power) + adj_snrs = [snr_sigma2db(snr_db2sigma(item)/enc_power) for item in snrs] + print('adjusted SNR should be',adj_snrs) + + + + + + + + + + + + + + + + + diff --git a/modulation.py b/modulation.py deleted file mode 100644 index 35bcb30..0000000 --- a/modulation.py +++ /dev/null @@ -1,34 +0,0 @@ -__author__ = 'yihanjiang' - -''' -This requires to upgrade real value channels to complex channels, which requires quite a lot of engineering. -For now just use BPSK to make things easier. -BPSK is a over-loaded idea: means just real-value channel. - -Modulation means each symbol has limited power. -Modulati -Input: continuous value -''' - -def modulation(input_signal, mod_mode = 'bpsk'): - if mod_mode == 'bpsk': - return input_signal - - elif mod_mode == 'continuous_complex': - # build block_len * code_rate / 2 symbols. - pass - - elif mod_mode == 'qpsk': - pass - - elif mod_mode == 'qam': - pass - - - -def demod(rec_signal, demod_mode = 'bpsk'): - if demod_mode == 'bpsk': - return rec_signal - - - diff --git a/modulations.py b/modulations.py new file mode 100644 index 0000000..5062775 --- /dev/null +++ b/modulations.py @@ -0,0 +1,110 @@ +__author__ = 'yihanjiang' +import torch +import torch.nn.functional as F + +from cnn_utils import SameShapeConv1d + +############################################## +# STE implementation +############################################## + +class STEQuantize(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs): + + enc_value_limit = 1.0 + enc_quantize_level = 2.0 + + ctx.save_for_backward(inputs) + ctx.enc_value_limit = enc_value_limit + ctx.enc_quantize_level = enc_quantize_level + + x_lim_abs = enc_value_limit + x_lim_range = 2.0 * x_lim_abs + x_input_norm = torch.clamp(inputs, -x_lim_abs, x_lim_abs) + + if enc_quantize_level == 2: + outputs_int = torch.sign(x_input_norm) + else: + outputs_int = torch.round((x_input_norm +x_lim_abs) * ((enc_quantize_level - 1.0)/x_lim_range)) * x_lim_range/(enc_quantize_level - 1.0) - x_lim_abs + + return outputs_int + + @staticmethod + def backward(ctx, grad_output): + + input, = ctx.saved_tensors + grad_output[input>ctx.enc_value_limit]=0 + grad_output[input<-ctx.enc_value_limit]=0 + grad_input = grad_output.clone() + + return grad_input, None, None + + + + +class Modulation(torch.nn.Module): + def __init__(self, args): + super(Modulation, self).__init__() + + use_cuda = not args.no_cuda and torch.cuda.is_available() + self.this_device = torch.device("cuda" if use_cuda else "cpu") + self.args = args + + self.mod_layer = SameShapeConv1d(num_layer=args.mod_num_layer, in_channels=args.mod_rate, + out_channels= args.mod_num_unit, kernel_size = 1, no_act = False) + self.mod_final = SameShapeConv1d(num_layer=1, in_channels=args.mod_num_unit, + out_channels= 2, kernel_size = 1, no_act = True) + + + def forward(self, inputs): + # Input has shape (B, L, R) + # output has shape (B, L * mod_rate, 2), last dimension is real, imag. + + inputs_flatten = inputs.view(self.args.batch_size, int(self.args.block_len * self.args.code_rate_n / self.args.mod_rate), self.args.mod_rate) + mod_symbols = self.mod_final(self.mod_layer(inputs_flatten)) + + if self.args.mod_pc == 'qpsk': + this_mean = torch.mean(mod_symbols) + this_std = torch.std(mod_symbols) + mod_symbols = (mod_symbols - this_mean)/this_std + stequantize = STEQuantize.apply + outputs = stequantize(mod_symbols) + elif self.args.mod_pc == 'symbol_power': + this_mean = torch.mean(torch.mean(mod_symbols, dim=2), dim=0) + new_symbol = mod_symbols.permute(0,2,1) + new_symbol_shape = new_symbol.shape + this_std = torch.std(new_symbol.view(new_symbol_shape[0]*new_symbol_shape[1],new_symbol_shape[2]), dim=0) + + this_mean = this_mean.unsqueeze(0).unsqueeze(2) + this_std = this_std.unsqueeze(0).unsqueeze(2) + outputs = (mod_symbols - this_mean)/this_std + + elif self.args.mod_pc == 'block_power': + this_mean = torch.mean(mod_symbols) + this_std = torch.std(mod_symbols) + outputs = (mod_symbols - this_mean)/this_std + + return outputs + + +class DeModulation(torch.nn.Module): + def __init__(self, args): + super(DeModulation, self).__init__() + + use_cuda = not args.no_cuda and torch.cuda.is_available() + self.this_device = torch.device("cuda" if use_cuda else "cpu") + self.args = args + + self.demod_layer = SameShapeConv1d(num_layer=args.demod_num_layer, in_channels=2, + out_channels= self.args.demod_num_unit, kernel_size = 1) + self.demod_final = SameShapeConv1d(num_layer=1, in_channels=args.demod_num_unit, + out_channels= args.mod_rate, kernel_size = 1, no_act = True) + + def forward(self, inputs): + # Input has shape (B, L * mod_rate, 2) + # output has shape (B, L, R) , last dimension is real, imag. + demod_symbols = self.demod_final(self.demod_layer(inputs)) + demod_codes = demod_symbols.reshape(self.args.batch_size, self.args.block_len, self.args.code_rate_n) + + return demod_codes \ No newline at end of file diff --git a/trainer.py b/trainer.py index 3655fd2..b163064 100644 --- a/trainer.py +++ b/trainer.py @@ -52,12 +52,12 @@ def train(epoch, model, optimizer, args, use_cuda = False, verbose = True, mode else: X_train = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) - + noise_shape = (args.batch_size, args.block_len, args.code_rate_n) # train encoder/decoder with different SNR... seems to be a good practice. if mode == 'encoder': - fwd_noise = generate_noise(X_train.shape, args, snr_low=args.train_enc_channel_low, snr_high=args.train_enc_channel_high, mode = 'encoder') + fwd_noise = generate_noise(noise_shape, args, snr_low=args.train_enc_channel_low, snr_high=args.train_enc_channel_high, mode = 'encoder') else: - fwd_noise = generate_noise(X_train.shape, args, snr_low=args.train_dec_channel_low, snr_high=args.train_dec_channel_high, mode = 'decoder') + fwd_noise = generate_noise(noise_shape, args, snr_low=args.train_dec_channel_low, snr_high=args.train_dec_channel_high, mode = 'decoder') X_train, fwd_noise = X_train.to(device), fwd_noise.to(device) @@ -96,7 +96,8 @@ def validate(model, optimizer, args, use_cuda = False, verbose = True): num_test_batch = int(args.num_block/args.batch_size * args.test_ratio) for batch_idx in range(num_test_batch): X_test = torch.randint(0, 2, (args.batch_size, args.block_len, args.code_rate_k), dtype=torch.float) - fwd_noise = generate_noise(X_test.shape, args, + noise_shape = (args.batch_size, args.block_len, args.code_rate_n) + fwd_noise = generate_noise(noise_shape, args, snr_low=args.train_enc_channel_low, snr_high=args.train_enc_channel_low) @@ -143,14 +144,16 @@ def test(model, args, block_len = 'default',use_cuda = False): # Precomputes Norm Statistics. if args.precompute_norm_stats: - num_test_batch = int(args.num_block/(args.batch_size)* args.test_ratio) - for batch_idx in range(num_test_batch): - X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) - X_test = X_test.to(device) - _ = model.enc(X_test) - print('Pre-computed norm statistics mean ',model.enc.mean_scalar, 'std ', model.enc.std_scalar) + with torch.no_grad(): + num_test_batch = int(args.num_block/(args.batch_size)* args.test_ratio) + for batch_idx in range(num_test_batch): + X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + X_test = X_test.to(device) + _ = model.enc(X_test) + print('Pre-computed norm statistics mean ',model.enc.mean_scalar, 'std ', model.enc.std_scalar) ber_res, bler_res = [], [] + ber_res_punc, bler_res_punc = [], [] snr_interval = (args.snr_test_end - args.snr_test_start)* 1.0 / (args.snr_points-1) snrs = [snr_interval* item + args.snr_test_start for item in range(args.snr_points)] print('SNRS', snrs) @@ -159,7 +162,7 @@ def test(model, args, block_len = 'default',use_cuda = False): for sigma, this_snr in zip(sigmas, snrs): test_ber, test_bler = .0, .0 with torch.no_grad(): - num_test_batch = int(args.num_block/(args.batch_size)* args.test_ratio) + num_test_batch = int(args.num_block/(args.batch_size)) for batch_idx in range(num_test_batch): X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) fwd_noise = generate_noise(X_test.shape, args, test_sigma=sigma) @@ -182,7 +185,31 @@ def test(model, args, block_len = 'default',use_cuda = False): if args.print_pos_power: print('code power', codes_power/num_test_batch) if args.print_pos_ber: - print('positional ber', test_pos_ber/num_test_batch) + res_pos = test_pos_ber/num_test_batch + res_pos_arg = np.array(res_pos.cpu()).argsort()[::-1] + res_pos_arg = res_pos_arg.tolist() + print('positional ber', res_pos) + print('positional argmax',res_pos_arg) + try: + test_ber_punc, test_bler_punc = .0, .0 + for batch_idx in range(num_test_batch): + X_test = torch.randint(0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) + fwd_noise = generate_noise(X_test.shape, args, test_sigma=sigma) + X_test, fwd_noise= X_test.to(device), fwd_noise.to(device) + + X_hat_test, the_codes = model(X_test, fwd_noise) + + test_ber_punc += errors_ber(X_hat_test,X_test, positions = res_pos_arg[:args.num_ber_puncture]) + test_bler_punc += errors_bler(X_hat_test,X_test, positions = res_pos_arg[:args.num_ber_puncture]) + + if batch_idx == 0: + test_pos_ber = errors_ber_pos(X_hat_test,X_test) + codes_power = code_power(the_codes) + else: + test_pos_ber += errors_ber_pos(X_hat_test,X_test) + codes_power += code_power(the_codes) + except: + print('no pos BER specified.') test_ber /= num_test_batch test_bler /= num_test_batch @@ -190,9 +217,21 @@ def test(model, args, block_len = 'default',use_cuda = False): ber_res.append(float(test_ber)) bler_res.append( float(test_bler)) + try: + test_ber_punc /= num_test_batch + test_bler_punc /= num_test_batch + print('Punctured Test SNR',this_snr ,'with ber ', float(test_ber_punc), 'with bler', float(test_bler_punc)) + ber_res_punc.append(float(test_ber_punc)) + bler_res_punc.append( float(test_bler_punc)) + except: + print('No puncturation is there.') + print('final results on SNRs ', snrs) print('BER', ber_res) print('BLER', bler_res) + print('final results on punctured SNRs ', snrs) + print('BER', ber_res_punc) + print('BLER', bler_res_punc) # compute adjusted SNR. (some quantization might make power!=1.0) enc_power = 0.0 diff --git a/utils.py b/utils.py index 4730aed..c3821ee 100644 --- a/utils.py +++ b/utils.py @@ -3,12 +3,18 @@ import numpy as np import math -def errors_ber(y_true, y_pred): +def errors_ber(y_true, y_pred, positions = 'default'): y_true = y_true.view(y_true.shape[0], -1, 1) y_pred = y_pred.view(y_pred.shape[0], -1, 1) myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() - res = sum(sum(myOtherTensor))/(myOtherTensor.shape[0]*myOtherTensor.shape[1]) + if positions == 'default': + res = sum(sum(myOtherTensor))/(myOtherTensor.shape[0]*myOtherTensor.shape[1]) + else: + res = torch.mean(myOtherTensor, dim=0).type(torch.FloatTensor) + for pos in positions: + res[pos] = 0.0 + res = torch.mean(res) return res def errors_ber_list(y_true, y_pred): @@ -22,7 +28,7 @@ def errors_ber_list(y_true, y_pred): return res_list_tensor -def errors_ber_pos(y_true, y_pred): +def errors_ber_pos(y_true, y_pred, discard_pos = []): y_true = y_true.view(y_true.shape[0], -1, 1) y_pred = y_pred.view(y_pred.shape[0], -1, 1) @@ -40,7 +46,8 @@ def code_power(the_codes): res = tmp return res -def errors_bler(y_true, y_pred): +def errors_bler(y_true, y_pred, positions = 'default'): + y_true = y_true.view(y_true.shape[0], -1, 1) y_pred = y_pred.view(y_pred.shape[0], -1, 1) @@ -48,7 +55,14 @@ def errors_bler(y_true, y_pred): X_test = torch.round(y_true) tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) tp0 = tp0.cpu().numpy() - bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) + + if positions == 'default': + bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) + else: + for pos in positions: + tp0[:, pos] = 0.0 + bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) + return bler_err_rate # note there are a few definitions of SNR. In our result, we stick to the following SNR setup.