From 1a09c436c1e909310717f80ef0cf92961aaf236b Mon Sep 17 00:00:00 2001 From: Jmgr Date: Tue, 26 Nov 2024 10:27:56 +0000 Subject: [PATCH] chore: rename NadaType to DslType, rename MetaType suffix to Type --- nada_dsl/nada_types/__init__.py | 13 +- nada_dsl/nada_types/collections.py | 96 ++++---- nada_dsl/nada_types/function.py | 10 +- nada_dsl/nada_types/generics.py | 8 +- nada_dsl/nada_types/scalar_types.py | 107 ++++----- nada_dsl/program_io.py | 8 +- tests/compiler_frontend_test.py | 8 +- tests/scalar_type_test.py | 326 +++++++++++++++++----------- 8 files changed, 324 insertions(+), 252 deletions(-) diff --git a/nada_dsl/nada_types/__init__.py b/nada_dsl/nada_types/__init__.py index 9aff870..4bd2290 100644 --- a/nada_dsl/nada_types/__init__.py +++ b/nada_dsl/nada_types/__init__.py @@ -113,13 +113,14 @@ def is_numeric(self) -> bool: return self in (BaseType.INTEGER, BaseType.UNSIGNED_INTEGER) +# TODO: make this abstract? @dataclass -class NadaType: - """Nada type class. +class DslType: + """DSL type class. - This is the parent class of all nada types. + This is the parent class of all DSL types. - In Nada, all the types wrap Operations. For instance, an addition between two integers + In the DSL, all the types wrap Operations. For instance, an addition between two integers is represented like this SecretInteger(child=Addition(...)). In MIR, the representation is based around operations. A MIR operation points to other @@ -145,7 +146,7 @@ def __init__(self, child: OperationType): """ self.child = child if self.child is not None: - self.child.store_in_ast(self.metatype().to_mir()) + self.child.store_in_ast(self.type().to_mir()) def __bool__(self): raise NotImplementedError @@ -161,5 +162,5 @@ def is_literal(cls) -> bool: return False @abstractmethod - def metatype(self): + def type(self): """Returns a meta type for this NadaType.""" diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index 0f4c7f4..06d67b3 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -14,7 +14,7 @@ ReduceASTOperation, UnaryASTOperation, ) -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType # Wildcard import due to non-zero types from nada_dsl.nada_types.scalar_types import * # pylint: disable=W0614:wildcard-import @@ -108,12 +108,12 @@ def store_in_ast(self, ty): ) -class TupleMetaType(MetaType): +class TupleType(DslType): """Marker type for Tuples.""" is_compound = True - def __init__(self, left_type: MetaType, right_type: MetaType): + def __init__(self, left_type: DslType, right_type: DslType): self.left_type = left_type self.right_type = right_type @@ -131,7 +131,7 @@ def to_mir(self): @dataclass -class Tuple(Generic[T, U], NadaType): +class Tuple(Generic[T, U], DslType): """The Tuple type""" left_type: T @@ -144,11 +144,11 @@ def __init__(self, child, left_type: T, right_type: U): super().__init__(self.child) @classmethod - def new(cls, left_value: NadaType, right_value: NadaType) -> "Tuple[T, U]": + def new(cls, left_value: DslType, right_value: DslType) -> "Tuple[T, U]": """Constructs a new Tuple.""" return Tuple( - left_type=left_value.metatype(), - right_type=right_value.metatype(), + left_type=left_value.type(), + right_type=right_value.type(), child=TupleNew( child=(left_value, right_value), source_ref=SourceRef.back_frame(), @@ -156,27 +156,27 @@ def new(cls, left_value: NadaType, right_value: NadaType) -> "Tuple[T, U]": ) @classmethod - def generic_type(cls, left_type: U, right_type: T) -> TupleMetaType: + def generic_type(cls, left_type: U, right_type: T) -> TupleType: """Returns the generic type for this Tuple""" - return TupleMetaType(left_type=left_type, right_type=right_type) + return TupleType(left_type=left_type, right_type=right_type) - def metatype(self): - """Metatype for Tuple""" - return TupleMetaType(self.left_type, self.right_type) + def type(self): + """Type for Tuple""" + return TupleType(self.left_type, self.right_type) -def _generate_accessor(ty: Any, accessor: Any) -> NadaType: +def _generate_accessor(ty: Any, accessor: Any) -> DslType: if hasattr(ty, "ty") and ty.ty.is_literal(): # TODO: fix raise TypeError("Literals are not supported in accessors") return ty.instantiate(accessor) -class NTupleMetaType(MetaType): +class NTupleType(DslType): """Marker type for NTuples.""" is_compound = True - def __init__(self, types: List[MetaType]): + def __init__(self, types: List[DslType]): self.types = types def instantiate(self, child_or_value): @@ -192,7 +192,7 @@ def to_mir(self): @dataclass -class NTuple(NadaType): +class NTuple(DslType): """The NTuple type""" types: List[Any] @@ -205,7 +205,7 @@ def __init__(self, child, types: List[Any]): @classmethod def new(cls, values: List[Any]) -> "NTuple": """Constructs a new NTuple.""" - types = [value.metatype() for value in values] + types = [value.type() for value in values] return NTuple( types=types, child=NTupleNew( @@ -214,7 +214,7 @@ def new(cls, values: List[Any]) -> "NTuple": ), ) - def __getitem__(self, index: int) -> NadaType: + def __getitem__(self, index: int) -> DslType: if index >= len(self.types): raise IndexError(f"Invalid index {index} for NTuple.") @@ -226,9 +226,9 @@ def __getitem__(self, index: int) -> NadaType: return _generate_accessor(self.types[index], accessor) - def metatype(self): - """Metatype for NTuple""" - return NTupleMetaType(self.types) + def type(self): + """Type for NTuple""" + return NTupleType(self.types) @dataclass @@ -261,12 +261,12 @@ def store_in_ast(self, ty: object): ) -class ObjectMetaType(MetaType): +class ObjectType(DslType): """Marker type for Objects.""" is_compound = True - def __init__(self, types: Dict[str, MetaType]): + def __init__(self, types: Dict[str, DslType]): self.types = types def to_mir(self): @@ -280,7 +280,7 @@ def instantiate(self, child_or_value): @dataclass -class Object(NadaType): +class Object(DslType): """The Object type""" types: Dict[str, Any] @@ -293,7 +293,7 @@ def __init__(self, child, types: Dict[str, Any]): @classmethod def new(cls, values: Dict[str, Any]) -> "Object": """Constructs a new Object.""" - types = {key: value.metatype() for key, value in values.items()} + types = {key: value.type() for key, value in values.items()} return Object( types=types, child=ObjectNew( @@ -302,7 +302,7 @@ def new(cls, values: Dict[str, Any]) -> "Object": ), ) - def __getattr__(self, attr: str) -> NadaType: + def __getattr__(self, attr: str) -> DslType: if attr not in self.types: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{attr}'" @@ -316,9 +316,9 @@ def __getattr__(self, attr: str) -> NadaType: return _generate_accessor(self.types[attr], accessor) - def metatype(self): - """Metatype for Object""" - return ObjectMetaType(types=self.types) + def type(self): + """Type for Object""" + return ObjectType(types=self.types) @dataclass @@ -412,7 +412,7 @@ def store_in_ast(self, ty: NadaTypeRepr): ) -class ArrayMetaType(MetaType): +class ArrayType(DslType): """Marker type for arrays.""" is_compound = True @@ -428,7 +428,7 @@ def to_mir(self): # and apply the same logic when the function gets passed to .map() or .reduce() # so we now the size of the array if self.size is None: - raise NotImplementedError("ArrayMetaType.to_mir") + raise NotImplementedError("ArrayType.to_mir") return { "Array": { "inner_type": self.contained_type.to_mir(), @@ -441,7 +441,7 @@ def instantiate(self, child_or_value): @dataclass -class Array(Generic[T], NadaType): +class Array(Generic[T], DslType): """Nada Array type. This is the representation of arrays in Nada MIR. @@ -461,14 +461,14 @@ class Array(Generic[T], NadaType): size: int def __init__(self, child, size: int, contained_type: T = None): - self.contained_type = contained_type or child.metatype() + self.contained_type = contained_type or child.type() self.size = size self.child = ( child if contained_type is not None else getattr(child, "child", None) ) if self.child is not None: - self.child.store_in_ast(self.metatype().to_mir()) + self.child.store_in_ast(self.type().to_mir()) def __iter__(self): raise NotAllowedException( @@ -495,9 +495,9 @@ def map(self: "Array[T]", function) -> "Array": def reduce(self: "Array[T]", function, initial: R) -> R: """The Reduce operation for arrays.""" self.check_not_constant(self.contained_type) - self.check_not_constant(initial.metatype()) + self.check_not_constant(initial.type()) function = create_nada_fn( - function, args_ty=[initial.metatype(), self.contained_type] + function, args_ty=[initial.type(), self.contained_type] ) return function.return_type.instantiate( Reduce( @@ -514,7 +514,7 @@ def zip(self: "Array[T]", other: "Array[U]") -> "Array[Tuple[T, U]]": raise IncompatibleTypesError("Cannot zip arrays of different size") return Array( size=self.size, - contained_type=TupleMetaType( + contained_type=TupleType( left_type=self.contained_type, right_type=other.contained_type, ), @@ -551,7 +551,7 @@ def new(cls, *args) -> "Array[T]": raise TypeError("All arguments must be of the same type") return Array( - contained_type=first_arg.metatype(), + contained_type=first_arg.type(), size=len(args), child=ArrayNew( child=args, @@ -559,9 +559,9 @@ def new(cls, *args) -> "Array[T]": ), ) - def metatype(self): - """Metatype for Array""" - return ArrayMetaType(self.contained_type, self.size) + def type(self): + """Type for Array""" + return ArrayType(self.contained_type, self.size) @dataclass @@ -597,10 +597,10 @@ class NTupleNew: Represents the creation of a new Tuple. """ - child: List[NadaType] + child: List[DslType] source_ref: SourceRef - def __init__(self, child: List[NadaType], source_ref: SourceRef): + def __init__(self, child: List[DslType], source_ref: SourceRef): self.id = next_operation_id() self.child = child self.source_ref = source_ref @@ -623,10 +623,10 @@ class ObjectNew: Represents the creation of a new Object. """ - child: Dict[str, NadaType] + child: Dict[str, DslType] source_ref: SourceRef - def __init__(self, child: Dict[str, NadaType], source_ref: SourceRef): + def __init__(self, child: Dict[str, DslType], source_ref: SourceRef): self.id = next_operation_id() self.child = child self.source_ref = source_ref @@ -644,10 +644,10 @@ def store_in_ast(self, ty: object): def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: """The Unzip operation for Arrays.""" - right_type = ArrayMetaType( + right_type = ArrayType( contained_type=array.contained_type.right_type, size=array.size ) - left_type = ArrayMetaType( + left_type = ArrayType( contained_type=array.contained_type.left_type, size=array.size ) @@ -670,7 +670,7 @@ def __init__(self, child: List[T], source_ref: SourceRef): self.child = child self.source_ref = source_ref - def store_in_ast(self, ty: NadaType): + def store_in_ast(self, ty: DslType): """Store this ArrayNew object in the AST.""" AST_OPERATIONS[self.id] = NewASTOperation( id=self.id, diff --git a/nada_dsl/nada_types/function.py b/nada_dsl/nada_types/function.py index 8b08f10..f9454a4 100644 --- a/nada_dsl/nada_types/function.py +++ b/nada_dsl/nada_types/function.py @@ -15,7 +15,7 @@ next_operation_id, ) from nada_dsl.nada_types.generics import T, R -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType class NadaFunctionArg(Generic[T]): @@ -65,7 +65,7 @@ def __init__( function: Callable[[T], R], return_type: R, source_ref: SourceRef, - child: NadaType, + child: DslType, ): self.child = child self.id = function_id @@ -98,7 +98,7 @@ class NadaFunctionCall(Generic[R]): """Represents a call to a Nada Function.""" fn: NadaFunction - args: List[NadaType] + args: List[DslType] source_ref: SourceRef def __init__(self, nada_function, args, source_ref): @@ -106,7 +106,7 @@ def __init__(self, nada_function, args, source_ref): self.args = args self.fn = nada_function self.source_ref = source_ref - self.store_in_ast(nada_function.return_type.metatype().to_mir()) + self.store_in_ast(nada_function.return_type.type().to_mir()) def store_in_ast(self, ty): """Store this function call in the AST.""" @@ -147,7 +147,7 @@ def create_nada_fn(fn, args_ty) -> NadaFunction[T, R]: child = fn(*nada_args_type_wrapped) - return_type = child.metatype() + return_type = child.type() return NadaFunction( function_id, function=fn, diff --git a/nada_dsl/nada_types/generics.py b/nada_dsl/nada_types/generics.py index 6923c60..ab53633 100644 --- a/nada_dsl/nada_types/generics.py +++ b/nada_dsl/nada_types/generics.py @@ -2,8 +2,8 @@ from typing import TypeVar -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType -R = TypeVar("R", bound=NadaType) -T = TypeVar("T", bound=NadaType) -U = TypeVar("U", bound=NadaType) +R = TypeVar("R", bound=DslType) +T = TypeVar("T", bound=DslType) +U = TypeVar("U", bound=DslType) diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index 50d513a..99d8d60 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -5,10 +5,11 @@ from dataclasses import dataclass from typing import Union, TypeVar from typing_extensions import Self +from nada_dsl.audit.abstract import AbstractBoolean from nada_dsl.operations import * from nada_dsl.program_io import Literal from nada_dsl import SourceRef -from . import NadaType, Mode, BaseType, OperationType +from . import DslType, Mode, BaseType, OperationType # Constant dictionary that stores all the Nada types and is use to # convert from the (mode, base_type) representation to the concrete Nada type @@ -47,7 +48,7 @@ AnyBoolean = Union["Boolean", "PublicBoolean", "SecretBoolean"] -class ScalarType(NadaType): +class ScalarType(DslType): """The Nada Scalar type. This is the super class for all scalar types in Nada. These are: @@ -288,7 +289,7 @@ def public_equals_operation(left: ScalarType, right: ScalarType) -> "PublicBoole ) -class BooleanType(ScalarType): +class AbstractBooleanType(ScalarType): """This abstraction represents all boolean types: - Boolean, PublicBoolean, SecretBoolean It provides common operation implementations for all the boolean types, defined above. @@ -349,7 +350,7 @@ def binary_logical_operation( return SecretBoolean(child=operation) -class MetaType(ABC): +class DslType(ABC): """Abstract meta type""" is_constant = False @@ -365,7 +366,7 @@ def to_mir(self): """Returns a MIR representation of this meta type""" -class MetaTypePassthroughMixin(MetaType): +class TypePassthroughMixin(DslType): """Mixin for meta types""" def instantiate(self, child_or_value): @@ -378,8 +379,8 @@ def to_mir(self): if name.startswith("Public"): name = name[len("Public") :].lstrip() - if name.endswith("MetaType"): - name = name[: -len("MetaType")].rstrip() + if name.endswith("Type"): + name = name[: -len("Type")].rstrip() return name @@ -405,11 +406,11 @@ def __eq__(self, other) -> AnyBoolean: def is_literal(cls) -> bool: return True - def metatype(self): - return IntegerMetaType() + def type(self): + return IntegerType() -class IntegerMetaType(MetaTypePassthroughMixin): +class IntegerType(TypePassthroughMixin): """Meta type for integers""" ty = Integer @@ -442,11 +443,11 @@ def __eq__(self, other) -> AnyBoolean: def is_literal(cls) -> bool: return True - def metatype(self): - return UnsignedIntegerMetaType() + def type(self): + return UnsignedIntegerType() -class UnsignedIntegerMetaType(MetaTypePassthroughMixin): +class UnsignedIntegerType(TypePassthroughMixin): """Meta type for unsigned integers""" ty = UnsignedInteger @@ -455,7 +456,7 @@ class UnsignedIntegerMetaType(MetaTypePassthroughMixin): @register_scalar_type(Mode.CONSTANT, BaseType.BOOLEAN) -class Boolean(BooleanType): +class Boolean(AbstractBoolean): """The Nada Boolean type. Represents a constant (literal) boolean.""" @@ -484,11 +485,11 @@ def __invert__(self: "Boolean") -> "Boolean": def is_literal(cls) -> bool: return True - def metatype(self): - return BooleanMetaType() + def type(self): + return BooleanType() -class BooleanMetaType(MetaTypePassthroughMixin): +class BooleanType(TypePassthroughMixin): """Meta type for booleans""" ty = Boolean @@ -503,7 +504,7 @@ class PublicInteger(NumericType): Represents a public unsigned integer in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.INTEGER, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: @@ -515,11 +516,11 @@ def public_equals( """Implementation of public equality for Public integer types.""" return public_equals_operation(self, other) - def metatype(self): - return PublicIntegerMetaType() + def type(self): + return PublicIntegerType() -class PublicIntegerMetaType(MetaTypePassthroughMixin): +class PublicIntegerType(TypePassthroughMixin): """Meta type for public integers""" ty = PublicInteger @@ -533,7 +534,7 @@ class PublicUnsignedInteger(NumericType): Represents a public integer in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.UNSIGNED_INTEGER, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: @@ -545,11 +546,11 @@ def public_equals( """Implementation of public equality for Public unsigned integer types.""" return public_equals_operation(self, other) - def metatype(self): - return PublicUnsignedIntegerMetaType() + def type(self): + return PublicUnsignedIntegerType() -class PublicUnsignedIntegerMetaType(MetaTypePassthroughMixin): +class PublicUnsignedIntegerType(TypePassthroughMixin): """Meta type for public unsigned integers""" ty = PublicUnsignedInteger @@ -564,7 +565,7 @@ class PublicBoolean(BooleanType): Represents a public boolean in a program. This is a public variable evaluated at runtime.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.BOOLEAN, Mode.PUBLIC) def __eq__(self, other) -> AnyBoolean: @@ -580,11 +581,11 @@ def public_equals( """Implementation of public equality for Public boolean types.""" return public_equals_operation(self, other) - def metatype(self): - return PublicBooleanMetaType() + def type(self): + return PublicBooleanType() -class PublicBooleanMetaType(MetaTypePassthroughMixin): +class PublicBooleanType(TypePassthroughMixin): """Meta type for public booleans""" ty = PublicBoolean @@ -596,7 +597,7 @@ class PublicBooleanMetaType(MetaTypePassthroughMixin): class SecretInteger(NumericType): """The Nada secret integer type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.INTEGER, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: @@ -635,11 +636,11 @@ def to_public(self: "SecretInteger") -> "PublicInteger": operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicInteger(child=operation) - def metatype(self): - return SecretIntegerMetaType() + def type(self): + return SecretIntegerType() -class SecretIntegerMetaType(MetaTypePassthroughMixin): +class SecretIntegerType(TypePassthroughMixin): """Meta type for secret integers""" ty = SecretInteger @@ -651,7 +652,7 @@ class SecretIntegerMetaType(MetaTypePassthroughMixin): class SecretUnsignedInteger(NumericType): """The Nada Secret Unsigned integer type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.UNSIGNED_INTEGER, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: @@ -692,11 +693,11 @@ def to_public( operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicUnsignedInteger(child=operation) - def metatype(self): - return SecretUnsignedIntegerMetaType() + def type(self): + return SecretUnsignedIntegerType() -class SecretUnsignedIntegerMetaType(MetaTypePassthroughMixin): +class SecretUnsignedIntegerType(TypePassthroughMixin): """Meta type for secret unsigned integers""" ty = SecretUnsignedInteger @@ -708,7 +709,7 @@ class SecretUnsignedIntegerMetaType(MetaTypePassthroughMixin): class SecretBoolean(BooleanType): """The SecretBoolean Nada MIR type.""" - def __init__(self, child: NadaType): + def __init__(self, child: DslType): super().__init__(child, BaseType.BOOLEAN, Mode.SECRET) def __eq__(self, other) -> AnyBoolean: @@ -728,11 +729,11 @@ def random(cls) -> "SecretBoolean": """Generate a random secret boolean.""" return SecretBoolean(child=Random(source_ref=SourceRef.back_frame())) - def metatype(self): - return SecretBooleanMetaType() + def type(self): + return SecretBooleanType() -class SecretBooleanMetaType(MetaTypePassthroughMixin): +class SecretBooleanType(TypePassthroughMixin): """Meta type for secret booleans""" ty = SecretBoolean @@ -740,41 +741,41 @@ class SecretBooleanMetaType(MetaTypePassthroughMixin): @dataclass -class EcdsaSignature(NadaType): +class EcdsaSignature(DslType): """The EcdsaSignature Nada MIR type.""" def __init__(self, child: OperationType): super().__init__(child=child) - def metatype(self): - return EcdsaSignatureMetaType() + def type(self): + return EcdsaSignatureType() -class EcdsaSignatureMetaType(MetaTypePassthroughMixin): +class EcdsaSignatureType(TypePassthroughMixin): """Meta type for EcdsaSignatures""" ty = EcdsaSignature @dataclass -class EcdsaDigestMessage(NadaType): +class EcdsaDigestMessage(DslType): """The EcdsaDigestMessage Nada MIR type.""" def __init__(self, child: OperationType): super().__init__(child=child) - def metatype(self): - return EcdsaDigestMessageMetaType() + def type(self): + return EcdsaDigestMessageType() -class EcdsaDigestMessageMetaType(MetaTypePassthroughMixin): +class EcdsaDigestMessageType(TypePassthroughMixin): """Meta type for EcdsaDigestMessages""" ty = EcdsaDigestMessage @dataclass -class EcdsaPrivateKey(NadaType): +class EcdsaPrivateKey(DslType): """The EcdsaPrivateKey Nada MIR type.""" def __init__(self, child: OperationType): @@ -786,11 +787,11 @@ def ecdsa_sign(self, digest: "EcdsaDigestMessage") -> "EcdsaSignature": child=EcdsaSign(left=self, right=digest, source_ref=SourceRef.back_frame()) ) - def metatype(self): - return EcdsaPrivateKeyMetaType() + def type(self): + return EcdsaPrivateKeyType() -class EcdsaPrivateKeyMetaType(MetaTypePassthroughMixin): +class EcdsaPrivateKeyType(TypePassthroughMixin): """Meta type for EcdsaPrivateKeys""" ty = EcdsaPrivateKey diff --git a/nada_dsl/program_io.py b/nada_dsl/program_io.py index 308d2a0..e03318d 100644 --- a/nada_dsl/program_io.py +++ b/nada_dsl/program_io.py @@ -15,11 +15,11 @@ ) from nada_dsl.errors import InvalidTypeError from nada_dsl.nada_types import AllTypes, Party -from nada_dsl.nada_types import NadaType +from nada_dsl.nada_types import DslType from nada_dsl.source_ref import SourceRef -class Input(NadaType): +class Input(DslType): """ Represents an input to the computation. @@ -56,7 +56,7 @@ def store_in_ast(self, ty: object): @dataclass -class Literal(NadaType): +class Literal(DslType): """ Represents a literal value. @@ -103,7 +103,7 @@ class Output: def __init__(self, child, name, party): self.source_ref = SourceRef.back_frame() - if not issubclass(type(child), NadaType): + if not issubclass(type(child), DslType): raise InvalidTypeError( f"{self.source_ref.file}:{self.source_ref.lineno}: Output value " f"{child} of type {type(child)} is not " diff --git a/tests/compiler_frontend_test.py b/tests/compiler_frontend_test.py index 1a53c8a..5b664d0 100644 --- a/tests/compiler_frontend_test.py +++ b/tests/compiler_frontend_test.py @@ -31,7 +31,11 @@ ) from nada_dsl.nada_types import AllTypes, Party from nada_dsl.nada_types.collections import Array, Tuple, NTuple, Object, unzip -from nada_dsl.nada_types.function import NadaFunctionArg, NadaFunctionCall, create_nada_fn +from nada_dsl.nada_types.function import ( + NadaFunctionArg, + NadaFunctionCall, + create_nada_fn, +) @pytest.fixture(autouse=True) @@ -125,7 +129,7 @@ def test_duplicated_inputs_checks(): def test_array_type_conversion(input_type, type_name, size): inner_input = create_input(SecretInteger, "name", "party", **{}) collection = create_collection(input_type, inner_input, size, **{}) - converted_input = collection.metatype().to_mir() + converted_input = collection.type().to_mir() assert list(converted_input.keys()) == [type_name] diff --git a/tests/scalar_type_test.py b/tests/scalar_type_test.py index 24abd8a..bf725f4 100644 --- a/tests/scalar_type_test.py +++ b/tests/scalar_type_test.py @@ -6,8 +6,21 @@ from nada_dsl import Input, Party from nada_dsl.nada_types import BaseType, Mode -from nada_dsl.nada_types.scalar_types import Integer, PublicInteger, SecretInteger, Boolean, PublicBoolean, \ - SecretBoolean, UnsignedInteger, PublicUnsignedInteger, SecretUnsignedInteger, ScalarType, BooleanType +from nada_dsl.nada_types.scalar_types import ( + AbstractBoolean, + Integer, + PublicInteger, + SecretInteger, + Boolean, + PublicBoolean, + SecretBoolean, + UnsignedInteger, + PublicUnsignedInteger, + SecretUnsignedInteger, + ScalarType, + BooleanType, +) + def combine_lists(list1, list2): """This returns all combinations for the items of two lists""" @@ -30,7 +43,7 @@ def combine_lists(list1, list2): booleans = [ Boolean(value=True), PublicBoolean(Input(name="public", party=Party("party"))), - SecretBoolean(Input(name="secret", party=Party("party"))) + SecretBoolean(Input(name="secret", party=Party("party"))), ] # All public boolean values @@ -46,7 +59,7 @@ def combine_lists(list1, list2): integers = [ Integer(value=1), PublicInteger(Input(name="public", party=Party("party"))), - SecretInteger(Input(name="secret", party=Party("party"))) + SecretInteger(Input(name="secret", party=Party("party"))), ] # All public integer values @@ -61,14 +74,14 @@ def combine_lists(list1, list2): # All integer inputs (non literal elements) variable_integers = [ PublicInteger(Input(name="public", party=Party("party"))), - SecretInteger(Input(name="public", party=Party("party"))) + SecretInteger(Input(name="public", party=Party("party"))), ] # All unsigned integer values unsigned_integers = [ UnsignedInteger(value=1), PublicUnsignedInteger(Input(name="public", party=Party("party"))), - SecretUnsignedInteger(Input(name="secret", party=Party("party"))) + SecretUnsignedInteger(Input(name="secret", party=Party("party"))), ] # All public unsigned integer values @@ -83,7 +96,7 @@ def combine_lists(list1, list2): # All unsigned integer inputs (non-literal elements) variable_unsigned_integers = [ PublicUnsignedInteger(Input(name="public", party=Party("party"))), - SecretUnsignedInteger(Input(name="public", party=Party("party"))) + SecretUnsignedInteger(Input(name="public", party=Party("party"))), ] # Binary arithmetic operations. They are provided as functions to the tests to avoid duplicate code @@ -98,9 +111,11 @@ def combine_lists(list1, list2): # Data set for the binary arithmetic operation tests. It combines all allowed operands with the operations. binary_arithmetic_operations = ( # Integers - combine_lists(itertools.product(integers, repeat=2), binary_arithmetic_functions) - # UnsignedIntegers - + combine_lists(itertools.product(unsigned_integers, repeat=2), binary_arithmetic_functions) + combine_lists(itertools.product(integers, repeat=2), binary_arithmetic_functions) + # UnsignedIntegers + + combine_lists( + itertools.product(unsigned_integers, repeat=2), binary_arithmetic_functions + ) ) @@ -114,16 +129,16 @@ def test_binary_arithmetic_operations(left: ScalarType, right: ScalarType, opera # Allowed operands for the power operation allowed_pow_operands = ( - # Integers: Only combinations of public integers - combine_lists(public_integers, public_integers) - # UnsignedIntegers: Only combinations of public unsigned integers - + combine_lists(public_unsigned_integers, public_unsigned_integers) + # Integers: Only combinations of public integers + combine_lists(public_integers, public_integers) + # UnsignedIntegers: Only combinations of public unsigned integers + + combine_lists(public_unsigned_integers, public_unsigned_integers) ) @pytest.mark.parametrize("left, right", allowed_pow_operands) def test_pow(left: ScalarType, right: ScalarType): - result = left ** right + result = left**right assert result.base_type, left.base_type assert result.base_type, right.base_type assert result.mode.value, max([left.mode.value, right.mode.value]) @@ -137,10 +152,12 @@ def test_pow(left: ScalarType, right: ScalarType): # The shift operations accept public unsigned integers on the right operand only. allowed_shift_operands = ( - # Integers on the left operand - combine_lists(combine_lists(integers, public_unsigned_integers), shift_functions) - # UnsignedIntegers on the left operand - + combine_lists(combine_lists(unsigned_integers, public_unsigned_integers), shift_functions) + # Integers on the left operand + combine_lists(combine_lists(integers, public_unsigned_integers), shift_functions) + # UnsignedIntegers on the left operand + + combine_lists( + combine_lists(unsigned_integers, public_unsigned_integers), shift_functions + ) ) @@ -157,15 +174,17 @@ def test_shift(left: ScalarType, right: ScalarType, operation): lambda lhs, rhs: lhs < rhs, lambda lhs, rhs: lhs > rhs, lambda lhs, rhs: lhs <= rhs, - lambda lhs, rhs: lhs >= rhs + lambda lhs, rhs: lhs >= rhs, ] # Allowed operands that are accepted by the numeric relational operations. They are combined with the operations. binary_relational_operations = ( - # Integers - combine_lists(itertools.product(integers, repeat=2), binary_relational_functions) - # UnsignedIntegers - + combine_lists(itertools.product(unsigned_integers, repeat=2), binary_relational_functions) + # Integers + combine_lists(itertools.product(integers, repeat=2), binary_relational_functions) + # UnsignedIntegers + + combine_lists( + itertools.product(unsigned_integers, repeat=2), binary_relational_functions + ) ) @@ -177,16 +196,13 @@ def test_binary_relational_operations(left: ScalarType, right: ScalarType, opera # Equality operations -equals_functions = [ - lambda lhs, rhs: lhs == rhs, - lambda lhs, rhs: lhs != rhs -] +equals_functions = [lambda lhs, rhs: lhs == rhs, lambda lhs, rhs: lhs != rhs] # Allowed operands that are accepted by the equality operations. They are combined with the operations. equals_operations = ( - combine_lists(itertools.product(integers, repeat=2), equals_functions) - + combine_lists(itertools.product(unsigned_integers, repeat=2), equals_functions) - + combine_lists(itertools.product(booleans, repeat=2), equals_functions) + combine_lists(itertools.product(integers, repeat=2), equals_functions) + + combine_lists(itertools.product(unsigned_integers, repeat=2), equals_functions) + + combine_lists(itertools.product(booleans, repeat=2), equals_functions) ) @@ -199,17 +215,27 @@ def test_equals_operations(left: ScalarType, right: ScalarType, operation): # Allowed operands that are accepted by the public_equals function. Literals are not accepted. public_equals_operands = ( - # Integers - combine_lists(variable_integers, variable_integers) - # UnsignedIntegers - + combine_lists(variable_unsigned_integers, variable_unsigned_integers) + # Integers + combine_lists(variable_integers, variable_integers) + # UnsignedIntegers + + combine_lists(variable_unsigned_integers, variable_unsigned_integers) ) @pytest.mark.parametrize("left, right", public_equals_operands) def test_public_equals( - left: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] - , right: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] + left: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], + right: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], ): assert isinstance(left.public_equals(right), PublicBoolean) @@ -218,15 +244,17 @@ def test_public_equals( logic_functions = [ lambda lhs, rhs: lhs & rhs, lambda lhs, rhs: lhs | rhs, - lambda lhs, rhs: lhs ^ rhs + lambda lhs, rhs: lhs ^ rhs, ] # Allowed operands that are accepted by the logic operations. They are combined with the operations. -binary_logic_operations = combine_lists(combine_lists(booleans, booleans), logic_functions) +binary_logic_operations = combine_lists( + combine_lists(booleans, booleans), logic_functions +) @pytest.mark.parametrize("left, right, operation", binary_logic_operations) -def test_logic_operations(left: BooleanType, right: BooleanType, operation): +def test_logic_operations(left: AbstractBoolean, right: AbstractBoolean, operation): result = operation(left, right) assert result.base_type, BaseType.BOOLEAN assert result.mode.value, max([left.mode.value, right.mode.value]) @@ -240,10 +268,9 @@ def test_invert_operations(operand): # Allowed operands that are accepted by the probabilistic truncation. -trunc_pr_operands = ( - combine_lists(secret_integers, public_unsigned_integers) - + combine_lists(secret_unsigned_integers, public_unsigned_integers) -) +trunc_pr_operands = combine_lists( + secret_integers, public_unsigned_integers +) + combine_lists(secret_unsigned_integers, public_unsigned_integers) @pytest.mark.parametrize("left, right", trunc_pr_operands) @@ -279,10 +306,14 @@ def test_to_public(operand): # Allow combination of operands that are accepted by if_else function if_else_operands = ( - combine_lists(secret_booleans, combine_lists(integers, integers)) - + combine_lists([public_boolean], combine_lists(integers, integers)) - + combine_lists(secret_booleans, combine_lists(unsigned_integers, unsigned_integers)) - + combine_lists([public_boolean], combine_lists(unsigned_integers, unsigned_integers)) + combine_lists(secret_booleans, combine_lists(integers, integers)) + + combine_lists([public_boolean], combine_lists(integers, integers)) + + combine_lists( + secret_booleans, combine_lists(unsigned_integers, unsigned_integers) + ) + + combine_lists( + [public_boolean], combine_lists(unsigned_integers, unsigned_integers) + ) ) @@ -296,40 +327,57 @@ def test_if_else(condition: BooleanType, left: ScalarType, right: ScalarType): # List of not allowed operations -not_allowed_binary_operations = \ - ( # Arithmetic operations - combine_lists(combine_lists(booleans, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(booleans, integers), binary_arithmetic_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), binary_arithmetic_functions) - + combine_lists(combine_lists(integers, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(integers, unsigned_integers), binary_arithmetic_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(unsigned_integers, integers), binary_arithmetic_functions) - # Relational operations - + combine_lists(combine_lists(booleans, booleans), binary_relational_functions) - + combine_lists(combine_lists(booleans, integers), binary_relational_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), binary_relational_functions) - + combine_lists(combine_lists(integers, booleans), binary_relational_functions) - + combine_lists(combine_lists(integers, unsigned_integers), binary_relational_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), binary_relational_functions) - + combine_lists(combine_lists(unsigned_integers, integers), binary_relational_functions) - # Equals operations - + combine_lists(combine_lists(booleans, integers), equals_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), equals_functions) - + combine_lists(combine_lists(integers, booleans), equals_functions) - + combine_lists(combine_lists(integers, unsigned_integers), equals_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), equals_functions) - + combine_lists(combine_lists(unsigned_integers, integers), equals_functions) - # Logic operations - + combine_lists(combine_lists(booleans, integers), logic_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), logic_functions) - + combine_lists(combine_lists(integers, booleans), logic_functions) - + combine_lists(combine_lists(integers, integers), logic_functions) - + combine_lists(combine_lists(integers, unsigned_integers), logic_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), logic_functions) - + combine_lists(combine_lists(unsigned_integers, integers), logic_functions) - + combine_lists(combine_lists(unsigned_integers, unsigned_integers), logic_functions) +not_allowed_binary_operations = ( # Arithmetic operations + combine_lists(combine_lists(booleans, booleans), binary_arithmetic_functions) + + combine_lists(combine_lists(booleans, integers), binary_arithmetic_functions) + + combine_lists( + combine_lists(booleans, unsigned_integers), binary_arithmetic_functions + ) + + combine_lists(combine_lists(integers, booleans), binary_arithmetic_functions) + + combine_lists( + combine_lists(integers, unsigned_integers), binary_arithmetic_functions ) + + combine_lists( + combine_lists(unsigned_integers, booleans), binary_arithmetic_functions + ) + + combine_lists( + combine_lists(unsigned_integers, integers), binary_arithmetic_functions + ) + # Relational operations + + combine_lists(combine_lists(booleans, booleans), binary_relational_functions) + + combine_lists(combine_lists(booleans, integers), binary_relational_functions) + + combine_lists( + combine_lists(booleans, unsigned_integers), binary_relational_functions + ) + + combine_lists(combine_lists(integers, booleans), binary_relational_functions) + + combine_lists( + combine_lists(integers, unsigned_integers), binary_relational_functions + ) + + combine_lists( + combine_lists(unsigned_integers, booleans), binary_relational_functions + ) + + combine_lists( + combine_lists(unsigned_integers, integers), binary_relational_functions + ) + # Equals operations + + combine_lists(combine_lists(booleans, integers), equals_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), equals_functions) + + combine_lists(combine_lists(integers, booleans), equals_functions) + + combine_lists(combine_lists(integers, unsigned_integers), equals_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), equals_functions) + + combine_lists(combine_lists(unsigned_integers, integers), equals_functions) + # Logic operations + + combine_lists(combine_lists(booleans, integers), logic_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), logic_functions) + + combine_lists(combine_lists(integers, booleans), logic_functions) + + combine_lists(combine_lists(integers, integers), logic_functions) + + combine_lists(combine_lists(integers, unsigned_integers), logic_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), logic_functions) + + combine_lists(combine_lists(unsigned_integers, integers), logic_functions) + + combine_lists( + combine_lists(unsigned_integers, unsigned_integers), logic_functions + ) +) @pytest.mark.parametrize("left, right, operation", not_allowed_binary_operations) @@ -341,38 +389,40 @@ def test_not_allowed_binary_operations(left, right, operation): # List of operands that the operation power does not accept. not_allowed_pow = ( - combine_lists(booleans, booleans) - + combine_lists(integers, booleans) - + combine_lists(unsigned_integers, booleans) - + combine_lists(booleans, integers) - + combine_lists(secret_integers, integers) - + combine_lists(public_integers, secret_integers) - + combine_lists(integers, unsigned_integers) - + combine_lists(booleans, unsigned_integers) - + combine_lists(unsigned_integers, integers) - + combine_lists(secret_unsigned_integers, unsigned_integers) - + combine_lists(public_unsigned_integers, secret_unsigned_integers) + combine_lists(booleans, booleans) + + combine_lists(integers, booleans) + + combine_lists(unsigned_integers, booleans) + + combine_lists(booleans, integers) + + combine_lists(secret_integers, integers) + + combine_lists(public_integers, secret_integers) + + combine_lists(integers, unsigned_integers) + + combine_lists(booleans, unsigned_integers) + + combine_lists(unsigned_integers, integers) + + combine_lists(secret_unsigned_integers, unsigned_integers) + + combine_lists(public_unsigned_integers, secret_unsigned_integers) ) @pytest.mark.parametrize("left, right", not_allowed_pow) def test_not_allowed_pow(left, right): with pytest.raises(Exception) as invalid_operation: - left ** right + left**right assert invalid_operation.type == TypeError # List of operands that the shift operation do not accept. not_allowed_shift = ( - combine_lists(combine_lists(booleans, booleans), shift_functions) - + combine_lists(combine_lists(integers, booleans), shift_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), shift_functions) - + combine_lists(combine_lists(booleans, integers), shift_functions) - + combine_lists(combine_lists(integers, integers), shift_functions) - + combine_lists(combine_lists(unsigned_integers, integers), shift_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), shift_functions) - + combine_lists(combine_lists(integers, secret_unsigned_integers), shift_functions) - + combine_lists(combine_lists(unsigned_integers, secret_unsigned_integers), shift_functions) + combine_lists(combine_lists(booleans, booleans), shift_functions) + + combine_lists(combine_lists(integers, booleans), shift_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), shift_functions) + + combine_lists(combine_lists(booleans, integers), shift_functions) + + combine_lists(combine_lists(integers, integers), shift_functions) + + combine_lists(combine_lists(unsigned_integers, integers), shift_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), shift_functions) + + combine_lists(combine_lists(integers, secret_unsigned_integers), shift_functions) + + combine_lists( + combine_lists(unsigned_integers, secret_unsigned_integers), shift_functions + ) ) @@ -384,14 +434,25 @@ def test_not_allowed_shift(left, right, operation): # List of operands that the public_equals function does not accept. -not_allowed_public_equals_operands = (combine_lists(variable_integers, variable_unsigned_integers) - + combine_lists(variable_unsigned_integers, variable_integers)) +not_allowed_public_equals_operands = combine_lists( + variable_integers, variable_unsigned_integers +) + combine_lists(variable_unsigned_integers, variable_integers) @pytest.mark.parametrize("left, right", not_allowed_public_equals_operands) def test_not_allowed_public_equals( - left: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] - , right: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] + left: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], + right: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], ): with pytest.raises(Exception) as invalid_operation: left.public_equals(right) @@ -411,17 +472,17 @@ def test_not_allowed_invert_operations(operand): # List of operands that the probabilistic truncation does not accept. not_allowed_trunc_pr_operands = ( - combine_lists(booleans, booleans) - + combine_lists(integers, booleans) - + combine_lists(unsigned_integers, booleans) - + combine_lists(booleans, integers) - + combine_lists(integers, integers) - + combine_lists(unsigned_integers, integers) - + combine_lists(booleans, unsigned_integers) - + combine_lists(integers, secret_unsigned_integers) - + combine_lists(public_integers, public_unsigned_integers) - + combine_lists(unsigned_integers, secret_unsigned_integers) - + combine_lists(public_unsigned_integers, public_unsigned_integers) + combine_lists(booleans, booleans) + + combine_lists(integers, booleans) + + combine_lists(unsigned_integers, booleans) + + combine_lists(booleans, integers) + + combine_lists(integers, integers) + + combine_lists(unsigned_integers, integers) + + combine_lists(booleans, unsigned_integers) + + combine_lists(integers, secret_unsigned_integers) + + combine_lists(public_integers, public_unsigned_integers) + + combine_lists(unsigned_integers, secret_unsigned_integers) + + combine_lists(public_unsigned_integers, public_unsigned_integers) ) @@ -429,7 +490,9 @@ def test_not_allowed_invert_operations(operand): def test_not_allowed_trunc_pr(left, right): with pytest.raises(Exception) as invalid_operation: left.trunc_pr(right) - assert invalid_operation.type == TypeError or invalid_operation.type == AttributeError + assert ( + invalid_operation.type == TypeError or invalid_operation.type == AttributeError + ) # List of types that cannot generate a random value @@ -442,20 +505,23 @@ def test_not_allowed_random(operand): operand.random() assert invalid_operation.type == AttributeError + # List of operands that the function if_else does not accept not_allowed_if_else_operands = ( - # Boolean branches - combine_lists(booleans, combine_lists(booleans, booleans)) - # Branches with different types - + combine_lists(booleans, combine_lists(integers, booleans)) - + combine_lists(booleans, combine_lists(unsigned_integers, booleans)) - + combine_lists(booleans, combine_lists(booleans, integers)) - + combine_lists(booleans, combine_lists(unsigned_integers, integers)) - + combine_lists(booleans, combine_lists(booleans, unsigned_integers)) - + combine_lists(booleans, combine_lists(integers, unsigned_integers)) - # The condition is a literal - + combine_lists([Boolean(value=True)], combine_lists(integers, integers)) - + combine_lists([Boolean(value=True)], combine_lists(unsigned_integers, unsigned_integers)) + # Boolean branches + combine_lists(booleans, combine_lists(booleans, booleans)) + # Branches with different types + + combine_lists(booleans, combine_lists(integers, booleans)) + + combine_lists(booleans, combine_lists(unsigned_integers, booleans)) + + combine_lists(booleans, combine_lists(booleans, integers)) + + combine_lists(booleans, combine_lists(unsigned_integers, integers)) + + combine_lists(booleans, combine_lists(booleans, unsigned_integers)) + + combine_lists(booleans, combine_lists(integers, unsigned_integers)) + # The condition is a literal + + combine_lists([Boolean(value=True)], combine_lists(integers, integers)) + + combine_lists( + [Boolean(value=True)], combine_lists(unsigned_integers, unsigned_integers) + ) )