Skip to content

Commit

Permalink
add a lot of experimental codes..
Browse files Browse the repository at this point in the history
  • Loading branch information
yihanjiang committed Jan 11, 2020
1 parent 8e108e7 commit 2a94121
Show file tree
Hide file tree
Showing 13 changed files with 890 additions and 125 deletions.
55 changes: 55 additions & 0 deletions channel_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,58 @@ def forward(self, input, fwd_noise):
x_dec = self.dec(received_codes)

return x_dec, codes



class Channel_ModAE(torch.nn.Module):
def __init__(self, args, enc, dec, mod, demod, modulation = 'qpsk'):
super(Channel_ModAE, self).__init__()
use_cuda = not args.no_cuda and torch.cuda.is_available()
self.this_device = torch.device("cuda" if use_cuda else "cpu")

self.args = args
self.enc = enc
self.dec = dec
self.mod = mod
self.demod = demod

def forward(self, input, fwd_noise):
# Setup Interleavers.
if self.args.is_interleave == 0:
pass

elif self.args.is_same_interleaver == 0:
interleaver = RandInterlv.RandInterlv(self.args.block_len, np.random.randint(0, 1000))

p_array = interleaver.p_array
self.enc.set_interleaver(p_array)
self.dec.set_interleaver(p_array)

else:# self.args.is_same_interleaver == 1
interleaver = RandInterlv.RandInterlv(self.args.block_len, 0) # not random anymore!
p_array = interleaver.p_array
self.enc.set_interleaver(p_array)
self.dec.set_interleaver(p_array)

codes = self.enc(input)
symbols = self.mod(codes)

# Setup channel mode:
if self.args.channel in ['awgn', 't-dist', 'radar', 'ge_awgn']:
received_symbols = symbols + fwd_noise

elif self.args.channel == 'fading':
print('Fading not implemented')

else:
print('default AWGN channel')
received_symbols = symbols + fwd_noise

if self.args.rec_quantize:
myquantize = MyQuantize.apply
received_symbols = myquantize(received_symbols, self.args.rec_quantize_level, self.args.rec_quantize_level)

x_rec = self.demod(received_symbols)
x_dec = self.dec(x_rec)

return x_dec, symbols
39 changes: 19 additions & 20 deletions channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from utils import snr_db2sigma, snr_sigma2db
import numpy as np

def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, snr_high = 0.0, mode = 'encoder'):
def generate_noise(noise_shape, args, test_sigma = 'default', snr_low = 0.0, snr_high = 0.0, mode = 'encoder'):
# SNRs at training
if test_sigma == 'default':
if args.channel == 'bec':
Expand All @@ -22,7 +22,7 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s
this_sigma_low = snr_db2sigma(snr_low)
this_sigma_high= snr_db2sigma(snr_high)
# mixture of noise sigma.
this_sigma = (this_sigma_low - this_sigma_high) * torch.rand((X_train_shape[0], X_train_shape[1], args.code_rate_n)) + this_sigma_high
this_sigma = (this_sigma_low - this_sigma_high) * torch.rand(noise_shape) + this_sigma_high

else:
if args.channel in ['bec', 'bsc', 'ge']: # bsc/bec noises
Expand All @@ -32,25 +32,25 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s

# SNRs at testing
if args.channel == 'awgn':
fwd_noise = this_sigma * torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float)
fwd_noise = this_sigma * torch.randn(noise_shape, dtype=torch.float)

elif args.channel == 't-dist':
fwd_noise = this_sigma * torch.from_numpy(np.sqrt((args.vv-2)/args.vv) * np.random.standard_t(args.vv, size = (X_train_shape[0], X_train_shape[1], args.code_rate_n))).type(torch.FloatTensor)
fwd_noise = this_sigma * torch.from_numpy(np.sqrt((args.vv-2)/args.vv) * np.random.standard_t(args.vv, size = noise_shape)).type(torch.FloatTensor)

elif args.channel == 'radar':
add_pos = np.random.choice([0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n),
add_pos = np.random.choice([0.0, 1.0], noise_shape,
p=[1 - args.radar_prob, args.radar_prob])

corrupted_signal = args.radar_power* np.random.standard_normal( size = (X_train_shape[0], X_train_shape[1], args.code_rate_n) ) * add_pos
fwd_noise = this_sigma * torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) +\
corrupted_signal = args.radar_power* np.random.standard_normal( size = noise_shape ) * add_pos
fwd_noise = this_sigma * torch.randn(noise_shape, dtype=torch.float) +\
torch.from_numpy(corrupted_signal).type(torch.FloatTensor)

elif args.channel == 'bec':
fwd_noise = torch.from_numpy(np.random.choice([0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n),
fwd_noise = torch.from_numpy(np.random.choice([0.0, 1.0], noise_shape,
p=[this_sigma, 1 - this_sigma])).type(torch.FloatTensor)

elif args.channel == 'bsc':
fwd_noise = torch.from_numpy(np.random.choice([0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n),
fwd_noise = torch.from_numpy(np.random.choice([0.0, 1.0], noise_shape,
p=[this_sigma, 1 - this_sigma])).type(torch.FloatTensor)
elif args.channel == 'ge_awgn':
#G-E AWGN channel
Expand All @@ -59,12 +59,12 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s
bsc_k = snr_db2sigma(snr_sigma2db(this_sigma) + 1) # accuracy on good state
bsc_h = snr_db2sigma(snr_sigma2db(this_sigma) - 1) # accuracy on good state

fwd_noise = np.zeros((X_train_shape[0], X_train_shape[1], args.code_rate_n))
for batch_idx in range(X_train_shape[0]):
for code_idx in range(args.code_rate_n):
fwd_noise = np.zeros(noise_shape)
for batch_idx in range(noise_shape[0]):
for code_idx in range(noise_shape[2]):

good = True
for time_idx in range(X_train_shape[1]):
for time_idx in range(noise_shape[1]):
if good:
if test_sigma == 'default':
fwd_noise[batch_idx,time_idx, code_idx] = bsc_k[batch_idx,time_idx, code_idx]
Expand All @@ -80,7 +80,7 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s
else:
print('bad!!! something happens')

fwd_noise = torch.from_numpy(fwd_noise).type(torch.FloatTensor)* torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float)
fwd_noise = torch.from_numpy(fwd_noise).type(torch.FloatTensor)* torch.randn(noise_shape, dtype=torch.float)

elif args.channel == 'ge':
#G-E discrete channel
Expand All @@ -89,12 +89,12 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s
bsc_k = 1.0 # accuracy on good state
bsc_h = this_sigma# accuracy on good state

fwd_noise = np.zeros((X_train_shape[0], X_train_shape[1], args.code_rate_n))
for batch_idx in range(X_train_shape[0]):
for code_idx in range(args.code_rate_n):
fwd_noise = np.zeros(noise_shape)
for batch_idx in range(noise_shape[0]):
for code_idx in range(noise_shape[2]):

good = True
for time_idx in range(X_train_shape[1]):
for time_idx in range(noise_shape[1]):
if good:
tmp = np.random.choice([0.0, 1.0], p=[1-bsc_k, bsc_k])
fwd_noise[batch_idx,time_idx, code_idx] = tmp
Expand All @@ -110,10 +110,9 @@ def generate_noise(X_train_shape, args, test_sigma = 'default', snr_low = 0.0, s

else:
# Unspecific channel, use AWGN channel.
fwd_noise = this_sigma * torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float)
fwd_noise = this_sigma * torch.randn(noise_shape, dtype=torch.float)

return fwd_noise




8 changes: 6 additions & 2 deletions cnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

# utility for Same Shape CNN 1D
class SameShapeConv1d(torch.nn.Module):
def __init__(self, num_layer, in_channels, out_channels, kernel_size, activation = 'elu'):
def __init__(self, num_layer, in_channels, out_channels, kernel_size, activation = 'elu', no_act = False):
super(SameShapeConv1d, self).__init__()

self.cnns = torch.nn.ModuleList()
self.num_layer = num_layer
self.no_act = no_act
for idx in range(num_layer):
if idx == 0:
self.cnns.append(torch.nn.Conv1d(in_channels = in_channels, out_channels=out_channels,
Expand Down Expand Up @@ -36,7 +37,10 @@ def forward(self, inputs):
inputs = torch.transpose(inputs, 1,2)
x = inputs
for idx in range(self.num_layer):
x = self.activation(self.cnns[idx](x))
if self.no_act:
x = self.cnns[idx](x)
else:
x = self.activation(self.cnns[idx](x))

outputs = torch.transpose(x, 1,2)
return outputs
Expand Down
13 changes: 3 additions & 10 deletions decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,22 +277,15 @@ def forward(self, received):

from encoders import SameShapeConv1d
class DEC_LargeCNN2Int(torch.nn.Module):
def __init__(self, args, p_array):
def __init__(self, args, p_array1, p_array2):
super(DEC_LargeCNN2Int, self).__init__()
self.args = args

use_cuda = not args.no_cuda and torch.cuda.is_available()
self.this_device = torch.device("cuda" if use_cuda else "cpu")

self.interleaver1 = Interleaver(args, p_array)
self.deinterleaver1 = DeInterleaver(args, p_array)

seed2 = 1000
rand_gen2 = mtrand.RandomState(seed2)
p_array2 = rand_gen2.permutation(arange(args.block_len))

print('p_array1 dec', p_array)
print('p_array2 dec', p_array2)
self.interleaver1 = Interleaver(args, p_array1)
self.deinterleaver1 = DeInterleaver(args, p_array1)

self.interleaver2 = Interleaver(args, p_array2)
self.deinterleaver2 = DeInterleaver(args, p_array2)
Expand Down
42 changes: 13 additions & 29 deletions encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,35 +101,28 @@ def enc_act(self, inputs):

def power_constraint(self, x_input):

if not self.args.precompute_norm_stats:
if self.args.no_code_norm:
return x_input
else:
this_mean = torch.mean(x_input)
this_std = torch.std(x_input)
x_input_norm = (x_input-this_mean)*1.0 / this_std
else:
x_input_norm = (x_input - self.mean_scalar)/self.std_scalar

if self.training:
# save pretrained mean/std. Pretrained not implemented in this code version.
try:
if self.args.precompute_norm_stats:
self.num_test_block += 1.0
self.mean_scalar = (self.mean_scalar*(self.num_test_block-1) + this_mean)/self.num_test_block
self.std_scalar = (self.std_scalar*(self.num_test_block-1) + this_std)/self.num_test_block
except:
print('group normalization seems wired.!')
x_input_norm = (x_input - self.mean_scalar)/self.std_scalar
else:
x_input_norm = (x_input-this_mean)*1.0 / this_std

if self.args.train_channel_mode == 'block_norm_ste':
stequantize = STEQuantize.apply
x_input_norm = stequantize(x_input_norm, self.args)

else:
if self.args.test_channel_mode == 'block_norm_ste':
stequantize = STEQuantize.apply
x_input_norm = stequantize(x_input_norm, self.args)

if self.args.enc_truncate_limit>0:
x_input_norm = torch.clamp(x_input_norm, -self.args.enc_truncate_limit, self.args.enc_truncate_limit)
if self.args.enc_truncate_limit>0:
x_input_norm = torch.clamp(x_input_norm, -self.args.enc_truncate_limit, self.args.enc_truncate_limit)

return x_input_norm
return x_input_norm

# Encoder with interleaver. Support different code rate.
class ENC_turbofy_rate2(ENCBase):
Expand Down Expand Up @@ -388,7 +381,7 @@ def forward(self, inputs):
#######################################################
from cnn_utils import SameShapeConv1d
class ENC_interCNN2Int(ENCBase):
def __init__(self, args, p_array):
def __init__(self, args, p_array1, p_array2):
# turbofy only for code rate 1/3
super(ENC_interCNN2Int, self).__init__(args)
self.args = args
Expand All @@ -411,16 +404,7 @@ def __init__(self, args, p_array):

self.enc_linear_3 = torch.nn.Linear(args.enc_num_unit, 1)

self.interleaver1 = Interleaver(args, p_array)


seed2 = 1000
rand_gen2 = mtrand.RandomState(seed2)
p_array2 = rand_gen2.permutation(arange(args.block_len))

print('p_array1', p_array)
print('p_array2', p_array2)

self.interleaver1 = Interleaver(args, p_array1)
self.interleaver2 = Interleaver(args, p_array2)


Expand Down Expand Up @@ -451,7 +435,7 @@ def forward(self, inputs):
x_p2 = self.enc_cnn_3(x_sys_int2)
x_p2 = self.enc_act(self.enc_linear_3(x_p2))

x_tx = torch.cat([x_sys,x_p1, x_p2], dim = 2)
x_tx = torch.cat([x_sys, x_p1, x_p2], dim = 2)

codes = self.power_constraint(x_tx)

Expand Down
31 changes: 29 additions & 2 deletions get_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_args():
'rate3_cnn2d',
'Turbo_rate3_757', # Turbo Code, rate 1/3, 757.
'Turbo_rate3_lte', # Turbo Code, rate 1/3, LTE.
'turboae_2int', # experimental, use multiple interleavers
],
default='TurboAE_rate3_cnn2d')

Expand All @@ -33,6 +34,7 @@ def get_args():
'nbcjr_rate3', # NeuralBCJR Decoder, rate 1/3, allow ft size.
'rate3_cnn', # CNN Encoder, rate 1/3. No Interleaver
'rate3_cnn2d',
'turboae_2int', # experimental, use multiple interleavers
],
default='TurboAE_rate3_cnn2d')
################################################################
Expand Down Expand Up @@ -81,7 +83,7 @@ def get_args():
parser.add_argument('-extrinsic', type=int, default=1)
parser.add_argument('-num_iter_ft', type=int, default=5)
parser.add_argument('-is_interleave', type=int, default=1, help='0 is not interleaving, 1 is fixed interleaver, >1 is random interleaver')
parser.add_argument('-is_same_interleaver', type=int, default=0, help='not random interleaver, potentially finetune?')
parser.add_argument('-is_same_interleaver', type=int, default=1, help='not random interleaver, potentially finetune?')
parser.add_argument('-is_parallel', type=int, default=0)
# CNN related
parser.add_argument('-enc_kernel_size', type=int, default=5)
Expand All @@ -91,12 +93,14 @@ def get_args():
parser.add_argument('-enc_num_layer', type=int, default=2)
parser.add_argument('-dec_num_layer', type=int, default=5)

parser.add_argument('-enc_num_unit', type=int, default=100, help = 'This is CNN number of filters, and RNN units')

parser.add_argument('-dec_num_unit', type=int, default=100, help = 'This is CNN number of filters, and RNN units')
parser.add_argument('-enc_num_unit', type=int, default=100, help = 'This is CNN number of filters, and RNN units')

parser.add_argument('-enc_act', choices=['tanh', 'selu', 'relu', 'elu', 'sigmoid', 'linear'], default='elu', help='only elu works')
parser.add_argument('-dec_act', choices=['tanh', 'selu', 'relu', 'elu', 'sigmoid', 'linear'], default='linear')

parser.add_argument('-num_ber_puncture', type=int, default=5, help = 'Puncture bad BER positions')

################################################################
# Training ALgorithm related parameters
Expand Down Expand Up @@ -133,6 +137,29 @@ def get_args():
default='block_norm')
parser.add_argument('-enc_truncate_limit', type=float, default=0, help='0 means no truncation')

################################################################
# Modulation related parameters
################################################################
parser.add_argument('-mod_rate', type=int, default=2, help = 'code: (B, L, R), mode_output (B, L*R/mod_rate, 2)')
parser.add_argument('-mod_num_layer', type=int, default=1, help = '')
parser.add_argument('-mod_num_unit', type=int, default=20, help = '')
parser.add_argument('-demod_num_layer', type=int, default=1, help = '')
parser.add_argument('-demod_num_unit', type=int, default=20, help = '')

parser.add_argument('-mod_lr', type = float, default=0.005, help='modulation leanring rate')
parser.add_argument('-demod_lr', type = float, default=0.005, help='demodulation leanring rate')

parser.add_argument('-num_train_mod', type=int, default=1, help = '')
parser.add_argument('-num_train_demod', type=int, default=5, help = '')

parser.add_argument('-mod_pc',
choices=['qpsk','symbol_power', 'block_power'],
default='block_power')

parser.add_argument('--no_code_norm', action='store_true', default=False,
help='the output of encoder is not normalized. Modulation do the work')



################################################################
# STE related parameters
Expand Down
Loading

0 comments on commit 2a94121

Please sign in to comment.