Skip to content

Commit

Permalink
Adapt string conversion to recent updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
nsbgn committed May 17, 2022
1 parent 1ad0737 commit 8145da3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 22 deletions.
13 changes: 0 additions & 13 deletions tests/test_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,6 @@ def test_taxonomy_with_parameterized_type_alias(self):
F(B, B): set(),
}

def test_string_schematic_type(self):
"""
Test that schematic types are printed with the names of their schematic
variables.
"""
A = TypeOperator()
B = TypeOperator()
TypeOperator(supertype=A)
f = Operator(type=lambda x: x [x << {A, B}])
lang = Language(scope=locals())

self.assertEqual(str(f.type), "x | x[A, B]")

def test_parse_inline_typing(self):
A = TypeOperator()
x = Operator(type=A)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,29 @@ def test_curried_function_signature_same_as_uncurried(self):
(A1, A2) ** A
)

def test_printing(self):
# Schematic types are printed with the names of their variables
f = TypeSchema(lambda foo: foo)
self.assertEqual(str(f), "foo")

# Subtype constraints
A = TypeOperator('A')
f = TypeSchema(lambda x: x [x <= A])
g = TypeSchema(lambda x: x [A <= x])
self.assertEqual(str(f), "x [x <= A]")
self.assertEqual(str(g), "x [A <= x]")

# Elimination constraints
B = TypeOperator('B')
f = TypeSchema(lambda x: x [x << (A, B)])
self.assertEqual(str(f), "x [x << (A, B)]")

# Subtype bounds printed using same notation as elimination constraints
f = TypeSchema(lambda x: x [x << A])
g = TypeSchema(lambda x: x [A << x])
self.assertEqual(str(f), "x [x << A]")
self.assertEqual(str(g), "A") # this subtype bound should be fixed


if __name__ == '__main__':
unittest.main()
18 changes: 9 additions & 9 deletions transformation_algebra/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self, schema: Callable[..., TypeInstance]):
def __str__(self) -> str:
names = signature(self.schema).parameters
variables = [TypeVariable() for _ in names]
return self.schema(*variables).text(
return self.schema(*variables).fix(prefer_lower=True).text(
with_constraints=True,
labels={v: k for k, v in zip(names, variables)})

Expand Down Expand Up @@ -314,18 +314,18 @@ def text(self,
assert result

if with_constraints:
result_aux = [result]
result_aux = []
for v in self.variables():
if v.lower:
result_aux.append(f"{v.text(*args)} {v.lower}")
result_aux.append(f"{v.lower} << {v.text(*args)}")
if v.upper:
result_aux.append(f"{v.text(*args)} {v.upper}")
result_aux.append(f"{v.text(*args)} << {v.upper}")

result_aux.extend(
c.text(*args) for c in self.constraints())
return ' | '.join(result_aux)
else:
return result
if result_aux:
result += f" [{', '.join(result_aux)}]"
return result

def fix(self, prefer_lower: bool = True) -> TypeInstance:
"""
Expand Down Expand Up @@ -873,8 +873,8 @@ def __init__(

def text(self, *args, **kwargs) -> str:
return (
f"{self.reference.text(*args, **kwargs)}["
f"{', '.join(a.text(*args, **kwargs) for a in self.alternatives)}]"
f"{self.reference.text(*args, **kwargs)} << ("
f"{', '.join(a.text(*args, **kwargs) for a in self.alternatives)})"
)

def __iter__(self) -> Iterator[TypeInstance]:
Expand Down

0 comments on commit 8145da3

Please sign in to comment.