diff --git a/nada_dsl/nada_types/__init__.py b/nada_dsl/nada_types/__init__.py index 9aff870..4bd2290 100644 --- a/nada_dsl/nada_types/__init__.py +++ b/nada_dsl/nada_types/__init__.py @@ -113,13 +113,14 @@ def is_numeric(self) -> bool: return self in (BaseType.INTEGER, BaseType.UNSIGNED_INTEGER) +# TODO: make this abstract? @dataclass -class NadaType: - """Nada type class. +class DslType: + """DSL type class. - This is the parent class of all nada types. + This is the parent class of all DSL types. - In Nada, all the types wrap Operations. For instance, an addition between two integers + In the DSL, all the types wrap Operations. For instance, an addition between two integers is represented like this SecretInteger(child=Addition(...)). In MIR, the representation is based around operations. A MIR operation points to other @@ -145,7 +146,7 @@ def __init__(self, child: OperationType): """ self.child = child if self.child is not None: - self.child.store_in_ast(self.metatype().to_mir()) + self.child.store_in_ast(self.type().to_mir()) def __bool__(self): raise NotImplementedError @@ -161,5 +162,5 @@ def is_literal(cls) -> bool: return False @abstractmethod - def metatype(self): + def type(self): """Returns a meta type for this NadaType.""" diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index 0f4c7f4..06d67b3 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -14,7 +14,7 @@ ReduceASTOperation, UnaryASTOperation, ) -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType # Wildcard import due to non-zero types from nada_dsl.nada_types.scalar_types import * # pylint: disable=W0614:wildcard-import @@ -108,12 +108,12 @@ def store_in_ast(self, ty): ) -class TupleMetaType(MetaType): +class TupleType(DslType): """Marker type for Tuples.""" is_compound = True - def __init__(self, left_type: MetaType, right_type: MetaType): + def __init__(self, left_type: DslType, right_type: DslType): self.left_type = left_type self.right_type = right_type @@ -131,7 +131,7 @@ def to_mir(self): @dataclass -class Tuple(Generic[T, U], NadaType): +class Tuple(Generic[T, U], DslType): """The Tuple type""" left_type: T @@ -144,11 +144,11 @@ def __init__(self, child, left_type: T, right_type: U): super().__init__(self.child) @classmethod - def new(cls, left_value: NadaType, right_value: NadaType) -> "Tuple[T, U]": + def new(cls, left_value: DslType, right_value: DslType) -> "Tuple[T, U]": """Constructs a new Tuple.""" return Tuple( - left_type=left_value.metatype(), - right_type=right_value.metatype(), + left_type=left_value.type(), + right_type=right_value.type(), child=TupleNew( child=(left_value, right_value), source_ref=SourceRef.back_frame(), @@ -156,27 +156,27 @@ def new(cls, left_value: NadaType, right_value: NadaType) -> "Tuple[T, U]": ) @classmethod - def generic_type(cls, left_type: U, right_type: T) -> TupleMetaType: + def generic_type(cls, left_type: U, right_type: T) -> TupleType: """Returns the generic type for this Tuple""" - return TupleMetaType(left_type=left_type, right_type=right_type) + return TupleType(left_type=left_type, right_type=right_type) - def metatype(self): - """Metatype for Tuple""" - return TupleMetaType(self.left_type, self.right_type) + def type(self): + """Type for Tuple""" + return TupleType(self.left_type, self.right_type) -def _generate_accessor(ty: Any, accessor: Any) -> NadaType: +def _generate_accessor(ty: Any, accessor: Any) -> DslType: if hasattr(ty, "ty") and ty.ty.is_literal(): # TODO: fix raise TypeError("Literals are not supported in accessors") return ty.instantiate(accessor) -class NTupleMetaType(MetaType): +class NTupleType(DslType): """Marker type for NTuples.""" is_compound = True - def __init__(self, types: List[MetaType]): + def __init__(self, types: List[DslType]): self.types = types def instantiate(self, child_or_value): @@ -192,7 +192,7 @@ def to_mir(self): @dataclass -class NTuple(NadaType): +class NTuple(DslType): """The NTuple type""" types: List[Any] @@ -205,7 +205,7 @@ def __init__(self, child, types: List[Any]): @classmethod def new(cls, values: List[Any]) -> "NTuple": """Constructs a new NTuple.""" - types = [value.metatype() for value in values] + types = [value.type() for value in values] return NTuple( types=types, child=NTupleNew( @@ -214,7 +214,7 @@ def new(cls, values: List[Any]) -> "NTuple": ), ) - def __getitem__(self, index: int) -> NadaType: + def __getitem__(self, index: int) -> DslType: if index >= len(self.types): raise IndexError(f"Invalid index {index} for NTuple.") @@ -226,9 +226,9 @@ def __getitem__(self, index: int) -> NadaType: return _generate_accessor(self.types[index], accessor) - def metatype(self): - """Metatype for NTuple""" - return NTupleMetaType(self.types) + def type(self): + """Type for NTuple""" + return NTupleType(self.types) @dataclass @@ -261,12 +261,12 @@ def store_in_ast(self, ty: object): ) -class ObjectMetaType(MetaType): +class ObjectType(DslType): """Marker type for Objects.""" is_compound = True - def __init__(self, types: Dict[str, MetaType]): + def __init__(self, types: Dict[str, DslType]): self.types = types def to_mir(self): @@ -280,7 +280,7 @@ def instantiate(self, child_or_value): @dataclass -class Object(NadaType): +class Object(DslType): """The Object type""" types: Dict[str, Any] @@ -293,7 +293,7 @@ def __init__(self, child, types: Dict[str, Any]): @classmethod def new(cls, values: Dict[str, Any]) -> "Object": """Constructs a new Object.""" - types = {key: value.metatype() for key, value in values.items()} + types = {key: value.type() for key, value in values.items()} return Object( types=types, child=ObjectNew( @@ -302,7 +302,7 @@ def new(cls, values: Dict[str, Any]) -> "Object": ), ) - def __getattr__(self, attr: str) -> NadaType: + def __getattr__(self, attr: str) -> DslType: if attr not in self.types: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{attr}'" @@ -316,9 +316,9 @@ def __getattr__(self, attr: str) -> NadaType: return _generate_accessor(self.types[attr], accessor) - def metatype(self): - """Metatype for Object""" - return ObjectMetaType(types=self.types) + def type(self): + """Type for Object""" + return ObjectType(types=self.types) @dataclass @@ -412,7 +412,7 @@ def store_in_ast(self, ty: NadaTypeRepr): ) -class ArrayMetaType(MetaType): +class ArrayType(DslType): """Marker type for arrays.""" is_compound = True @@ -428,7 +428,7 @@ def to_mir(self): # and apply the same logic when the function gets passed to .map() or .reduce() # so we now the size of the array if self.size is None: - raise NotImplementedError("ArrayMetaType.to_mir") + raise NotImplementedError("ArrayType.to_mir") return { "Array": { "inner_type": self.contained_type.to_mir(), @@ -441,7 +441,7 @@ def instantiate(self, child_or_value): @dataclass -class Array(Generic[T], NadaType): +class Array(Generic[T], DslType): """Nada Array type. This is the representation of arrays in Nada MIR. @@ -461,14 +461,14 @@ class Array(Generic[T], NadaType): size: int def __init__(self, child, size: int, contained_type: T = None): - self.contained_type = contained_type or child.metatype() + self.contained_type = contained_type or child.type() self.size = size self.child = ( child if contained_type is not None else getattr(child, "child", None) ) if self.child is not None: - self.child.store_in_ast(self.metatype().to_mir()) + self.child.store_in_ast(self.type().to_mir()) def __iter__(self): raise NotAllowedException( @@ -495,9 +495,9 @@ def map(self: "Array[T]", function) -> "Array": def reduce(self: "Array[T]", function, initial: R) -> R: """The Reduce operation for arrays.""" self.check_not_constant(self.contained_type) - self.check_not_constant(initial.metatype()) + self.check_not_constant(initial.type()) function = create_nada_fn( - function, args_ty=[initial.metatype(), self.contained_type] + function, args_ty=[initial.type(), self.contained_type] ) return function.return_type.instantiate( Reduce( @@ -514,7 +514,7 @@ def zip(self: "Array[T]", other: "Array[U]") -> "Array[Tuple[T, U]]": raise IncompatibleTypesError("Cannot zip arrays of different size") return Array( size=self.size, - contained_type=TupleMetaType( + contained_type=TupleType( left_type=self.contained_type, right_type=other.contained_type, ), @@ -551,7 +551,7 @@ def new(cls, *args) -> "Array[T]": raise TypeError("All arguments must be of the same type") return Array( - contained_type=first_arg.metatype(), + contained_type=first_arg.type(), size=len(args), child=ArrayNew( child=args, @@ -559,9 +559,9 @@ def new(cls, *args) -> "Array[T]": ), ) - def metatype(self): - """Metatype for Array""" - return ArrayMetaType(self.contained_type, self.size) + def type(self): + """Type for Array""" + return ArrayType(self.contained_type, self.size) @dataclass @@ -597,10 +597,10 @@ class NTupleNew: Represents the creation of a new Tuple. """ - child: List[NadaType] + child: List[DslType] source_ref: SourceRef - def __init__(self, child: List[NadaType], source_ref: SourceRef): + def __init__(self, child: List[DslType], source_ref: SourceRef): self.id = next_operation_id() self.child = child self.source_ref = source_ref @@ -623,10 +623,10 @@ class ObjectNew: Represents the creation of a new Object. """ - child: Dict[str, NadaType] + child: Dict[str, DslType] source_ref: SourceRef - def __init__(self, child: Dict[str, NadaType], source_ref: SourceRef): + def __init__(self, child: Dict[str, DslType], source_ref: SourceRef): self.id = next_operation_id() self.child = child self.source_ref = source_ref @@ -644,10 +644,10 @@ def store_in_ast(self, ty: object): def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: """The Unzip operation for Arrays.""" - right_type = ArrayMetaType( + right_type = ArrayType( contained_type=array.contained_type.right_type, size=array.size ) - left_type = ArrayMetaType( + left_type = ArrayType( contained_type=array.contained_type.left_type, size=array.size ) @@ -670,7 +670,7 @@ def __init__(self, child: List[T], source_ref: SourceRef): self.child = child self.source_ref = source_ref - def store_in_ast(self, ty: NadaType): + def store_in_ast(self, ty: DslType): """Store this ArrayNew object in the AST.""" AST_OPERATIONS[self.id] = NewASTOperation( id=self.id, diff --git a/nada_dsl/nada_types/function.py b/nada_dsl/nada_types/function.py index 8b08f10..f9454a4 100644 --- a/nada_dsl/nada_types/function.py +++ b/nada_dsl/nada_types/function.py @@ -15,7 +15,7 @@ next_operation_id, ) from nada_dsl.nada_types.generics import T, R -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType class NadaFunctionArg(Generic[T]): @@ -65,7 +65,7 @@ def __init__( function: Callable[[T], R], return_type: R, source_ref: SourceRef, - child: NadaType, + child: DslType, ): self.child = child self.id = function_id @@ -98,7 +98,7 @@ class NadaFunctionCall(Generic[R]): """Represents a call to a Nada Function.""" fn: NadaFunction - args: List[NadaType] + args: List[DslType] source_ref: SourceRef def __init__(self, nada_function, args, source_ref): @@ -106,7 +106,7 @@ def __init__(self, nada_function, args, source_ref): self.args = args self.fn = nada_function self.source_ref = source_ref - self.store_in_ast(nada_function.return_type.metatype().to_mir()) + self.store_in_ast(nada_function.return_type.type().to_mir()) def store_in_ast(self, ty): """Store this function call in the AST.""" @@ -147,7 +147,7 @@ def create_nada_fn(fn, args_ty) -> NadaFunction[T, R]: child = fn(*nada_args_type_wrapped) - return_type = child.metatype() + return_type = child.type() return NadaFunction( function_id, function=fn, diff --git a/nada_dsl/nada_types/generics.py b/nada_dsl/nada_types/generics.py index 6923c60..ab53633 100644 --- a/nada_dsl/nada_types/generics.py +++ b/nada_dsl/nada_types/generics.py @@ -2,8 +2,8 @@ from typing import TypeVar -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType -R = TypeVar("R", bound=NadaType) -T = TypeVar("T", bound=NadaType) -U = TypeVar("U", bound=NadaType) +R = TypeVar("R", bound=DslType) +T = TypeVar("T", bound=DslType) +U = TypeVar("U", bound=DslType) diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index 50d513a..cba1d1c 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -5,10 +5,11 @@ from dataclasses import dataclass from typing import Union, TypeVar from typing_extensions import Self +from nada_dsl.audit.abstract import AbstractBoolean from nada_dsl.operations import * from nada_dsl.program_io import Literal from nada_dsl import SourceRef -from . import NadaType, Mode, BaseType, OperationType +from . import DslType, Mode, BaseType, OperationType # Constant dictionary that stores all the Nada types and is use to # convert from the (mode, base_type) representation to the concrete Nada type @@ -47,7 +48,7 @@ AnyBoolean = Union["Boolean", "PublicBoolean", "SecretBoolean"] -class ScalarType(NadaType): +class ScalarType(DslType): """The Nada Scalar type. This is the super class for all scalar types in Nada. These are: @@ -288,7 +289,7 @@ def public_equals_operation(left: ScalarType, right: ScalarType) -> "PublicBoole ) -class BooleanType(ScalarType): +class AbstractBooleanType(ScalarType): """This abstraction represents all boolean types: - Boolean, PublicBoolean, SecretBoolean It provides common operation implementations for all the boolean types, defined above. @@ -349,7 +350,7 @@ def binary_logical_operation( return SecretBoolean(child=operation) -class MetaType(ABC): +class DslType(ABC): """Abstract meta type""" is_constant = False @@ -365,7 +366,7 @@ def to_mir(self): """Returns a MIR representation of this meta type""" -class MetaTypePassthroughMixin(MetaType): +class TypePassthroughMixin(DslType): """Mixin for meta types""" def instantiate(self, child_or_value): @@ -378,8 +379,8 @@ def to_mir(self): if name.startswith("Public"): name = name[len("Public") :].lstrip() - if name.endswith("MetaType"): - name = name[: -len("MetaType")].rstrip() + if name.endswith("Type"): + name = name[: -len("Type")].rstrip() return name @@ -405,11 +406,11 @@ def __eq__(self, other) -> AnyBoolean: def is_literal(cls) -> bool: return True - def metatype(self): - return IntegerMetaType() + def type(self): + return IntegerType() -class IntegerMetaType(MetaTypePassthroughMixin): +class IntegerType(TypePassthroughMixin): """Meta type for integers""" ty = Integer @@ -442,11 +443,11 @@ def __eq__(self, other) -> AnyBoolean: def is_literal(cls) -> bool: return True - def metatype(self): - return UnsignedIntegerMetaType() + def type(self): + return UnsignedIntegerType() -class UnsignedIntegerMetaType(MetaTypePassthroughMixin): +class UnsignedIntegerType(TypePassthroughMixin): """Meta type for unsigned integers""" ty = UnsignedInteger @@ -455,7 +456,7 @@ class UnsignedIntegerMetaType(MetaTypePassthroughMixin): @register_scalar_type(Mode.CONSTANT, BaseType.BOOLEAN) -class Boolean(BooleanType): +class Boolean(AbstractBooleanType): """The Nada Boolean type. Represents a constant (literal) boolean.""" @@ -484,11 +485,11 @@ def __invert__(self: "Boolean") -> "Boolean": def is_literal(cls) -> bool: return True - def metatype(self): - return BooleanMetaType() + def type(self): + return BooleanType() -class BooleanMetaType(MetaTypePassthroughMixin): +class BooleanType(TypePassthroughMixin): """Meta type for booleans""" ty = Boolean @@ -503,7 +504,7 @@ class PublicInteger(NumericType): Represents a public unsigned integer in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.INTEGER, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: @@ -515,11 +516,11 @@ def public_equals( """Implementation of public equality for Public integer types.""" return public_equals_operation(self, other) - def metatype(self): - return PublicIntegerMetaType() + def type(self): + return PublicIntegerType() -class PublicIntegerMetaType(MetaTypePassthroughMixin): +class PublicIntegerType(TypePassthroughMixin): """Meta type for public integers""" ty = PublicInteger @@ -533,7 +534,7 @@ class PublicUnsignedInteger(NumericType): Represents a public integer in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.UNSIGNED_INTEGER, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: @@ -545,11 +546,11 @@ def public_equals( """Implementation of public equality for Public unsigned integer types.""" return public_equals_operation(self, other) - def metatype(self): - return PublicUnsignedIntegerMetaType() + def type(self): + return PublicUnsignedIntegerType() -class PublicUnsignedIntegerMetaType(MetaTypePassthroughMixin): +class PublicUnsignedIntegerType(TypePassthroughMixin): """Meta type for public unsigned integers""" ty = PublicUnsignedInteger @@ -558,13 +559,13 @@ class PublicUnsignedIntegerMetaType(MetaTypePassthroughMixin): @dataclass @register_scalar_type(Mode.PUBLIC, BaseType.BOOLEAN) -class PublicBoolean(BooleanType): +class PublicBoolean(AbstractBooleanType): """The Nada Public Boolean type. Represents a public boolean in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.BOOLEAN, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: @@ -580,11 +581,11 @@ def public_equals( """Implementation of public equality for Public boolean types.""" return public_equals_operation(self, other) - def metatype(self): - return PublicBooleanMetaType() + def type(self): + return PublicBooleanType() -class PublicBooleanMetaType(MetaTypePassthroughMixin): +class PublicBooleanType(TypePassthroughMixin): """Meta type for public booleans""" ty = PublicBoolean @@ -596,7 +597,7 @@ class PublicBooleanMetaType(MetaTypePassthroughMixin): class SecretInteger(NumericType): """The Nada secret integer type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.INTEGER, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: @@ -635,11 +636,11 @@ def to_public(self: "SecretInteger") -> "PublicInteger": operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicInteger(child=operation) - def metatype(self): - return SecretIntegerMetaType() + def type(self): + return SecretIntegerType() -class SecretIntegerMetaType(MetaTypePassthroughMixin): +class SecretIntegerType(TypePassthroughMixin): """Meta type for secret integers""" ty = SecretInteger @@ -651,7 +652,7 @@ class SecretIntegerMetaType(MetaTypePassthroughMixin): class SecretUnsignedInteger(NumericType): """The Nada Secret Unsigned integer type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.UNSIGNED_INTEGER, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: @@ -692,11 +693,11 @@ def to_public( operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicUnsignedInteger(child=operation) - def metatype(self): - return SecretUnsignedIntegerMetaType() + def type(self): + return SecretUnsignedIntegerType() -class SecretUnsignedIntegerMetaType(MetaTypePassthroughMixin): +class SecretUnsignedIntegerType(TypePassthroughMixin): """Meta type for secret unsigned integers""" ty = SecretUnsignedInteger @@ -705,10 +706,10 @@ class SecretUnsignedIntegerMetaType(MetaTypePassthroughMixin): @dataclass @register_scalar_type(Mode.SECRET, BaseType.BOOLEAN) -class SecretBoolean(BooleanType): +class SecretBoolean(AbstractBooleanType): """The SecretBoolean Nada MIR type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.BOOLEAN, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: @@ -728,11 +729,11 @@ def random(cls) -> "SecretBoolean": """Generate a random secret boolean.""" return SecretBoolean(child=Random(source_ref=SourceRef.back_frame())) - def metatype(self): - return SecretBooleanMetaType() + def type(self): + return SecretBooleanType() -class SecretBooleanMetaType(MetaTypePassthroughMixin): +class SecretBooleanType(TypePassthroughMixin): """Meta type for secret booleans""" ty = SecretBoolean @@ -740,41 +741,41 @@ class SecretBooleanMetaType(MetaTypePassthroughMixin): @dataclass -class EcdsaSignature(NadaType): +class EcdsaSignature(DslType): """The EcdsaSignature Nada MIR type.""" def __init__(self, child: OperationType): super().__init__(child=child) - def metatype(self): - return EcdsaSignatureMetaType() + def type(self): + return EcdsaSignatureType() -class EcdsaSignatureMetaType(MetaTypePassthroughMixin): +class EcdsaSignatureType(TypePassthroughMixin): """Meta type for EcdsaSignatures""" ty = EcdsaSignature @dataclass -class EcdsaDigestMessage(NadaType): +class EcdsaDigestMessage(DslType): """The EcdsaDigestMessage Nada MIR type.""" def __init__(self, child: OperationType): super().__init__(child=child) - def metatype(self): - return EcdsaDigestMessageMetaType() + def type(self): + return EcdsaDigestMessageType() -class EcdsaDigestMessageMetaType(MetaTypePassthroughMixin): +class EcdsaDigestMessageType(TypePassthroughMixin): """Meta type for EcdsaDigestMessages""" ty = EcdsaDigestMessage @dataclass -class EcdsaPrivateKey(NadaType): +class EcdsaPrivateKey(DslType): """The EcdsaPrivateKey Nada MIR type.""" def __init__(self, child: OperationType): @@ -786,11 +787,11 @@ def ecdsa_sign(self, digest: "EcdsaDigestMessage") -> "EcdsaSignature": child=EcdsaSign(left=self, right=digest, source_ref=SourceRef.back_frame()) ) - def metatype(self): - return EcdsaPrivateKeyMetaType() + def type(self): + return EcdsaPrivateKeyType() -class EcdsaPrivateKeyMetaType(MetaTypePassthroughMixin): +class EcdsaPrivateKeyType(TypePassthroughMixin): """Meta type for EcdsaPrivateKeys""" ty = EcdsaPrivateKey diff --git a/nada_dsl/program_io.py b/nada_dsl/program_io.py index 1aaf52f..9848ab6 100644 --- a/nada_dsl/program_io.py +++ b/nada_dsl/program_io.py @@ -15,11 +15,11 @@ ) from nada_dsl.errors import InvalidTypeError from nada_dsl.nada_types import AllTypes, Party -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType from nada_dsl.source_ref import SourceRef -class Input(NadaType): +class Input(DslType): """ Represents an input to the computation. @@ -56,7 +56,7 @@ def store_in_ast(self, ty: object): @dataclass -class Literal(NadaType): +class Literal(DslType): """ Represents a literal value. @@ -103,7 +103,7 @@ class Output: def __init__(self, child, name, party): self.source_ref = SourceRef.back_frame() - if not issubclass(type(child), NadaType): + if not issubclass(type(child), DslType): raise InvalidTypeError( f"{self.source_ref.file}:{self.source_ref.lineno}: Output value " f"{child} of type {type(child)} is not " diff --git a/tests/compiler_frontend_test.py b/tests/compiler_frontend_test.py index 284624e..5b664d0 100644 --- a/tests/compiler_frontend_test.py +++ b/tests/compiler_frontend_test.py @@ -129,7 +129,7 @@ def test_duplicated_inputs_checks(): def test_array_type_conversion(input_type, type_name, size): inner_input = create_input(SecretInteger, "name", "party", **{}) collection = create_collection(input_type, inner_input, size, **{}) - converted_input = collection.metatype().to_mir() + converted_input = collection.type().to_mir() assert list(converted_input.keys()) == [type_name] diff --git a/tests/scalar_type_test.py b/tests/scalar_type_test.py index 94b9aa3..480818f 100644 --- a/tests/scalar_type_test.py +++ b/tests/scalar_type_test.py @@ -7,6 +7,8 @@ from nada_dsl import Input, Party from nada_dsl.nada_types import BaseType, Mode from nada_dsl.nada_types.scalar_types import ( + AbstractBoolean, + AbstractBooleanType, Integer, PublicInteger, SecretInteger, @@ -253,7 +255,9 @@ def test_public_equals( @pytest.mark.parametrize("left, right, operation", binary_logic_operations) -def test_logic_operations(left: BooleanType, right: BooleanType, operation): +def test_logic_operations( + left: AbstractBooleanType, right: AbstractBooleanType, operation +): result = operation(left, right) assert result.base_type, BaseType.BOOLEAN assert result.mode.value, max([left.mode.value, right.mode.value])