Skip to content

Commit

Permalink
add missing typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Sivan Ravid committed Jan 22, 2024
1 parent 862abc9 commit f6ee01d
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions fuse/data/ops/tests/test_ops_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from typing import Optional, OrderedDict, Union, List
from typing import Optional, OrderedDict, Union, List, Any

import torch

Expand Down Expand Up @@ -281,7 +281,7 @@ def test_op_replace_element(self) -> None:
test op_replace_element on each of possible inputs it supports: string, list, tensor
"""

def _apply_op_replace(input, to_replace, replace_with):
def _apply_op_replace(input: Any, to_replace: Any, replace_with: str) -> Any:
sample_dict = NDict({})
sample_dict["data.input"] = input
sample_dict = op_call(
Expand Down Expand Up @@ -321,7 +321,9 @@ def test_op_replace_any_element(self) -> None:
test op_replace_any_element on each of possible inputs it supports: string, list, tensor
"""

def _apply_op_replace(input, to_replace, replace_with):
def _apply_op_replace_any(
input: Any, to_replace: Any, replace_with: str
) -> Any:
sample_dict = NDict({})
sample_dict["data.input"] = input
sample_dict = op_call(
Expand All @@ -341,19 +343,19 @@ def _apply_op_replace(input, to_replace, replace_with):
replace_with = "X"
expected_result = "aXXXXa"

res = _apply_op_replace(input, to_replace, replace_with)
res = _apply_op_replace_any(input, to_replace, replace_with)
self.assertEqual(res, expected_result)

# input is a list
input = [1, 2, 3, 2, 1]
to_replace = [2, 3]
replace_with = 0
expected_result = [1, 0, 0, 0, 1]
res = _apply_op_replace(input, to_replace, replace_with)
res = _apply_op_replace_any(input, to_replace, replace_with)
self.assertEqual(res, expected_result)

# input is a tensor
res = _apply_op_replace(torch.tensor(input), to_replace, replace_with)
res = _apply_op_replace_any(torch.tensor(input), to_replace, replace_with)
self.assertTrue(torch.equal(res, torch.tensor(expected_result)))


Expand Down

0 comments on commit f6ee01d

Please sign in to comment.