From cd5f296e6f8cb2d16b900dcf9c24d30334a144c5 Mon Sep 17 00:00:00 2001 From: Juan M Salamanca Date: Fri, 8 Nov 2024 16:07:43 +0000 Subject: [PATCH] fix: if_else then to_public correctly typed --- nada_dsl/nada_types/scalar_types.py | 174 ++++++++++---- nada_dsl/scalar_type_test.py | 344 ++++++++++++++++------------ pyproject.toml | 19 +- uv.lock | 28 ++- 4 files changed, 370 insertions(+), 195 deletions(-) diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index c2f861a..e350ea4 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -1,8 +1,9 @@ # pylint:disable=W0401,W0614 """The Nada Scalar type definitions.""" - +from __future__ import annotations from dataclasses import dataclass -from typing import Union +from typing import TypeVar, Union, overload +from typing_extensions import Self from nada_dsl.operations import * from nada_dsl.program_io import Literal from nada_dsl import SourceRef @@ -14,6 +15,9 @@ SCALAR_TYPES = {} +BooleanTypes = Union["PublicBoolean", "SecretBoolean", "Boolean"] + + class ScalarType(NadaType): """The Nada Scalar type. This is the super class for all scalar types in Nada. @@ -33,7 +37,7 @@ def __init__(self, inner: OperationType, base_type: BaseType, mode: Mode): self.base_type = base_type self.mode = mode - def __eq__(self, other): + def __eq__(self, other) -> BooleanType: return equals_operation( "Equals", "==", self, other, lambda lhs, rhs: lhs == rhs ) @@ -43,10 +47,17 @@ def __ne__(self, other): "NotEquals", "!=", self, other, lambda lhs, rhs: lhs != rhs ) + def to_public(self) -> Self: + """Convert this scalar type into a public variable. + This is the default implementation, to be overriden by subclasses. + The default behaviour is to return the same type, but for secret types it will invoke the Reveal operation. + """ + return self + def equals_operation( operation, operator, left: ScalarType, right: ScalarType, f -) -> ScalarType: +) -> BooleanType: """This function is an abstraction for the equality operations""" base_type = left.base_type if base_type != right.base_type: @@ -70,18 +81,23 @@ def equals_operation( def register_scalar_type(mode: Mode, base_type: BaseType): """Decorator used to register a new scalar type in the `SCALAR_TYPES` dictionary.""" - def decorator(scalar_type: ScalarType): - SCALAR_TYPES[(mode, base_type)] = scalar_type - scalar_type.mode = mode - scalar_type.base_type = base_type - return scalar_type + def decorator(cls): + SCALAR_TYPES[(mode, base_type)] = cls + cls.mode = mode + cls.base_type = base_type + return cls return decorator -def new_scalar_type(mode: Mode, base_type: BaseType) -> ScalarType: +def new_scalar_type(mode: Mode, base_type: BaseType, *args) -> ScalarType: """Returns the corresponding MIR Nada Type""" - return SCALAR_TYPES.get((mode, base_type)) + scalar_type = SCALAR_TYPES.get((mode, base_type)) + if scalar_type is None: + raise TypeError( + f"scalar type not found for mode={mode} and base_type={base_type}" + ) + return scalar_type(args[0]) class NumericType(ScalarType): @@ -94,6 +110,16 @@ class NumericType(ScalarType): value: int + @classmethod + def new(cls, mode: Mode, base_type: BaseType, *args, **kwargs) -> Self: + """Returns the corresponding MIR Boolean Type""" + scalar_type = SCALAR_TYPES.get((mode, base_type)) + if scalar_type is None: + raise TypeError( + f"scalar type not found for mode={mode} and base_type={base_type}" + ) + return scalar_type(args, kwargs) + def __add__(self, other): return binary_arithmetic_operation( "Addition", "+", self, other, lambda lhs, rhs: lhs + rhs @@ -127,10 +153,10 @@ def __pow__(self, other): raise TypeError(f"Invalid operation: {self} ** {other}") mode = Mode(max([self.mode.value, other.mode.value])) if mode == Mode.CONSTANT: - return new_scalar_type(mode, base_type)(self.value**other.value) + return new_scalar_type(mode, base_type, self.value**other.value) if mode == Mode.PUBLIC: inner = Power(left=self, right=other, source_ref=SourceRef.back_frame()) - return new_scalar_type(mode, base_type)(inner) + return new_scalar_type(mode, base_type, inner) raise TypeError(f"Invalid operation: {self} ** {other}") def __lshift__(self, other): @@ -166,14 +192,15 @@ def __ge__(self, other): def __radd__(self, other): """This adds support for builtin `sum` operation for numeric types.""" if isinstance(other, int): - other_type = new_scalar_type(mode=Mode.CONSTANT, base_type=self.base_type)( - other - ) + other_type = new_scalar_type(Mode.CONSTANT, self.base_type, other) return self.__add__(other_type) return self.__add__(other) +NUMBER = TypeVar("NUMBER", bound=NumericType) + + def binary_arithmetic_operation( operation, operator, left: ScalarType, right: ScalarType, f ) -> ScalarType: @@ -186,12 +213,12 @@ def binary_arithmetic_operation( mode = Mode(max([left.mode.value, right.mode.value])) match mode: case Mode.CONSTANT: - return new_scalar_type(mode, base_type)(f(left.value, right.value)) + return new_scalar_type(mode, base_type, f(left.value, right.value)) case Mode.PUBLIC | Mode.SECRET: inner = globals()[operation]( left=left, right=right, source_ref=SourceRef.back_frame() ) - return new_scalar_type(mode, base_type)(inner) + return new_scalar_type(mode, base_type, inner) def shift_operation( @@ -208,33 +235,15 @@ def shift_operation( mode = Mode(max([left.mode.value, right_mode.value])) match mode: case Mode.CONSTANT: - return new_scalar_type(mode, base_type)(f(left.value, right.value)) + return new_scalar_type(mode, base_type, f(left.value, right.value)) case Mode.PUBLIC | Mode.SECRET: inner = globals()[operation]( left=left, right=right, source_ref=SourceRef.back_frame() ) - return new_scalar_type(mode, base_type)(inner) + return new_scalar_type(mode, base_type, inner) -def binary_relational_operation( - operation, operator, left: ScalarType, right: ScalarType, f -) -> ScalarType: - """This function is an abstraction for the binary relational operations""" - base_type = left.base_type - if base_type != right.base_type or not base_type.is_numeric(): - raise TypeError(f"Invalid operation: {left} {operator} {right}") - mode = Mode(max([left.mode.value, right.mode.value])) - match mode: - case Mode.CONSTANT: - return new_scalar_type(mode, BaseType.BOOLEAN)(f(left.value, right.value)) - case Mode.PUBLIC | Mode.SECRET: - inner = globals()[operation]( - left=left, right=right, source_ref=SourceRef.back_frame() - ) - return new_scalar_type(mode, BaseType.BOOLEAN)(inner) - - -def public_equals_operation(left: ScalarType, right: ScalarType) -> ScalarType: +def public_equals_operation(left: ScalarType, right: ScalarType) -> PublicBoolean: """This function is an abstraction for the public_equals operation for all types.""" base_type = left.base_type if base_type != right.base_type: @@ -255,6 +264,14 @@ class BooleanType(ScalarType): It provides common operation implementations for all the boolean types, defined above. """ + @classmethod + def new(cls, mode: Mode, *args) -> BooleanTypes: + """Returns the corresponding MIR Boolean Type""" + scalar_type = SCALAR_TYPES.get((mode, BaseType.BOOLEAN)) + if scalar_type is None: + raise TypeError(f"scalar type not found for mode={mode}") + return scalar_type(args[0]) + def __and__(self, other): return binary_logical_operation( "BooleanAnd", "&", self, other, lambda lhs, rhs: lhs & rhs @@ -270,7 +287,30 @@ def __xor__(self, other): "BooleanXor", "^", self, other, lambda lhs, rhs: lhs ^ rhs ) - def if_else(self, arg_0: ScalarType, arg_1: ScalarType) -> ScalarType: + # NOTE: These overloads are just for the type checker. + # They are a way to signal the different return types for if_else + @overload + def if_else(self, arg_0: SecretInteger, arg_1: SecretInteger) -> SecretInteger: ... + + @overload + def if_else(self, arg_0: PublicInteger, arg_1: PublicInteger) -> SecretInteger: ... + + @overload + def if_else( + self, arg_0: SecretUnsignedInteger, arg_1: SecretUnsignedInteger + ) -> SecretUnsignedInteger: ... + + @overload + def if_else( + self, arg_0: PublicUnsignedInteger, arg_1: PublicUnsignedInteger + ) -> SecretUnsignedInteger: ... + + def if_else(self, arg_0: NUMBER, arg_1: NUMBER) -> Union[ + PublicInteger, + PublicUnsignedInteger, + SecretInteger, + SecretUnsignedInteger, + ]: """This function implements the function 'if_else' for every class that extends 'BooleanType'.""" base_type = arg_0.base_type if ( @@ -285,7 +325,37 @@ def if_else(self, arg_0: ScalarType, arg_1: ScalarType) -> ScalarType: ) if mode == Mode.CONSTANT: mode = Mode.PUBLIC - return new_scalar_type(mode, base_type)(inner) + result_type = NumericType.new(mode, base_type, inner) + if isinstance( + result_type, + Union[ + PublicInteger, + PublicUnsignedInteger, + SecretInteger, + SecretUnsignedInteger, + ], + ): + return result_type + + raise TypeError(f"Invalid result type: {result_type}") + + +def binary_relational_operation( + operation, operator, left: NUMBER, right: NUMBER, f +) -> BooleanTypes: + """This function is an abstraction for the binary relational operations""" + base_type = left.base_type + if base_type != right.base_type or not base_type.is_numeric(): + raise TypeError(f"Invalid operation: {left} {operator} {right}") + mode = Mode(max([left.mode.value, right.mode.value])) + match mode: + case Mode.CONSTANT: + return BooleanType.new(mode, (f(left.value, right.value))) + case Mode.PUBLIC | Mode.SECRET: + inner = globals()[operation]( + left=left, right=right, source_ref=SourceRef.back_frame() + ) + return BooleanType.new(mode, inner) def binary_logical_operation( @@ -395,6 +465,10 @@ def public_equals( """Implementation of public equality for Public integer types.""" return public_equals_operation(self, other) + def to_public(self) -> PublicInteger: + """Convert this scalar type into a public variable.""" + return self + @register_scalar_type(Mode.PUBLIC, BaseType.UNSIGNED_INTEGER) class PublicUnsignedInteger(NumericType): @@ -411,10 +485,14 @@ def __eq__(self, other): def public_equals( self, other: Union["PublicUnsignedInteger", "SecretUnsignedInteger"] - ) -> "PublicBoolean": + ) -> PublicBoolean: """Implementation of public equality for Public unsigned integer types.""" return public_equals_operation(self, other) + def to_public(self) -> PublicUnsignedInteger: + """Convert this scalar type into a public variable.""" + return self + @dataclass @register_scalar_type(Mode.PUBLIC, BaseType.BOOLEAN) @@ -430,7 +508,7 @@ def __init__(self, inner: NadaType): def __eq__(self, other): return ScalarType.__eq__(self, other) - def __invert__(self: "PublicBoolean") -> "PublicBoolean": + def __invert__(self: PublicBoolean) -> PublicBoolean: operation = Not(this=self, source_ref=SourceRef.back_frame()) return PublicBoolean(inner=operation) @@ -454,7 +532,7 @@ def __eq__(self, other): def public_equals( self, other: Union["PublicInteger", "SecretInteger"] - ) -> "PublicBoolean": + ) -> PublicBoolean: """Implementation of public equality for secret integer types.""" return public_equals_operation(self, other) @@ -476,11 +554,11 @@ def trunc_pr( raise TypeError(f"Invalid operation: {self}.trunc_pr({other})") @classmethod - def random(cls) -> "SecretInteger": + def random(cls) -> SecretInteger: """Random operation for Secret integers.""" return SecretInteger(inner=Random(source_ref=SourceRef.back_frame())) - def to_public(self: "SecretInteger") -> "PublicInteger": + def to_public(self: SecretInteger) -> PublicInteger: """Convert this secret integer into a public variable.""" operation = Reveal(this=self, source_ref=SourceRef.back_frame()) return PublicInteger(inner=operation) @@ -585,7 +663,5 @@ def __init__(self, inner: OperationType): def ecdsa_sign(self, digest: "EcdsaDigestMessage") -> "EcdsaSignature": """Random operation for Secret integers.""" return EcdsaSignature( - inner=EcdsaSign( - left=self, right=digest, source_ref=SourceRef.back_frame() - ) + inner=EcdsaSign(left=self, right=digest, source_ref=SourceRef.back_frame()) ) diff --git a/nada_dsl/scalar_type_test.py b/nada_dsl/scalar_type_test.py index 05fa504..aa1e0ea 100644 --- a/nada_dsl/scalar_type_test.py +++ b/nada_dsl/scalar_type_test.py @@ -6,8 +6,20 @@ from nada_dsl import Input, Party from nada_dsl.nada_types import BaseType, Mode -from nada_dsl.nada_types.scalar_types import Integer, PublicInteger, SecretInteger, Boolean, PublicBoolean, \ - SecretBoolean, UnsignedInteger, PublicUnsignedInteger, SecretUnsignedInteger, ScalarType, BooleanType +from nada_dsl.nada_types.scalar_types import ( + Integer, + NumericType, + PublicInteger, + SecretInteger, + Boolean, + PublicBoolean, + SecretBoolean, + UnsignedInteger, + PublicUnsignedInteger, + SecretUnsignedInteger, + ScalarType, + BooleanType, +) def combine_lists(list1, list2): @@ -31,7 +43,7 @@ def combine_lists(list1, list2): booleans = [ Boolean(value=True), PublicBoolean(Input(name="public", party=Party("party"))), - SecretBoolean(Input(name="secret", party=Party("party"))) + SecretBoolean(Input(name="secret", party=Party("party"))), ] # All public boolean values @@ -47,7 +59,7 @@ def combine_lists(list1, list2): integers = [ Integer(value=1), PublicInteger(Input(name="public", party=Party("party"))), - SecretInteger(Input(name="secret", party=Party("party"))) + SecretInteger(Input(name="secret", party=Party("party"))), ] # All public integer values @@ -62,14 +74,14 @@ def combine_lists(list1, list2): # All integer inputs (non literal elements) variable_integers = [ PublicInteger(Input(name="public", party=Party("party"))), - SecretInteger(Input(name="public", party=Party("party"))) + SecretInteger(Input(name="public", party=Party("party"))), ] # All unsigned integer values unsigned_integers = [ UnsignedInteger(value=1), PublicUnsignedInteger(Input(name="public", party=Party("party"))), - SecretUnsignedInteger(Input(name="secret", party=Party("party"))) + SecretUnsignedInteger(Input(name="secret", party=Party("party"))), ] # All public unsigned integer values @@ -84,7 +96,7 @@ def combine_lists(list1, list2): # All unsigned integer inputs (non-literal elements) variable_unsigned_integers = [ PublicUnsignedInteger(Input(name="public", party=Party("party"))), - SecretUnsignedInteger(Input(name="public", party=Party("party"))) + SecretUnsignedInteger(Input(name="public", party=Party("party"))), ] # Binary arithmetic operations. They are provided as functions to the tests to avoid duplicate code @@ -99,9 +111,11 @@ def combine_lists(list1, list2): # Data set for the binary arithmetic operation tests. It combines all allowed operands with the operations. binary_arithmetic_operations = ( # Integers - combine_lists(itertools.product(integers, repeat=2), binary_arithmetic_functions) - # UnsignedIntegers - + combine_lists(itertools.product(unsigned_integers, repeat=2), binary_arithmetic_functions) + combine_lists(itertools.product(integers, repeat=2), binary_arithmetic_functions) + # UnsignedIntegers + + combine_lists( + itertools.product(unsigned_integers, repeat=2), binary_arithmetic_functions + ) ) @@ -115,16 +129,16 @@ def test_binary_arithmetic_operations(left: ScalarType, right: ScalarType, opera # Allowed operands for the power operation allowed_pow_operands = ( - # Integers: Only combinations of public integers - combine_lists(public_integers, public_integers) - # UnsignedIntegers: Only combinations of public unsigned integers - + combine_lists(public_unsigned_integers, public_unsigned_integers) + # Integers: Only combinations of public integers + combine_lists(public_integers, public_integers) + # UnsignedIntegers: Only combinations of public unsigned integers + + combine_lists(public_unsigned_integers, public_unsigned_integers) ) @pytest.mark.parametrize("left, right", allowed_pow_operands) def test_pow(left: ScalarType, right: ScalarType): - result = left ** right + result = left**right assert result.base_type, left.base_type assert result.base_type, right.base_type assert result.mode.value, max([left.mode.value, right.mode.value]) @@ -138,10 +152,12 @@ def test_pow(left: ScalarType, right: ScalarType): # The shift operations accept public unsigned integers on the right operand only. allowed_shift_operands = ( - # Integers on the left operand - combine_lists(combine_lists(integers, public_unsigned_integers), shift_functions) - # UnsignedIntegers on the left operand - + combine_lists(combine_lists(unsigned_integers, public_unsigned_integers), shift_functions) + # Integers on the left operand + combine_lists(combine_lists(integers, public_unsigned_integers), shift_functions) + # UnsignedIntegers on the left operand + + combine_lists( + combine_lists(unsigned_integers, public_unsigned_integers), shift_functions + ) ) @@ -158,15 +174,17 @@ def test_shift(left: ScalarType, right: ScalarType, operation): lambda lhs, rhs: lhs < rhs, lambda lhs, rhs: lhs > rhs, lambda lhs, rhs: lhs <= rhs, - lambda lhs, rhs: lhs >= rhs + lambda lhs, rhs: lhs >= rhs, ] # Allowed operands that are accepted by the numeric relational operations. They are combined with the operations. binary_relational_operations = ( - # Integers - combine_lists(itertools.product(integers, repeat=2), binary_relational_functions) - # UnsignedIntegers - + combine_lists(itertools.product(unsigned_integers, repeat=2), binary_relational_functions) + # Integers + combine_lists(itertools.product(integers, repeat=2), binary_relational_functions) + # UnsignedIntegers + + combine_lists( + itertools.product(unsigned_integers, repeat=2), binary_relational_functions + ) ) @@ -178,16 +196,13 @@ def test_binary_relational_operations(left: ScalarType, right: ScalarType, opera # Equality operations -equals_functions = [ - lambda lhs, rhs: lhs == rhs, - lambda lhs, rhs: lhs != rhs -] +equals_functions = [lambda lhs, rhs: lhs == rhs, lambda lhs, rhs: lhs != rhs] # Allowed operands that are accepted by the equality operations. They are combined with the operations. equals_operations = ( - combine_lists(itertools.product(integers, repeat=2), equals_functions) - + combine_lists(itertools.product(unsigned_integers, repeat=2), equals_functions) - + combine_lists(itertools.product(booleans, repeat=2), equals_functions) + combine_lists(itertools.product(integers, repeat=2), equals_functions) + + combine_lists(itertools.product(unsigned_integers, repeat=2), equals_functions) + + combine_lists(itertools.product(booleans, repeat=2), equals_functions) ) @@ -200,17 +215,27 @@ def test_equals_operations(left: ScalarType, right: ScalarType, operation): # Allowed operands that are accepted by the public_equals function. Literals are not accepted. public_equals_operands = ( - # Integers - combine_lists(variable_integers, variable_integers) - # UnsignedIntegers - + combine_lists(variable_unsigned_integers, variable_unsigned_integers) + # Integers + combine_lists(variable_integers, variable_integers) + # UnsignedIntegers + + combine_lists(variable_unsigned_integers, variable_unsigned_integers) ) @pytest.mark.parametrize("left, right", public_equals_operands) def test_public_equals( - left: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] - , right: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] + left: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], + right: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], ): assert isinstance(left.public_equals(right), PublicBoolean) @@ -219,11 +244,13 @@ def test_public_equals( logic_functions = [ lambda lhs, rhs: lhs & rhs, lambda lhs, rhs: lhs | rhs, - lambda lhs, rhs: lhs ^ rhs + lambda lhs, rhs: lhs ^ rhs, ] # Allowed operands that are accepted by the logic operations. They are combined with the operations. -binary_logic_operations = combine_lists(combine_lists(booleans, booleans), logic_functions) +binary_logic_operations = combine_lists( + combine_lists(booleans, booleans), logic_functions +) @pytest.mark.parametrize("left, right, operation", binary_logic_operations) @@ -241,10 +268,9 @@ def test_invert_operations(operand): # Allowed operands that are accepted by the probabilistic truncation. -trunc_pr_operands = ( - combine_lists(secret_integers, public_unsigned_integers) - + combine_lists(secret_unsigned_integers, public_unsigned_integers) -) +trunc_pr_operands = combine_lists( + secret_integers, public_unsigned_integers +) + combine_lists(secret_unsigned_integers, public_unsigned_integers) @pytest.mark.parametrize("left, right", trunc_pr_operands) @@ -280,10 +306,14 @@ def test_to_public(operand): # Allow combination of operands that are accepted by if_else function if_else_operands = ( - combine_lists(secret_booleans, combine_lists(integers, integers)) - + combine_lists([public_boolean], combine_lists(integers, integers)) - + combine_lists(secret_booleans, combine_lists(unsigned_integers, unsigned_integers)) - + combine_lists([public_boolean], combine_lists(unsigned_integers, unsigned_integers)) + combine_lists(secret_booleans, combine_lists(integers, integers)) + + combine_lists([public_boolean], combine_lists(integers, integers)) + + combine_lists( + secret_booleans, combine_lists(unsigned_integers, unsigned_integers) + ) + + combine_lists( + [public_boolean], combine_lists(unsigned_integers, unsigned_integers) + ) ) @@ -297,40 +327,57 @@ def test_if_else(condition: BooleanType, left: ScalarType, right: ScalarType): # List of not allowed operations -not_allowed_binary_operations = \ - ( # Arithmetic operations - combine_lists(combine_lists(booleans, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(booleans, integers), binary_arithmetic_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), binary_arithmetic_functions) - + combine_lists(combine_lists(integers, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(integers, unsigned_integers), binary_arithmetic_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(unsigned_integers, integers), binary_arithmetic_functions) - # Relational operations - + combine_lists(combine_lists(booleans, booleans), binary_relational_functions) - + combine_lists(combine_lists(booleans, integers), binary_relational_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), binary_relational_functions) - + combine_lists(combine_lists(integers, booleans), binary_relational_functions) - + combine_lists(combine_lists(integers, unsigned_integers), binary_relational_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), binary_relational_functions) - + combine_lists(combine_lists(unsigned_integers, integers), binary_relational_functions) - # Equals operations - + combine_lists(combine_lists(booleans, integers), equals_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), equals_functions) - + combine_lists(combine_lists(integers, booleans), equals_functions) - + combine_lists(combine_lists(integers, unsigned_integers), equals_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), equals_functions) - + combine_lists(combine_lists(unsigned_integers, integers), equals_functions) - # Logic operations - + combine_lists(combine_lists(booleans, integers), logic_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), logic_functions) - + combine_lists(combine_lists(integers, booleans), logic_functions) - + combine_lists(combine_lists(integers, integers), logic_functions) - + combine_lists(combine_lists(integers, unsigned_integers), logic_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), logic_functions) - + combine_lists(combine_lists(unsigned_integers, integers), logic_functions) - + combine_lists(combine_lists(unsigned_integers, unsigned_integers), logic_functions) +not_allowed_binary_operations = ( # Arithmetic operations + combine_lists(combine_lists(booleans, booleans), binary_arithmetic_functions) + + combine_lists(combine_lists(booleans, integers), binary_arithmetic_functions) + + combine_lists( + combine_lists(booleans, unsigned_integers), binary_arithmetic_functions + ) + + combine_lists(combine_lists(integers, booleans), binary_arithmetic_functions) + + combine_lists( + combine_lists(integers, unsigned_integers), binary_arithmetic_functions + ) + + combine_lists( + combine_lists(unsigned_integers, booleans), binary_arithmetic_functions + ) + + combine_lists( + combine_lists(unsigned_integers, integers), binary_arithmetic_functions + ) + # Relational operations + + combine_lists(combine_lists(booleans, booleans), binary_relational_functions) + + combine_lists(combine_lists(booleans, integers), binary_relational_functions) + + combine_lists( + combine_lists(booleans, unsigned_integers), binary_relational_functions + ) + + combine_lists(combine_lists(integers, booleans), binary_relational_functions) + + combine_lists( + combine_lists(integers, unsigned_integers), binary_relational_functions + ) + + combine_lists( + combine_lists(unsigned_integers, booleans), binary_relational_functions ) + + combine_lists( + combine_lists(unsigned_integers, integers), binary_relational_functions + ) + # Equals operations + + combine_lists(combine_lists(booleans, integers), equals_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), equals_functions) + + combine_lists(combine_lists(integers, booleans), equals_functions) + + combine_lists(combine_lists(integers, unsigned_integers), equals_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), equals_functions) + + combine_lists(combine_lists(unsigned_integers, integers), equals_functions) + # Logic operations + + combine_lists(combine_lists(booleans, integers), logic_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), logic_functions) + + combine_lists(combine_lists(integers, booleans), logic_functions) + + combine_lists(combine_lists(integers, integers), logic_functions) + + combine_lists(combine_lists(integers, unsigned_integers), logic_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), logic_functions) + + combine_lists(combine_lists(unsigned_integers, integers), logic_functions) + + combine_lists( + combine_lists(unsigned_integers, unsigned_integers), logic_functions + ) +) @pytest.mark.parametrize("left, right, operation", not_allowed_binary_operations) @@ -342,38 +389,40 @@ def test_not_allowed_binary_operations(left, right, operation): # List of operands that the operation power does not accept. not_allowed_pow = ( - combine_lists(booleans, booleans) - + combine_lists(integers, booleans) - + combine_lists(unsigned_integers, booleans) - + combine_lists(booleans, integers) - + combine_lists(secret_integers, integers) - + combine_lists(public_integers, secret_integers) - + combine_lists(integers, unsigned_integers) - + combine_lists(booleans, unsigned_integers) - + combine_lists(unsigned_integers, integers) - + combine_lists(secret_unsigned_integers, unsigned_integers) - + combine_lists(public_unsigned_integers, secret_unsigned_integers) + combine_lists(booleans, booleans) + + combine_lists(integers, booleans) + + combine_lists(unsigned_integers, booleans) + + combine_lists(booleans, integers) + + combine_lists(secret_integers, integers) + + combine_lists(public_integers, secret_integers) + + combine_lists(integers, unsigned_integers) + + combine_lists(booleans, unsigned_integers) + + combine_lists(unsigned_integers, integers) + + combine_lists(secret_unsigned_integers, unsigned_integers) + + combine_lists(public_unsigned_integers, secret_unsigned_integers) ) @pytest.mark.parametrize("left, right", not_allowed_pow) def test_not_allowed_pow(left, right): with pytest.raises(Exception) as invalid_operation: - left ** right + left**right assert invalid_operation.type == TypeError # List of operands that the shift operation do not accept. not_allowed_shift = ( - combine_lists(combine_lists(booleans, booleans), shift_functions) - + combine_lists(combine_lists(integers, booleans), shift_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), shift_functions) - + combine_lists(combine_lists(booleans, integers), shift_functions) - + combine_lists(combine_lists(integers, integers), shift_functions) - + combine_lists(combine_lists(unsigned_integers, integers), shift_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), shift_functions) - + combine_lists(combine_lists(integers, secret_unsigned_integers), shift_functions) - + combine_lists(combine_lists(unsigned_integers, secret_unsigned_integers), shift_functions) + combine_lists(combine_lists(booleans, booleans), shift_functions) + + combine_lists(combine_lists(integers, booleans), shift_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), shift_functions) + + combine_lists(combine_lists(booleans, integers), shift_functions) + + combine_lists(combine_lists(integers, integers), shift_functions) + + combine_lists(combine_lists(unsigned_integers, integers), shift_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), shift_functions) + + combine_lists(combine_lists(integers, secret_unsigned_integers), shift_functions) + + combine_lists( + combine_lists(unsigned_integers, secret_unsigned_integers), shift_functions + ) ) @@ -385,14 +434,25 @@ def test_not_allowed_shift(left, right, operation): # List of operands that the public_equals function does not accept. -not_allowed_public_equals_operands = (combine_lists(variable_integers, variable_unsigned_integers) - + combine_lists(variable_unsigned_integers, variable_integers)) +not_allowed_public_equals_operands = combine_lists( + variable_integers, variable_unsigned_integers +) + combine_lists(variable_unsigned_integers, variable_integers) @pytest.mark.parametrize("left, right", not_allowed_public_equals_operands) def test_not_allowed_public_equals( - left: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] - , right: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] + left: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], + right: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], ): with pytest.raises(Exception) as invalid_operation: left.public_equals(right) @@ -412,17 +472,17 @@ def test_not_allowed_invert_operations(operand): # List of operands that the probabilistic truncation does not accept. not_allowed_trunc_pr_operands = ( - combine_lists(booleans, booleans) - + combine_lists(integers, booleans) - + combine_lists(unsigned_integers, booleans) - + combine_lists(booleans, integers) - + combine_lists(integers, integers) - + combine_lists(unsigned_integers, integers) - + combine_lists(booleans, unsigned_integers) - + combine_lists(integers, secret_unsigned_integers) - + combine_lists(public_integers, public_unsigned_integers) - + combine_lists(unsigned_integers, secret_unsigned_integers) - + combine_lists(public_unsigned_integers, public_unsigned_integers) + combine_lists(booleans, booleans) + + combine_lists(integers, booleans) + + combine_lists(unsigned_integers, booleans) + + combine_lists(booleans, integers) + + combine_lists(integers, integers) + + combine_lists(unsigned_integers, integers) + + combine_lists(booleans, unsigned_integers) + + combine_lists(integers, secret_unsigned_integers) + + combine_lists(public_integers, public_unsigned_integers) + + combine_lists(unsigned_integers, secret_unsigned_integers) + + combine_lists(public_unsigned_integers, public_unsigned_integers) ) @@ -430,7 +490,9 @@ def test_not_allowed_invert_operations(operand): def test_not_allowed_trunc_pr(left, right): with pytest.raises(Exception) as invalid_operation: left.trunc_pr(right) - assert invalid_operation.type == TypeError or invalid_operation.type == AttributeError + assert ( + invalid_operation.type == TypeError or invalid_operation.type == AttributeError + ) # List of types that cannot generate a random value @@ -444,36 +506,36 @@ def test_not_allowed_random(operand): assert invalid_operation.type == AttributeError -# List of types that cannot invoke the function to_public() -to_public_operands = public_booleans + public_integers + public_unsigned_integers - - -@pytest.mark.parametrize("operand", to_public_operands) -def test_not_to_public(operand): - with pytest.raises(Exception) as invalid_operation: - operand.to_public() - assert invalid_operation.type == AttributeError - - # List of operands that the function if_else does not accept not_allowed_if_else_operands = ( - # Boolean branches - combine_lists(booleans, combine_lists(booleans, booleans)) - # Branches with different types - + combine_lists(booleans, combine_lists(integers, booleans)) - + combine_lists(booleans, combine_lists(unsigned_integers, booleans)) - + combine_lists(booleans, combine_lists(booleans, integers)) - + combine_lists(booleans, combine_lists(unsigned_integers, integers)) - + combine_lists(booleans, combine_lists(booleans, unsigned_integers)) - + combine_lists(booleans, combine_lists(integers, unsigned_integers)) - # The condition is a literal - + combine_lists([Boolean(value=True)], combine_lists(integers, integers)) - + combine_lists([Boolean(value=True)], combine_lists(unsigned_integers, unsigned_integers)) + # Boolean branches + combine_lists(booleans, combine_lists(booleans, booleans)) + # Branches with different types + + combine_lists(booleans, combine_lists(integers, booleans)) + + combine_lists(booleans, combine_lists(unsigned_integers, booleans)) + + combine_lists(booleans, combine_lists(booleans, integers)) + + combine_lists(booleans, combine_lists(unsigned_integers, integers)) + + combine_lists(booleans, combine_lists(booleans, unsigned_integers)) + + combine_lists(booleans, combine_lists(integers, unsigned_integers)) + # The condition is a literal + + combine_lists([Boolean(value=True)], combine_lists(integers, integers)) + + combine_lists( + [Boolean(value=True)], combine_lists(unsigned_integers, unsigned_integers) + ) ) @pytest.mark.parametrize("condition, left, right", not_allowed_if_else_operands) -def test_if_else(condition: BooleanType, left: ScalarType, right: ScalarType): +def test_if_else(condition: BooleanType, left, right): with pytest.raises(Exception) as invalid_operation: condition.if_else(left, right) assert invalid_operation.type == TypeError + + +@pytest.mark.skip( + reason="This is just a test to verify that the typechecker works okay" +) +def test2_if_else(int1: SecretInteger, int2: SecretInteger) -> PublicInteger: + condition = int1 > int2 + if_else_result = condition.if_else(int1, int2) + return if_else_result.to_public() diff --git a/pyproject.toml b/pyproject.toml index d61485b..945f39f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,13 @@ version = "0.7.0" description = "Nillion Nada DSL to create Nillion MPC programs." requires-python = ">=3.10" readme = "README.pyproject.md" -dependencies = ["asttokens~=2.4", "richreports~=0.2", "parsial~=0.1", "sortedcontainers~=2.4"] +dependencies = [ + "asttokens~=2.4", + "richreports~=0.2", + "parsial~=0.1", + "sortedcontainers~=2.4", + "typing_extensions~=4.12", +] classifiers = ["License :: OSI Approved :: Apache Software License"] license = { file = "LICENSE" } @@ -32,8 +38,10 @@ dev-dependencies = [ "pytest-cov>=4,<6", "pylint>=2.17,<3.4", "nada-mir-proto[dev]", - "tomli", - "requests", + "tomli~=2.0", + "requests~=2.32", + "pyright~=1.1", + "typing_extensions~=4.12", ] [tool.uv.sources] @@ -41,3 +49,8 @@ nada-mir-proto = { workspace = true } [tool.uv.workspace] members = ["nada_mir"] + +[tool.pyright] +venvPath = "." +venv = ".venv" +exclude = ["**/__pycache__", "dist/", "test-programs", "**/build", ".venv/", "nada_dsl/*_test.py"] diff --git a/uv.lock b/uv.lock index 0692ce9..c42c444 100644 --- a/uv.lock +++ b/uv.lock @@ -646,6 +646,7 @@ test = [ dev = [ { name = "nada-mir-proto", extra = ["dev"] }, { name = "pylint" }, + { name = "pyright" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "requests" }, @@ -673,13 +674,14 @@ requires-dist = [ dev = [ { name = "nada-mir-proto", extras = ["dev"], editable = "nada_mir" }, { name = "pylint", specifier = ">=2.17,<3.4" }, + { name = "pyright", specifier = "~=1.1" }, { name = "pytest", specifier = ">=7.4,<9.0" }, { name = "pytest-cov", specifier = ">=4,<6" }, - { name = "requests" }, + { name = "requests", specifier = "~=2.32" }, { name = "sphinx", specifier = ">=5,<9" }, { name = "sphinx-rtd-theme", specifier = ">=1.0,<3.1" }, { name = "toml", specifier = "~=0.10.2" }, - { name = "tomli" }, + { name = "tomli", specifier = "~=2.0" }, ] [[package]] @@ -703,6 +705,15 @@ requires-dist = [ { name = "grpcio-tools", specifier = "==1.62.3" }, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, +] + [[package]] name = "packaging" version = "24.1" @@ -790,6 +801,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4d/11/4a3f814eee14593f3cfcf7046bc765bf1646d5c88132c08c45310fc7d85f/pylint-3.3.1-py3-none-any.whl", hash = "sha256:2f846a466dd023513240bc140ad2dd73bfc080a5d85a710afdb728c420a5a2b9", size = 521768 }, ] +[[package]] +name = "pyright" +version = "1.1.388" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9c/83/e9867538a794638d2d20ac3ab3106a31aca1d9cfea530c9b2921809dae03/pyright-1.1.388.tar.gz", hash = "sha256:0166d19b716b77fd2d9055de29f71d844874dbc6b9d3472ccd22df91db3dfa34", size = 21939 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/57/7fb00363b7f267a398c5bdf4f55f3e64f7c2076b2e7d2901b3373d52b6ff/pyright-1.1.388-py3-none-any.whl", hash = "sha256:c7068e9f2c23539c6ac35fc9efac6c6c1b9aa5a0ce97a9a8a6cf0090d7cbf84c", size = 18579 }, +] + [[package]] name = "pytest" version = "8.3.3"