-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fftshift for FourierOp #143
Conversation
Coverage Report
|
I tested the following cases that worked fine:
I tried to test KTrajectoryPulseq as well, but it is not working at all currently even after fixing #141 and #144 |
import torch.nn.functional as F | ||
|
||
|
||
def change_data_shape(dat: torch.Tensor, dat_shape_new: tuple[int, ...]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of possible issues:
- name sounds like a reshape operation.
- consider data instead of data
- consider dat_shape_new -> new_shape (or I might call you yoda)
- what if data has a different number of dimensions than the new shape?
def change_data_shape(dat: torch.Tensor, dat_shape_new: tuple[int, ...]) -> torch.Tensor: | |
def pad_or_crop(data: torch.Tensor, new_shape: tuple[int, ...]|torch.Size) -> torch.Tensor: | |
"""Change shape of data by cropping or zero-padding. | |
Parameters | |
---------- | |
data | |
data | |
new_shape | |
target shape of data | |
Returns | |
------- | |
data padded or cropped to shape | |
""" | |
if len(new_shape) != data.ndim: | |
raise ValueError("length of shape should match dimensions of data") | |
npad=[] | |
for old, new in zip(data.shape, new_shape): | |
diff = new-old | |
# ensures padding and cropping lead to the matching asymmetry for odd shape differences | |
before = math.trunc(diff/2) | |
after = diff - before | |
npad.append(before) | |
npad.append(after) | |
# npad has to be reversed because pad expects it in reversed order | |
if any(npad): | |
data = F.pad(data, npad[::-1]) | |
return data |
Or, including the along_dimension functionality:
def change_data_shape(dat: torch.Tensor, dat_shape_new: tuple[int, ...]) -> torch.Tensor: | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import math | |
def normalize_index(ndim: int, index:int): | |
"""Normlize possibly negative indices | |
Parameters | |
---------- | |
ndim | |
number of dimensions | |
index | |
index to normalize. negative indices count from the end. | |
Raises | |
------ | |
IndexError | |
if index is outside [-ndim,ndim) | |
""" | |
if 0<index<ndim: | |
return index | |
elif -ndim<=index<0: | |
return ndim + index | |
else: | |
raise IndexError(f"Invalid index {index} for {ndim} data dimensions") | |
def pad_or_crop(data: torch.Tensor, new_shape: tuple[int, ...]|torch.Size, dim:None|tuple[int,...]=None) -> torch.Tensor: | |
"""Change shape of data by cropping or zero-padding. | |
Parameters | |
---------- | |
data | |
data | |
new_shape | |
desired shape of data | |
dim: | |
dimensions the new_shape corropsoends to. None (default) is interpreted as last len(new_shape) dimensions. | |
Returns | |
------- | |
data padded or cropped to shape | |
""" | |
if len(new_shape) > data.ndim: | |
raise ValueError("length of new shape should not exceed dimensions of data") | |
if dim is None: # Use last dimensions | |
new_shape = data.shape[:-len(new_shape)] + new_shape | |
else: | |
if len(new_shape) != len(dim): | |
raise ValueError("length of shape should match length of dim") | |
dim = tuple(normalize_index(idx) for idx in dim) # raises if any not in [-data.ndim,data.ndim) | |
if len(dim)!=len(set(dim)): # this is why we normalize | |
raise ValueError("repeated values are not allowed in dims") | |
(new_shape:=torch.tensor(data.shape))[dim]=new_shape | |
npad=[] | |
for old, new in zip(data.shape, new_shape): | |
diff = new-old | |
after = math.trunc(diff/2) | |
before = diff - after | |
npad.append(before) | |
npad.append(after) | |
if any(npad): | |
# F.pad expects paddings in reversed order | |
data = F.pad(data, npad[::-1]) | |
return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I incorporated your second suggestion with the dimension option to be able to use it direclty in the PadOp
dim = tuple(normalize_index(data.ndim, idx) for idx in dim) # raises if any not in [-data.ndim,data.ndim) | ||
if len(dim) != len(set(dim)): # this is why we normalize | ||
raise ValueError('repeated values are not allowed in dims') | ||
new_shape_full = torch.tensor(data.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this a tensor if you don't use tuple indexing in the next line?
if len(dim) != len(set(dim)): # this is why we normalize | ||
raise ValueError('repeated values are not allowed in dims') | ||
new_shape_full = torch.tensor(data.shape) | ||
for i, d in enumerate(dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason for this loop instead of just indexing with dims?
new_shape_full[d] = new_shape[i] | ||
|
||
npad = [] | ||
for old, new in zip(torch.tensor(data.shape), new_shape_full): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are the shapes tensors anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise the difference in the line below does not work
the padding is broken for dims=None. |
* FastFourierOp and PadOp added
No description provided.