diff --git a/encoders.py b/encoders.py index c2593c4..cebe48d 100644 --- a/encoders.py +++ b/encoders.py @@ -126,6 +126,8 @@ def power_constraint(self, x_input): stequantize = STEQuantize.apply x_input_norm = stequantize(x_input_norm, self.args) + if self.args.enc_truncate_limit>0: + x_input_norm = torch.clamp(x_input_norm, -self.args.enc_truncate_limit, self.args.enc_truncate_limit) return x_input_norm @@ -531,17 +533,20 @@ def __init__(self, args, p_array): self.enc_cnn_1 = CNN2d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k, out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size) - self.enc_linear_1 = torch.nn.Conv2d(args.enc_num_unit, 1, 1, 1, 0, bias=True) + self.enc_linear_1 = CNN2d(num_layer=1, in_channels= args.enc_num_unit, + out_channels= 1, kernel_size = 1, no_act=True) self.enc_cnn_2 = CNN2d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k, out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size) - self.enc_linear_2 = torch.nn.Conv2d(args.enc_num_unit, 1, 1, 1, 0, bias=True) + self.enc_linear_2 = CNN2d(num_layer=1, in_channels= args.enc_num_unit, + out_channels= 1, kernel_size = 1, no_act=True) self.enc_cnn_3 = CNN2d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k, out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size) - self.enc_linear_3 = torch.nn.Conv2d(args.enc_num_unit, 1, 1, 1, 0, bias=True) + self.enc_linear_3 = CNN2d(num_layer=1, in_channels= args.enc_num_unit, + out_channels= 1, kernel_size = 1, no_act=True) self.interleaver = Interleaver2D(args, p_array) @@ -563,15 +568,15 @@ def forward(self, inputs): inputs = 2.0*inputs - 1.0 x_sys = self.enc_cnn_1(inputs) - x_sys = self.enc_act(self.enc_linear_1(x_sys)) + x_sys = self.enc_linear_1(x_sys) x_p1 = self.enc_cnn_2(inputs) - x_p1 = self.enc_act(self.enc_linear_2(x_p1)) + x_p1 = self.enc_linear_2(x_p1) x_sys_int = self.interleaver(inputs) x_p2 = self.enc_cnn_3(x_sys_int) - x_p2 = self.enc_act(self.enc_linear_3(x_p2)) + x_p2 = self.enc_linear_3(x_p2) x_tx = torch.cat([x_sys,x_p1, x_p2], dim = 1) x_tx = x_tx.view(self.args.batch_size, self.args.code_rate_n, self.args.block_len) diff --git a/get_args.py b/get_args.py index 00019e9..b3deddd 100644 --- a/get_args.py +++ b/get_args.py @@ -131,6 +131,7 @@ def get_args(): parser.add_argument('-train_channel_mode', choices=['block_norm','block_norm_ste'], default='block_norm') + parser.add_argument('-enc_truncate_limit', type=float, default=0, help='0 means no truncation') ################################################################ @@ -154,7 +155,7 @@ def get_args(): # Loss related parameters ################################################################ - parser.add_argument('-loss', choices=['bce', 'mse','focal', 'bce_block', 'maxBCE', 'bce_rl', 'enc_rl', 'soft_ber'], + parser.add_argument('-loss', choices=['bce', 'mse','focal', 'bce_block', 'maxBCE', 'bce_rl', 'enc_rl', 'soft_ber', 'sortBCE'], default='bce', help='only BCE works') parser.add_argument('-ber_lambda', type = float, default=1.0, help = 'default 0.0, the more emphasis on BER loss, only for bce_rl') diff --git a/interleavers.py b/interleavers.py index 7ce3340..fe640f2 100644 --- a/interleavers.py +++ b/interleavers.py @@ -49,6 +49,64 @@ def forward(self, inputs): # TBD: change 2D interleavers # 2D interleavers seems not working well... Don't know why... +class Interleaver2Dold(torch.nn.Module): + def __init__(self, args, p_array): + super(Interleaver2D, self).__init__() + self.args = args + self.p_array = torch.LongTensor(p_array).view(len(p_array))#.view(args.img_size, args.img_size) + + def set_parray(self, p_array): + self.p_array = torch.LongTensor(p_array).view(len(p_array))#.view(self.args.img_size, args.img_size) + + def forward(self, inputs): + input_shape = inputs.shape + + inputs = inputs.view(input_shape[0], input_shape[1], input_shape[2]*input_shape[3]) + inputs = inputs.permute(2, 0, 1) + res = inputs[self.p_array] + + + res = res.permute(1, 2, 0) + res = res.view(input_shape) + + return res + +class DeInterleaver2Dold(torch.nn.Module): + def __init__(self, args, p_array): + super(DeInterleaver2D, self).__init__() + self.args = args + + self.reverse_p_array = [0 for _ in range(len(p_array))] + for idx in range(len(p_array)): + self.reverse_p_array[p_array[idx]] = idx + + self.reverse_p_array = torch.LongTensor(self.reverse_p_array).view(self.args.img_size**2) + + def set_parray(self, p_array): + + self.reverse_p_array = [0 for _ in range(len(p_array))] + for idx in range(len(p_array)): + self.reverse_p_array[p_array[idx]] = idx + + self.reverse_p_array = torch.LongTensor(self.reverse_p_array).view(self.args.img_size**2) + + def forward(self, inputs): + input_shape = inputs.shape + + inputs = inputs.view(input_shape[0], input_shape[1], input_shape[2]* input_shape[3]) + + + inputs = inputs.permute(2,0,1) + res = inputs[self.reverse_p_array] + + res = res.permute(1,2,0) + res = res.view(input_shape) + + return res + + +# TBD: change 2D interleavers +# Play with real 2D interleavers: p_array with 2-step interleaving. class Interleaver2D(torch.nn.Module): def __init__(self, args, p_array): super(Interleaver2D, self).__init__() diff --git a/loss.py b/loss.py index c976f5e..4a8e451 100644 --- a/loss.py +++ b/loss.py @@ -86,10 +86,25 @@ def customized_loss(output, X_train, args, size_average = True, noise = None, co BCE_loss_tmp = F.binary_cross_entropy(output, X_train, reduce=False) bce_loss = torch.mean(BCE_loss_tmp) - tmp, _ = torch.max(BCE_loss_tmp, dim=1, keepdim=False) + pos_loss = torch.mean(BCE_loss_tmp, dim=0) + + tmp, _ = torch.max(pos_loss, dim=0) max_loss = torch.mean(tmp) loss = bce_loss + args.lambda_maxBCE * max_loss + elif args.loss == 'sortBCE': + output = torch.clamp(output, 0.0, 1.0) + BCE_loss_tmp = F.binary_cross_entropy(output, X_train, reduce=False) + + bce_loss = torch.mean(BCE_loss_tmp) + pos_loss = torch.mean(BCE_loss_tmp, dim=0) + + tmp, _ = torch.sort(pos_loss, dim=-1, descending=True, out=None) + + sort_loss = torch.sum(tmp[:5, :]) + + loss = bce_loss + args.lambda_maxBCE * sort_loss + return loss diff --git a/main.py b/main.py index 8c4bdc5..b39c73d 100644 --- a/main.py +++ b/main.py @@ -251,3 +251,11 @@ def import_dec(args): + + + + + + + + diff --git a/modulation.py b/modulation.py index 64c65c1..35bcb30 100644 --- a/modulation.py +++ b/modulation.py @@ -1 +1,34 @@ __author__ = 'yihanjiang' + +''' +This requires to upgrade real value channels to complex channels, which requires quite a lot of engineering. +For now just use BPSK to make things easier. +BPSK is a over-loaded idea: means just real-value channel. + +Modulation means each symbol has limited power. +Modulati +Input: continuous value +''' + +def modulation(input_signal, mod_mode = 'bpsk'): + if mod_mode == 'bpsk': + return input_signal + + elif mod_mode == 'continuous_complex': + # build block_len * code_rate / 2 symbols. + pass + + elif mod_mode == 'qpsk': + pass + + elif mod_mode == 'qam': + pass + + + +def demod(rec_signal, demod_mode = 'bpsk'): + if demod_mode == 'bpsk': + return rec_signal + + +