Skip to content

Commit

Permalink
Merge pull request #15 from juglab/merger
Browse files Browse the repository at this point in the history
Merger
  • Loading branch information
tibuch authored Mar 1, 2021
2 parents aa6394e + 3fe1e8c commit 25ec06f
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 46 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Build Python package:
`python setup.py bdist_wheel`

Build singularity recipe:
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.18-py3-none-any.whl /fourier_image_transformers-0.1.18-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.18-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.18.Singularity`
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.19-py3-none-any.whl /fourier_image_transformers-0.1.19-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.19-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.19.Singularity`

Build singularity container:
`sudo singularity build fit_v0.1.18.simg v0.1.18.Singularity`
`sudo singularity build fit_v0.1.19.simg v0.1.19.Singularity`
6 changes: 3 additions & 3 deletions fit/datamodules/GroundTruthDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def __init__(self, train_gt_images, val_gt_images, test_gt_images, inner_circle=
self.val_gt_images = val_gt_images
self.test_gt_images = test_gt_images
assert self.train_gt_images.shape[1] == self.train_gt_images.shape[2], 'Train images are not square.'
assert self.train_gt_images.shape[1] % 2 == 1, 'Train image size has to be odd.'
# assert self.train_gt_images.shape[1] % 2 == 1, 'Train image size has to be odd.'
assert self.val_gt_images.shape[1] == self.val_gt_images.shape[2], 'Val images are not square.'
assert self.val_gt_images.shape[1] % 2 == 1, 'Val image size has to be odd.'
# assert self.val_gt_images.shape[1] % 2 == 1, 'Val image size has to be odd.'
assert self.test_gt_images.shape[1] == self.test_gt_images.shape[2], 'Test images are not square.'
assert self.test_gt_images.shape[1] % 2 == 1, 'Test image size has to be odd.'
# assert self.test_gt_images.shape[1] % 2 == 1, 'Test image size has to be odd.'

self.shape = (self.train_gt_images.shape[1], self.train_gt_images.shape[2])
if inner_circle:
Expand Down
15 changes: 7 additions & 8 deletions fit/datamodules/super_res/SRecDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class MNISTSResFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 27
IMG_SHAPE = 28

def __init__(self, root_dir, batch_size):
"""
Expand All @@ -38,9 +38,8 @@ def setup(self, stage: Optional[str] = None):
mnist_train_val = MNIST(self.root_dir, train=True, download=True).data.type(torch.float32)
np.random.seed(1612)
perm = np.random.permutation(mnist_train_val.shape[0])
mnist_train = mnist_train_val[perm[:55000], 1:, 1:]
mnist_val = mnist_train_val[perm[55000:], 1:, 1:]
mnist_test = mnist_test[:, 1:, 1:]
mnist_train = mnist_train_val[perm[:55000]]
mnist_val = mnist_train_val[perm[55000:]]

assert mnist_train.shape[1] == MNISTSResFourierTargetDataModule.IMG_SHAPE
assert mnist_train.shape[2] == MNISTSResFourierTargetDataModule.IMG_SHAPE
Expand All @@ -66,13 +65,13 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=MNISTSResFourierTargetDataModule.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=MNISTSResFourierTargetDataModule.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
Expand Down Expand Up @@ -123,13 +122,13 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
Expand Down
16 changes: 8 additions & 8 deletions fit/datamodules/tomo_rec/TRecDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
Expand Down Expand Up @@ -219,13 +219,13 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
Expand Down Expand Up @@ -288,13 +288,13 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
Expand Down Expand Up @@ -360,13 +360,13 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=2)
batch_size=self.batch_size, num_workers=1)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
Expand Down
2 changes: 1 addition & 1 deletion fit/modules/SResTransformerModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, d_model, img_shape=27,
self.dft_shape = (img_shape, img_shape // 2 + 1)

self.sres = SResTransformer(d_model=self.hparams.d_model,
y_coords_img=self.y_coords_img, x_coords_img=self.x_coords_img,
y_coords_img=self.y_coords_img, x_coords_img=self.x_coords_img, flatten_order=self.dst_flatten_coords,
attention_type='causal-linear',
n_layers=self.hparams.n_layers,
n_heads=self.hparams.n_heads,
Expand Down
16 changes: 11 additions & 5 deletions fit/modules/TRecTransformerModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords
self.register_buffer('mask', psfft(self.bin_factor, pixel_res=img_shape))

self.trec = TRecTransformer(d_model=self.hparams.d_model,
y_coords_proj=y_coords_proj, x_coords_proj=x_coords_proj,
y_coords_img=y_coords_img, x_coords_img=x_coords_img,
y_coords_proj=y_coords_proj, x_coords_proj=x_coords_proj, flatten_proj=self.src_flatten_coords,
y_coords_img=y_coords_img, x_coords_img=x_coords_img, flatten_img=self.dst_flatten_coords,
attention_type=self.hparams.attention_type,
n_layers=self.hparams.n_layers,
n_heads=self.hparams.n_heads,
Expand Down Expand Up @@ -90,7 +90,9 @@ def configure_optimizers(self):
def _real_loss(self, pred_img, target_fc, mag_min, mag_max):
dft_target = convert_to_dft(fc=target_fc, mag_min=mag_min, mag_max=mag_max,
dst_flatten_coords=self.dst_flatten_coords, img_shape=self.hparams.img_shape)
dft_target *= self.mask
if self.bin_factor > 1:
dft_target *= self.mask

y_target = torch.roll(torch.fft.irfftn(dft_target, dim=[1, 2], s=2 * (self.hparams.img_shape,)),
2 * (self.hparams.img_shape // 2,), (1, 2))
return F.mse_loss(pred_img, y_target)
Expand All @@ -116,10 +118,14 @@ def criterion(self, pred_fc, pred_img, target_fc, mag_min, mag_max):
def _bin_data(self, x_fc, y_fc):
shells = (self.hparams.detector_len // 2 + 1) / self.bin_factor
num_sino_fcs = np.clip(self.num_angles * int(shells + 1), 1, x_fc.shape[1])
num_target_fcs = np.sum(self.dst_order <= shells)

if self.bin_factor > 1:
num_target_fcs = np.sum(self.dst_order <= shells)
else:
num_target_fcs = self.trec.decoder_input.shape[1]

x_fc_ = x_fc[:, self.src_flatten_coords][:, :num_sino_fcs]
out_pos_emb = self.trec.pos_embedding_target.pe[:, :num_target_fcs]
out_pos_emb = self.trec.decoder_input[:, :num_target_fcs]
y_fc_ = y_fc[:, self.dst_flatten_coords][:, :num_target_fcs]

return x_fc_, out_pos_emb, y_fc_
Expand Down
5 changes: 3 additions & 2 deletions fit/transformers/PositionalEncoding2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@


class PositionalEncoding2D(torch.nn.Module):
def __init__(self, d_model, y_coords, x_coords, dropout=0.0, persistent=False):
def __init__(self, d_model, y_coords, x_coords, flatten_order, dropout=0.0, persistent=False):
super(PositionalEncoding2D, self).__init__()
self.dropout = torch.nn.Dropout(p=dropout)
self.d_model = d_model

pe = self.positional_encoding_2D(self.d_model, y_coords, x_coords)
pe = pe.reshape(-1, pe.shape[0]).unsqueeze(0)
pe = torch.movedim(pe, 0, -1).unsqueeze(0)
pe = pe[:, flatten_order]

self.register_buffer('pe', pe, persistent=persistent)

Expand Down
3 changes: 2 additions & 1 deletion fit/transformers/SResTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class SResTransformer(torch.nn.Module):
def __init__(self,
d_model,
y_coords_img, x_coords_img,
y_coords_img, x_coords_img, flatten_order,
attention_type="linear",
n_layers=4,
n_heads=4,
Expand All @@ -23,6 +23,7 @@ def __init__(self,
d_model // 2,
y_coords_img,
x_coords_img,
flatten_order=flatten_order,
persistent=False
)

Expand Down
20 changes: 11 additions & 9 deletions fit/transformers/TRecTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class TRecTransformer(torch.nn.Module):
def __init__(self,
d_model,
y_coords_proj, x_coords_proj,
y_coords_img, x_coords_img,
y_coords_proj, x_coords_proj, flatten_proj,
y_coords_img, x_coords_img, flatten_img,
attention_type="linear",
n_layers=4,
n_heads=4,
Expand All @@ -25,6 +25,7 @@ def __init__(self,
d_model // 2,
y_coords_proj,
x_coords_proj,
flatten_order=flatten_proj,
persistent=False
)

Expand All @@ -39,8 +40,9 @@ def __init__(self,
attention_dropout=attention_dropout
).get()

self.pos_embedding_target = PositionalEncoding2D(d_model, y_coords_img, x_coords_img)

self.pos_embedding_target = PositionalEncoding2D(d_model // 2, y_coords_img, x_coords_img, flatten_order=flatten_img)
decoder_input = torch.cat([torch.rand(self.pos_embedding_target.pe.shape), self.pos_embedding_target.pe], dim=2)
self.register_buffer('decoder_input', decoder_input, persistent=True)
self.decoder = TransformerDecoderBuilder.from_kwargs(
self_attention_type=attention_type,
cross_attention_type=attention_type,
Expand All @@ -59,13 +61,13 @@ def __init__(self,
)

self.conv_block = torch.nn.Sequential(
torch.nn.Conv2d(1, d_query, kernel_size=3, stride=1, padding=1),
torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(d_query),
torch.nn.Conv2d(d_query, d_query, kernel_size=3, stride=1, padding=1),
torch.nn.BatchNorm2d(32),
torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(d_query),
torch.nn.Conv2d(d_query, 1, kernel_size=1, stride=1, padding=0)
torch.nn.BatchNorm2d(32),
torch.nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0)
)

def forward(self, x, out_pos_emb, mag_min, mag_max, dst_flatten_coords, img_shape, attenuation):
Expand Down
29 changes: 24 additions & 5 deletions fit/utils/tomo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@ def get_detector_length(proj_space):
return num_px_horiz


def get_proj_coords(angles, det_len):
def get_proj_coords_pol(angles, det_len):
tmp = det_len // 2 + 1
a = np.rad2deg(-angles + np.pi / 2.)
r = np.arange(0, tmp)
r, a = np.meshgrid(r, a)
flatten_indices = np.argsort(r.flatten())
r = r.flatten()[flatten_indices]
a = a.flatten()[flatten_indices]
xcoords = r * np.cos(np.deg2rad(a))
ycoords = (tmp) + r * np.sin(np.deg2rad(a)) - 1
return torch.from_numpy(xcoords), torch.from_numpy(ycoords), flatten_indices
return torch.from_numpy(r), torch.from_numpy(np.deg2rad(a)), flatten_indices


def get_img_coords(img_shape, det_len):
def get_proj_coords_cart(angles, det_len):
r, a, flatten_indices = get_proj_coords_pol(angles, det_len)
xcoords = r * torch.cos(a)
ycoords = (det_len // 2) + r * torch.sin(a)
return xcoords, ycoords, flatten_indices


def get_img_coords_cart(img_shape, det_len):
xcoords, ycoords = np.meshgrid(np.linspace(0, det_len // 2, num=img_shape // 2 + 1, endpoint=True),
np.concatenate([np.linspace(0, det_len // 2, img_shape // 2, False),
np.linspace(det_len // 2, det_len - 1, img_shape // 2 + 1)]))
Expand All @@ -43,3 +48,17 @@ def get_img_coords(img_shape, det_len):
xcoords = xcoords.flatten()[flatten_indices]
ycoords = ycoords.flatten()[flatten_indices]
return torch.from_numpy(xcoords), torch.from_numpy(ycoords), flatten_indices, order


def get_img_coords_pol(img_shape, det_len):
xcoords, ycoords, flatten_indices, order = get_img_coords_cart(img_shape, det_len)
ycoords -= img_shape // 2
r = torch.sqrt(xcoords ** 2 + ycoords ** 2)
phi = torch.atan2(ycoords, xcoords)
return r, phi, flatten_indices, order


def pol2cart(rho, phi):
x = rho * torch.cos(phi)
y = rho * torch.sin(phi)
return (x, y)
2 changes: 1 addition & 1 deletion fit/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.18'
__version__ = '0.1.19'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"tifffile",
"tqdm",
"pytorch-fast-transformers",
"dival",
"dival<0.6.0",
"pytorch-lightning",
"jupyter"
]
Expand Down
Loading

0 comments on commit 25ec06f

Please sign in to comment.