Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Mann committed Aug 28, 2024
1 parent a81a16b commit e5b928d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions i6_models/parts/fsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__all__ = ["TorchFsaBuilder", "WeightedFsa"]

from functools import reduce
from typing import Iterable, NamedTuple, Tuple, TypeVar
from typing import Iterable, NamedTuple, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -37,7 +37,7 @@ def __mul__(self, scale: float) -> WeightedFsa:
self.start_end_states,
)

def to(self, device: torch.device) -> WeightedFsa:
def to(self, device: Union[str, torch.device]) -> WeightedFsa:
"""Move the tensors to a given device. This wraps around the
PyTorch `Tensor.to(device)` method."""
return WeightedFsa._make(tensor.to(device) for tensor in self)
Expand Down

0 comments on commit e5b928d

Please sign in to comment.