Skip to content

Commit

Permalink
Add wrapper_flexible_args decorator to squeeze_multiple_dims and zsco…
Browse files Browse the repository at this point in the history
…re functions
  • Loading branch information
RichieHakim committed Mar 27, 2024
1 parent c8ae72f commit 1a3e55e
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ def unravel_index(
index = index // dim
return tuple(out[::-1])

@misc.wrapper_flexible_args(['dim', 'axis'])
def squeeze_multiple_dims(
arr: torch.Tensor,
dims: Tuple[int, int] = (0, 1)
Expand Down Expand Up @@ -814,12 +815,12 @@ def __call__(self, x: torch.sparse.FloatTensor) -> torch.Tensor:
return values[self.idx]


@misc.wrapper_flexible_args(['dim', 'axis'])
def zscore(
X: torch.Tensor,
dim: Optional[int] = None,
ddof: int = 0,
nan_policy: str = 'propagate',
axis: Optional[int] = None
) -> torch.Tensor:
"""
Computes the z-score of a tensor.
Expand All @@ -842,18 +843,13 @@ def zscore(
* ``'raise'``: throws an error
* ``'omit'``: performs the calculations ignoring nan values \n
(Default is ``'propagate'``)
axis (Optional[int]):
Axis or axes along which the z-score is computed.
The default is to compute the z-score of the flattened array.
(Default is ``None``)
Returns:
(torch.Tensor):
Z-scored tensor.
RH 2022
"""
dim = axis if (axis is not None) and (dim is None) else dim
assert dim is not None, 'Must specify dimension to compute z-score over.'

if nan_policy == 'omit':
Expand Down

0 comments on commit 1a3e55e

Please sign in to comment.