diff --git a/bnpm/torch_helpers.py b/bnpm/torch_helpers.py index 21bce82..685285d 100644 --- a/bnpm/torch_helpers.py +++ b/bnpm/torch_helpers.py @@ -893,6 +893,7 @@ def slice_along_dim( ) -> torch.Tensor: """ Slices a tensor along a specified dimension. + RH 2022 Args: X (torch.Tensor): @@ -906,8 +907,6 @@ def slice_along_dim( (torch.Tensor): sliced_tensor (torch.Tensor): Sliced tensor. - - RH 2022 """ slices = [slice(None)] * X.ndim slices[dim] = idx