From 52d1be4368637ee3913e4fe47ea914b1e2e6b9ca Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 9 Apr 2024 10:01:12 -0400 Subject: [PATCH] poc: add Decimal[bits, places] syntax --- vyper/semantics/analysis/utils.py | 4 +-- vyper/semantics/types/__init__.py | 4 ++- vyper/semantics/types/primitives.py | 42 ++++++++++++++++++++++++++--- vyper/semantics/types/utils.py | 2 +- 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 4b751e7406..a330a12eac 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -22,7 +22,7 @@ from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.bytestrings import BytesT, StringT -from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT +from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT, DecimalT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT from vyper.utils import checksum_encode, int_to_fourbytes @@ -303,7 +303,7 @@ def types_from_Constant(self, node): # special handling for bytestrings since their # class objects are in the type map, not the type itself # (worth rethinking this design at some point.) - if t in (BytesT, StringT): + if t in (BytesT, StringT, DecimalT): t = t.from_literal(node) # any more validation which needs to occur diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 59a20dd99f..400ad83880 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -9,7 +9,7 @@ def _get_primitive_types(): - res = [BoolT(), DecimalT()] + res = [BoolT()] res.extend(IntegerT.all()) res.extend(BytesM_T.all()) @@ -21,6 +21,8 @@ def _get_primitive_types(): # note: since bytestrings are parametrizable, the *class* objects # are in the namespace instead of concrete type objects. res.extend([BytesT, StringT]) + # ditto for Decimals + res.append(DecimalT) ret = {t._id: t for t in res} ret.update(_get_sequence_types()) diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 4e8af80aac..f841e1e54a 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -10,6 +10,7 @@ CompilerPanic, InvalidLiteral, InvalidOperation, + InvalidType, UnimplementedException, OverflowException, VyperException, ) @@ -313,9 +314,7 @@ def SINT(bits): class DecimalT(NumericT): typeclass = "decimal" - _bits = 168 # TODO generalize - _decimal_places = 10 # TODO generalize - _id = "decimal" + _id = "Decimal" _is_signed = True _invalid_ops = ( vy_ast.Pow, @@ -331,6 +330,43 @@ class DecimalT(NumericT): ast_type = Decimal + def __init__(self, bits=168, decimal_places=10): + self._bits = bits + self._decimal_places = decimal_places + + if bits != 168 or decimal_places != 10: + raise UnimplementedException("Not implemented: {repr(self)}", hint="only Decimal[168, 10] is currently available") + + def __repr__(self): + return f"Decimal[{self._bits}, {self._decimal_places}]" + + @classmethod + def from_annotation(cls, node): + def _fail(): + raise InvalidType("not a valid Decimal", hint="expected: Decimal[, None: + if not isinstance(node, vy_ast.Decimal): + # TODO: check bits, places + raise TypeMismatch("Not a decimal") + + @classmethod + def from_literal(cls, node): + return DecimalT(168, 10) + def validate_numeric_op(self, node) -> None: try: super().validate_numeric_op(node) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 93cf85d5f8..7a1f154911 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -101,7 +101,7 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: if isinstance(node, vy_ast.Subscript): # ex. HashMap, DynArray, Bytes, static arrays - if node.value.get("id") in ("HashMap", "Bytes", "String", "DynArray"): + if node.value.get("id") in ("HashMap", "Bytes", "String", "DynArray", "Decimal"): assert isinstance(node.value, vy_ast.Name) # mypy hint type_ctor = namespace[node.value.id] else: