Skip to content

Commit

Permalink
Merge pull request #11 from juglab/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
tibuch authored Jan 18, 2021
2 parents d5a41a5 + 00f862a commit 6eaf6af
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 23 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Build Python package:
`python setup.py bdist_wheel`

Build singularity recipe:
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.9-py3-none-any.whl /fourier_image_transformers-0.1.9-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.9-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.9.Singularity`
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.10-py3-none-any.whl /fourier_image_transformers-0.1.10-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.10-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.10.Singularity`

Build singularity container:
`sudo singularity build fit_v0.1.9.simg v0.1.9.Singularity`
`sudo singularity build fit_v0.1.10.simg v0.1.10.Singularity`
7 changes: 7 additions & 0 deletions fit/datamodules/tomo_rec/TRecDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ def setup(self, stage: Optional[str] = None):
gt_train *= circle
gt_val *= circle
gt_test *= circle

self.mean = gt_train.mean()
self.std = gt_train.std()

gt_train = normalize(gt_train, self.mean, self.std)
gt_val = normalize(gt_val, self.mean, self.std)
gt_test = normalize(gt_test, self.mean, self.std)
self.gt_ds = get_projection_dataset(
GroundTruthDataset(gt_train, gt_val, gt_test),
num_angles=self.num_angles, im_shape=450, impl='astra_cpu', inner_circle=self.inner_circle)
Expand Down
59 changes: 39 additions & 20 deletions fit/modules/TRecTransformerModule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.core.step_result import TrainResult, EvalResult
from torch.optim.lr_scheduler import ReduceLROnPlateau

from fit.datamodules.tomo_rec import MNISTTomoFourierTargetDataModule
Expand Down Expand Up @@ -69,11 +68,11 @@ def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords
dropout=self.hparams.dropout,
attention_dropout=self.hparams.attention_dropout)

x, y = torch.meshgrid(torch.arange(-MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1,
MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1),
torch.arange(-MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1,
MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1))
self.register_buffer('circle', torch.sqrt(x ** 2. + y ** 2.) <= MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2)
x, y = torch.meshgrid(torch.arange(-self.hparams.img_shape // 2 + 1,
self.hparams.img_shape // 2 + 1),
torch.arange(-self.hparams.img_shape // 2 + 1,
self.hparams.img_shape // 2 + 1))
self.register_buffer('circle', torch.sqrt(x ** 2. + y ** 2.) <= self.hparams.img_shape // 2)

def forward(self, x, out_pos_emb):
return self.trec.forward(x, out_pos_emb)
Expand Down Expand Up @@ -105,7 +104,7 @@ def _fc_loss(self, pred_fc, target_fc, mag_min, mag_max):
c1_unit = c1 / amp1
c2_unit = c2 / amp2

amp_loss = (1 + torch.pow(amp1 - amp2, 2))
amp_loss = (1 + torch.pow(pred_fc[..., 0] - target_fc[..., 0], 2)).unsqueeze(-1)
phi_loss = (2 - torch.sum(c1_unit * c2_unit, dim=-1, keepdim=True))
return torch.mean(amp_loss * phi_loss), torch.mean(amp_loss), torch.mean(phi_loss)

Expand Down Expand Up @@ -147,13 +146,25 @@ def training_epoch_end(self, outputs):
self.log('Train/amp_loss', torch.mean(torch.stack(amp_loss)), logger=True, on_epoch=True)
self.log('Train/phi_loss', torch.mean(torch.stack(phi_loss)), logger=True, on_epoch=True)

def _monitor_mse(self, pred, y_real, mag_min, mag_max):
dft_pred = convert_to_dft(fc=pred, mag_min=mag_min, mag_max=mag_max,
dst_flatten_coords=self.dst_flatten_coords, img_shape=self.hparams.img_shape)
y_hat = torch.roll(torch.fft.irfftn(dft_pred, dim=[1, 2], s=2 * (self.hparams.img_shape,)),
def _gt_bin_mse(self, y_fc, y_real, mag_min, mag_max):
dft_y = convert_to_dft(fc=y_fc, mag_min=mag_min, mag_max=mag_max,
dst_flatten_coords=self.dst_flatten_coords, img_shape=self.hparams.img_shape)
y_hat = torch.roll(torch.fft.irfftn(dft_y, dim=[1, 2], s=2 * (self.hparams.img_shape,)),
2 * (self.hparams.img_shape // 2,), (1, 2))

return F.mse_loss(y_hat, y_real)

def _val_psnr(self, pred_img, y_real):
pred_img_norm = denormalize(pred_img, self.trainer.datamodule.mean, self.trainer.datamodule.std)
y_real_norm = denormalize(y_real, self.trainer.datamodule.mean, self.trainer.datamodule.std)
psnrs = []
for i in range(len(pred_img_norm)):
gt = self.circle * y_real_norm[i]
psnrs.append(PSNR(gt, self.circle * pred_img_norm[i],
drange=gt.max()-gt.min()))

return torch.mean(torch.stack(psnrs))

def validation_step(self, batch, batch_idx):
x_fc, y_fc, y_real, (mag_min, mag_max) = batch
x_fc_, out_pos_emb, y_fc_ = self._bin_data(x_fc, y_fc)
Expand All @@ -165,21 +176,24 @@ def validation_step(self, batch, batch_idx):
val_loss, amp_loss, phi_loss = self.criterion(pred_fc, pred_img, y_fc_, mag_min, mag_max)

val_mse = F.mse_loss(pred_img, y_real)
bin_mse = self._monitor_mse(y_fc_, y_real, mag_min=mag_min, mag_max=mag_max)
val_psnr = self._val_psnr(pred_img, y_real)
bin_mse = self._gt_bin_mse(y_fc_, y_real, mag_min=mag_min, mag_max=mag_max)
self.log_dict({'val_loss': val_loss})
self.log_dict({'val_mse': val_mse})
self.log_dict({'val_psnr': val_psnr})
self.log_dict({'bin_mse': bin_mse})
if batch_idx == 0:
self.log_val_images(pred_img, x_fc, y_fc_, y_real, mag_min, mag_max)
return {'val_loss': val_loss, 'val_mse': val_mse, 'bin_mse': bin_mse, 'amp_loss': amp_loss,
return {'val_loss': val_loss, 'val_mse': val_mse, 'val_psnr': val_psnr, 'bin_mse': bin_mse,
'amp_loss': amp_loss,
'phi_loss': phi_loss}

def log_val_images(self, pred_img, x, y_fc, y_real, mag_min, mag_max):
x_fc = convert2FC(x, mag_min, mag_max)
dft_target = convert_to_dft(fc=y_fc, mag_min=mag_min, mag_max=mag_max,
dst_flatten_coords=self.dst_flatten_coords, img_shape=self.hparams.img_shape)

for i in range(3):
for i in range(min(3, len(pred_img))):
x_dft = fft_interpolate(self.x_coords_proj.cpu().numpy(), self.y_coords_proj.cpu().numpy(),
self.x_coords_img.cpu().numpy(), self.y_coords_img.cpu().numpy(),
x_fc[i][self.src_flatten_coords].cpu().numpy(),
Expand Down Expand Up @@ -212,10 +226,12 @@ def log_val_images(self, pred_img, x, y_fc, y_real, mag_min, mag_max):
def validation_epoch_end(self, outputs):
val_loss = [o['val_loss'] for o in outputs]
val_mse = [o['val_mse'] for o in outputs]
val_psnr = [o['val_psnr'] for o in outputs]
bin_mse = [o['bin_mse'] for o in outputs]
amp_loss = [d['amp_loss'] for d in outputs]
phi_loss = [d['phi_loss'] for d in outputs]
mean_val_mse = torch.mean(torch.stack(val_mse))
mean_val_psnr = torch.mean(torch.stack(val_psnr))
mean_bin_mse = torch.mean(torch.stack(bin_mse))
if self.bin_count > self.hparams.bin_factor_cd and mean_val_mse < (
self.hparams.alpha * mean_bin_mse) and self.bin_factor > 1:
Expand All @@ -224,10 +240,14 @@ def validation_epoch_end(self, outputs):
self.register_buffer('mask', psfft(self.bin_factor, pixel_res=self.hparams.img_shape).to(self.device))
print('Reduced bin_factor to {}.'.format(self.bin_factor))

if self.bin_factor > 1:
self.trainer.lr_schedulers[0]['scheduler']._reset()

self.bin_count += 1

self.log('Train/avg_val_loss', torch.mean(torch.stack(val_loss)), logger=True, on_epoch=True)
self.log('Train/avg_val_mse', mean_val_mse, logger=True, on_epoch=True)
self.log('Train/avg_val_psnr', mean_val_psnr, logger=True, on_epoch=True)
self.log('Train/avg_bin_mse', mean_bin_mse, logger=True, on_epoch=True)
self.log('Train/avg_val_amp_loss', torch.mean(torch.stack(amp_loss)), logger=True, on_epoch=True)
self.log('Train/avg_val_phi_loss', torch.mean(torch.stack(phi_loss)), logger=True, on_epoch=True)
Expand All @@ -241,15 +261,14 @@ def test_step(self, batch, batch_idx):
x_fc_, out_pos_emb, y_fc_ = self._bin_data(x, y)

_, pred_img = self.trec.forward(x_fc_, out_pos_emb, mag_min=mag_min, mag_max=mag_max,
dst_flatten_coords=self.dst_flatten_coords,
img_shape=self.hparams.img_shape,
attenuation=self.mask)

dst_flatten_coords=self.dst_flatten_coords,
img_shape=self.hparams.img_shape,
attenuation=self.mask)

gt = denormalize(y_real[0], self.trainer.datamodule.mean, self.trainer.datamodule.std)
pred_img = denormalize(pred_img[0], self.trainer.datamodule.mean, self.trainer.datamodule.std)

return PSNR(self.circle * gt, self.circle * pred_img, drange=torch.tensor(255., dtype=torch.float32))
return PSNR(self.circle * gt, self.circle * pred_img, drange=gt.max()-gt.min())

def test_epoch_end(self, outputs):
outputs = torch.stack(outputs)
Expand Down
2 changes: 1 addition & 1 deletion fit/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.9'
__version__ = '0.1.10'
13 changes: 13 additions & 0 deletions scripts/LoDoPaB_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"batch_size": 2,
"num_angles": 40,
"n_heads": 8,
"d_query": 32,
"init_bin_factor": 8,
"bin_factor_cd": 5,
"alpha": 1.5,
"lr": 0.0001,
"attention_type": "linear",
"n_layers": 4,
"max_epochs": 1000
}
107 changes: 107 additions & 0 deletions scripts/TRec_LoDoPaB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import argparse
import glob
import json
from os.path import exists

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

from fit.datamodules.tomo_rec import MNISTTomoFourierTargetDataModule
from fit.datamodules.tomo_rec.TRecDataModule import LoDoPaBFourierTargetDataModule
from fit.modules import TRecTransformerModule
from fit.utils.tomo_utils import get_proj_coords, get_img_coords


def main():
seed_everything(22122020)

parser = argparse.ArgumentParser(description="")
parser.add_argument("--exp_config")

args = parser.parse_args()

with open(args.exp_config) as f:
conf = json.load(f)

dm = LoDoPaBFourierTargetDataModule(batch_size=conf['batch_size'],
num_angles=conf['num_angles'])
dm.setup()

det_len = dm.gt_ds.get_ray_trafo().geometry.detector.shape[0]

proj_xcoords, proj_ycoords, src_flatten = get_proj_coords(angles=dm.gt_ds.get_ray_trafo().geometry.angles,
det_len=det_len)
target_xcoords, target_ycoords, dst_flatten, order = get_img_coords(img_shape=dm.IMG_SHAPE, det_len=det_len)

model = TRecTransformerModule(d_model=conf['n_heads'] * conf['d_query'],
y_coords_proj=proj_ycoords, x_coords_proj=proj_xcoords,
y_coords_img=target_ycoords, x_coords_img=target_xcoords,
src_flatten_coords=src_flatten, dst_flatten_coords=dst_flatten,
dst_order=order,
angles=dm.gt_ds.get_ray_trafo().geometry.angles, img_shape=dm.IMG_SHAPE,
detector_len=det_len,
init_bin_factor=conf['init_bin_factor'], bin_factor_cd=conf['bin_factor_cd'],
alpha=conf['alpha'],
lr=conf['lr'], weight_decay=0.01,
attention_type=conf['attention_type'], n_layers=conf['n_layers'],
n_heads=conf['n_heads'], d_query=conf['d_query'], dropout=0.1, attention_dropout=0.1)

if exists('lightning_logs'):
print('Some experiments already exist. Abort.')
return 0

trainer = Trainer(max_epochs=conf['max_epochs'],
gpus=1,
checkpoint_callback=ModelCheckpoint(
filepath=None,
save_top_k=1,
verbose=False,
save_last=True,
monitor='Train/avg_val_mse',
mode='min',
prefix='best_val_loss_'
),
deterministic=True)

trainer.fit(model, datamodule=dm);

model = TRecTransformerModule.load_from_checkpoint('lightning_logs/version_0/checkpoints/best_val_loss_-last.ckpt',
y_coords_proj=model.y_coords_proj,
x_coords_proj=model.x_coords_proj,
y_coords_img=model.y_coords_img,
x_coords_img=model.x_coords_img,
angles=model.angles,
src_flatten_coords=model.src_flatten_coords,
dst_flatten_coords=model.dst_flatten_coords,
dst_order=model.dst_order)

test_res = trainer.test(model, datamodule=dm)[0]
out_res = {
"Mean PSNR": test_res["Mean PSNR"].item(),
"SEM PSNR": test_res["SEM PSNR"].item()
}
with open('last_ckpt_results.json', 'w') as f:
json.dump(out_res, f)

best_path = glob.glob('lightning_logs/version_0/checkpoints/best_val_loss_-epoch*')[0]
model = TRecTransformerModule.load_from_checkpoint(best_path,
y_coords_proj=model.y_coords_proj,
x_coords_proj=model.x_coords_proj,
y_coords_img=model.y_coords_img,
x_coords_img=model.x_coords_img,
angles=model.angles,
src_flatten_coords=model.src_flatten_coords,
dst_flatten_coords=model.dst_flatten_coords,
dst_order=model.dst_order)

test_res = trainer.test(model, datamodule=dm)[0]
out_res = {
"Mean PSNR": test_res["Mean PSNR"].item(),
"SEM PSNR": test_res["SEM PSNR"].item()
}
with open('best_ckpt_results.json', 'w') as f:
json.dump(out_res, f)


if __name__ == "__main__":
main()
Loading

0 comments on commit 6eaf6af

Please sign in to comment.