From e5b928d79e09ea45b94b26a31edf177698c99fde Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Wed, 28 Aug 2024 04:59:03 -0400 Subject: [PATCH] typing --- i6_models/parts/fsa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/fsa.py index 7ae0328e..f4b2df22 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/fsa.py @@ -3,7 +3,7 @@ __all__ = ["TorchFsaBuilder", "WeightedFsa"] from functools import reduce -from typing import Iterable, NamedTuple, Tuple, TypeVar +from typing import Iterable, NamedTuple, Tuple, Union import numpy as np import torch @@ -37,7 +37,7 @@ def __mul__(self, scale: float) -> WeightedFsa: self.start_end_states, ) - def to(self, device: torch.device) -> WeightedFsa: + def to(self, device: Union[str, torch.device]) -> WeightedFsa: """Move the tensors to a given device. This wraps around the PyTorch `Tensor.to(device)` method.""" return WeightedFsa._make(tensor.to(device) for tensor in self)