Skip to content

Commit

Permalink
add some CCE and CDec
Browse files Browse the repository at this point in the history
  • Loading branch information
yihanjiang committed Dec 29, 2019
1 parent 2225d10 commit 8e108e7
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 8 deletions.
17 changes: 11 additions & 6 deletions encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def power_constraint(self, x_input):
stequantize = STEQuantize.apply
x_input_norm = stequantize(x_input_norm, self.args)

if self.args.enc_truncate_limit>0:
x_input_norm = torch.clamp(x_input_norm, -self.args.enc_truncate_limit, self.args.enc_truncate_limit)

return x_input_norm

Expand Down Expand Up @@ -531,17 +533,20 @@ def __init__(self, args, p_array):
self.enc_cnn_1 = CNN2d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k,
out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size)

self.enc_linear_1 = torch.nn.Conv2d(args.enc_num_unit, 1, 1, 1, 0, bias=True)
self.enc_linear_1 = CNN2d(num_layer=1, in_channels= args.enc_num_unit,
out_channels= 1, kernel_size = 1, no_act=True)

self.enc_cnn_2 = CNN2d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k,
out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size)

self.enc_linear_2 = torch.nn.Conv2d(args.enc_num_unit, 1, 1, 1, 0, bias=True)
self.enc_linear_2 = CNN2d(num_layer=1, in_channels= args.enc_num_unit,
out_channels= 1, kernel_size = 1, no_act=True)

self.enc_cnn_3 = CNN2d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k,
out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size)

self.enc_linear_3 = torch.nn.Conv2d(args.enc_num_unit, 1, 1, 1, 0, bias=True)
self.enc_linear_3 = CNN2d(num_layer=1, in_channels= args.enc_num_unit,
out_channels= 1, kernel_size = 1, no_act=True)

self.interleaver = Interleaver2D(args, p_array)

Expand All @@ -563,15 +568,15 @@ def forward(self, inputs):

inputs = 2.0*inputs - 1.0
x_sys = self.enc_cnn_1(inputs)
x_sys = self.enc_act(self.enc_linear_1(x_sys))
x_sys = self.enc_linear_1(x_sys)

x_p1 = self.enc_cnn_2(inputs)
x_p1 = self.enc_act(self.enc_linear_2(x_p1))
x_p1 = self.enc_linear_2(x_p1)

x_sys_int = self.interleaver(inputs)

x_p2 = self.enc_cnn_3(x_sys_int)
x_p2 = self.enc_act(self.enc_linear_3(x_p2))
x_p2 = self.enc_linear_3(x_p2)

x_tx = torch.cat([x_sys,x_p1, x_p2], dim = 1)
x_tx = x_tx.view(self.args.batch_size, self.args.code_rate_n, self.args.block_len)
Expand Down
3 changes: 2 additions & 1 deletion get_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def get_args():
parser.add_argument('-train_channel_mode',
choices=['block_norm','block_norm_ste'],
default='block_norm')
parser.add_argument('-enc_truncate_limit', type=float, default=0, help='0 means no truncation')


################################################################
Expand All @@ -154,7 +155,7 @@ def get_args():
# Loss related parameters
################################################################

parser.add_argument('-loss', choices=['bce', 'mse','focal', 'bce_block', 'maxBCE', 'bce_rl', 'enc_rl', 'soft_ber'],
parser.add_argument('-loss', choices=['bce', 'mse','focal', 'bce_block', 'maxBCE', 'bce_rl', 'enc_rl', 'soft_ber', 'sortBCE'],
default='bce', help='only BCE works')

parser.add_argument('-ber_lambda', type = float, default=1.0, help = 'default 0.0, the more emphasis on BER loss, only for bce_rl')
Expand Down
58 changes: 58 additions & 0 deletions interleavers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,64 @@ def forward(self, inputs):

# TBD: change 2D interleavers
# 2D interleavers seems not working well... Don't know why...
class Interleaver2Dold(torch.nn.Module):
def __init__(self, args, p_array):
super(Interleaver2D, self).__init__()
self.args = args
self.p_array = torch.LongTensor(p_array).view(len(p_array))#.view(args.img_size, args.img_size)

def set_parray(self, p_array):
self.p_array = torch.LongTensor(p_array).view(len(p_array))#.view(self.args.img_size, args.img_size)

def forward(self, inputs):
input_shape = inputs.shape

inputs = inputs.view(input_shape[0], input_shape[1], input_shape[2]*input_shape[3])
inputs = inputs.permute(2, 0, 1)
res = inputs[self.p_array]


res = res.permute(1, 2, 0)
res = res.view(input_shape)

return res

class DeInterleaver2Dold(torch.nn.Module):
def __init__(self, args, p_array):
super(DeInterleaver2D, self).__init__()
self.args = args

self.reverse_p_array = [0 for _ in range(len(p_array))]
for idx in range(len(p_array)):
self.reverse_p_array[p_array[idx]] = idx

self.reverse_p_array = torch.LongTensor(self.reverse_p_array).view(self.args.img_size**2)

def set_parray(self, p_array):

self.reverse_p_array = [0 for _ in range(len(p_array))]
for idx in range(len(p_array)):
self.reverse_p_array[p_array[idx]] = idx

self.reverse_p_array = torch.LongTensor(self.reverse_p_array).view(self.args.img_size**2)

def forward(self, inputs):
input_shape = inputs.shape

inputs = inputs.view(input_shape[0], input_shape[1], input_shape[2]* input_shape[3])


inputs = inputs.permute(2,0,1)
res = inputs[self.reverse_p_array]

res = res.permute(1,2,0)
res = res.view(input_shape)

return res


# TBD: change 2D interleavers
# Play with real 2D interleavers: p_array with 2-step interleaving.
class Interleaver2D(torch.nn.Module):
def __init__(self, args, p_array):
super(Interleaver2D, self).__init__()
Expand Down
17 changes: 16 additions & 1 deletion loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,25 @@ def customized_loss(output, X_train, args, size_average = True, noise = None, co
BCE_loss_tmp = F.binary_cross_entropy(output, X_train, reduce=False)

bce_loss = torch.mean(BCE_loss_tmp)
tmp, _ = torch.max(BCE_loss_tmp, dim=1, keepdim=False)
pos_loss = torch.mean(BCE_loss_tmp, dim=0)

tmp, _ = torch.max(pos_loss, dim=0)
max_loss = torch.mean(tmp)

loss = bce_loss + args.lambda_maxBCE * max_loss

elif args.loss == 'sortBCE':
output = torch.clamp(output, 0.0, 1.0)
BCE_loss_tmp = F.binary_cross_entropy(output, X_train, reduce=False)

bce_loss = torch.mean(BCE_loss_tmp)
pos_loss = torch.mean(BCE_loss_tmp, dim=0)

tmp, _ = torch.sort(pos_loss, dim=-1, descending=True, out=None)

sort_loss = torch.sum(tmp[:5, :])

loss = bce_loss + args.lambda_maxBCE * sort_loss

return loss

8 changes: 8 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,11 @@ def import_dec(args):











33 changes: 33 additions & 0 deletions modulation.py
Original file line number Diff line number Diff line change
@@ -1 +1,34 @@
__author__ = 'yihanjiang'

'''
This requires to upgrade real value channels to complex channels, which requires quite a lot of engineering.
For now just use BPSK to make things easier.
BPSK is a over-loaded idea: means just real-value channel.
Modulation means each symbol has limited power.
Modulati
Input: continuous value
'''

def modulation(input_signal, mod_mode = 'bpsk'):
if mod_mode == 'bpsk':
return input_signal

elif mod_mode == 'continuous_complex':
# build block_len * code_rate / 2 symbols.
pass

elif mod_mode == 'qpsk':
pass

elif mod_mode == 'qam':
pass



def demod(rec_signal, demod_mode = 'bpsk'):
if demod_mode == 'bpsk':
return rec_signal



0 comments on commit 8e108e7

Please sign in to comment.