Skip to content

Commit

Permalink
fix: if_else then to_public correctly typed
Browse files Browse the repository at this point in the history
  • Loading branch information
navasvarela committed Nov 8, 2024
1 parent ec3406e commit cd5f296
Show file tree
Hide file tree
Showing 4 changed files with 370 additions and 195 deletions.
174 changes: 125 additions & 49 deletions nada_dsl/nada_types/scalar_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# pylint:disable=W0401,W0614
"""The Nada Scalar type definitions."""

from __future__ import annotations
from dataclasses import dataclass
from typing import Union
from typing import TypeVar, Union, overload
from typing_extensions import Self
from nada_dsl.operations import *
from nada_dsl.program_io import Literal
from nada_dsl import SourceRef
Expand All @@ -14,6 +15,9 @@
SCALAR_TYPES = {}


BooleanTypes = Union["PublicBoolean", "SecretBoolean", "Boolean"]


class ScalarType(NadaType):
"""The Nada Scalar type.
This is the super class for all scalar types in Nada.
Expand All @@ -33,7 +37,7 @@ def __init__(self, inner: OperationType, base_type: BaseType, mode: Mode):
self.base_type = base_type
self.mode = mode

def __eq__(self, other):
def __eq__(self, other) -> BooleanType:
return equals_operation(
"Equals", "==", self, other, lambda lhs, rhs: lhs == rhs
)
Expand All @@ -43,10 +47,17 @@ def __ne__(self, other):
"NotEquals", "!=", self, other, lambda lhs, rhs: lhs != rhs
)

def to_public(self) -> Self:
"""Convert this scalar type into a public variable.
This is the default implementation, to be overriden by subclasses.
The default behaviour is to return the same type, but for secret types it will invoke the Reveal operation.
"""
return self


def equals_operation(
operation, operator, left: ScalarType, right: ScalarType, f
) -> ScalarType:
) -> BooleanType:
"""This function is an abstraction for the equality operations"""
base_type = left.base_type
if base_type != right.base_type:
Expand All @@ -70,18 +81,23 @@ def equals_operation(
def register_scalar_type(mode: Mode, base_type: BaseType):
"""Decorator used to register a new scalar type in the `SCALAR_TYPES` dictionary."""

def decorator(scalar_type: ScalarType):
SCALAR_TYPES[(mode, base_type)] = scalar_type
scalar_type.mode = mode
scalar_type.base_type = base_type
return scalar_type
def decorator(cls):
SCALAR_TYPES[(mode, base_type)] = cls
cls.mode = mode
cls.base_type = base_type
return cls

return decorator


def new_scalar_type(mode: Mode, base_type: BaseType) -> ScalarType:
def new_scalar_type(mode: Mode, base_type: BaseType, *args) -> ScalarType:
"""Returns the corresponding MIR Nada Type"""
return SCALAR_TYPES.get((mode, base_type))
scalar_type = SCALAR_TYPES.get((mode, base_type))
if scalar_type is None:
raise TypeError(
f"scalar type not found for mode={mode} and base_type={base_type}"
)
return scalar_type(args[0])


class NumericType(ScalarType):
Expand All @@ -94,6 +110,16 @@ class NumericType(ScalarType):

value: int

@classmethod
def new(cls, mode: Mode, base_type: BaseType, *args, **kwargs) -> Self:
"""Returns the corresponding MIR Boolean Type"""
scalar_type = SCALAR_TYPES.get((mode, base_type))
if scalar_type is None:
raise TypeError(
f"scalar type not found for mode={mode} and base_type={base_type}"
)
return scalar_type(args, kwargs)

def __add__(self, other):
return binary_arithmetic_operation(
"Addition", "+", self, other, lambda lhs, rhs: lhs + rhs
Expand Down Expand Up @@ -127,10 +153,10 @@ def __pow__(self, other):
raise TypeError(f"Invalid operation: {self} ** {other}")
mode = Mode(max([self.mode.value, other.mode.value]))
if mode == Mode.CONSTANT:
return new_scalar_type(mode, base_type)(self.value**other.value)
return new_scalar_type(mode, base_type, self.value**other.value)
if mode == Mode.PUBLIC:
inner = Power(left=self, right=other, source_ref=SourceRef.back_frame())
return new_scalar_type(mode, base_type)(inner)
return new_scalar_type(mode, base_type, inner)
raise TypeError(f"Invalid operation: {self} ** {other}")

def __lshift__(self, other):
Expand Down Expand Up @@ -166,14 +192,15 @@ def __ge__(self, other):
def __radd__(self, other):
"""This adds support for builtin `sum` operation for numeric types."""
if isinstance(other, int):
other_type = new_scalar_type(mode=Mode.CONSTANT, base_type=self.base_type)(
other
)
other_type = new_scalar_type(Mode.CONSTANT, self.base_type, other)
return self.__add__(other_type)

return self.__add__(other)


NUMBER = TypeVar("NUMBER", bound=NumericType)


def binary_arithmetic_operation(
operation, operator, left: ScalarType, right: ScalarType, f
) -> ScalarType:
Expand All @@ -186,12 +213,12 @@ def binary_arithmetic_operation(
mode = Mode(max([left.mode.value, right.mode.value]))
match mode:
case Mode.CONSTANT:
return new_scalar_type(mode, base_type)(f(left.value, right.value))
return new_scalar_type(mode, base_type, f(left.value, right.value))
case Mode.PUBLIC | Mode.SECRET:
inner = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, base_type)(inner)
return new_scalar_type(mode, base_type, inner)


def shift_operation(
Expand All @@ -208,33 +235,15 @@ def shift_operation(
mode = Mode(max([left.mode.value, right_mode.value]))
match mode:
case Mode.CONSTANT:
return new_scalar_type(mode, base_type)(f(left.value, right.value))
return new_scalar_type(mode, base_type, f(left.value, right.value))
case Mode.PUBLIC | Mode.SECRET:
inner = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, base_type)(inner)
return new_scalar_type(mode, base_type, inner)


def binary_relational_operation(
operation, operator, left: ScalarType, right: ScalarType, f
) -> ScalarType:
"""This function is an abstraction for the binary relational operations"""
base_type = left.base_type
if base_type != right.base_type or not base_type.is_numeric():
raise TypeError(f"Invalid operation: {left} {operator} {right}")
mode = Mode(max([left.mode.value, right.mode.value]))
match mode:
case Mode.CONSTANT:
return new_scalar_type(mode, BaseType.BOOLEAN)(f(left.value, right.value))
case Mode.PUBLIC | Mode.SECRET:
inner = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame()
)
return new_scalar_type(mode, BaseType.BOOLEAN)(inner)


def public_equals_operation(left: ScalarType, right: ScalarType) -> ScalarType:
def public_equals_operation(left: ScalarType, right: ScalarType) -> PublicBoolean:
"""This function is an abstraction for the public_equals operation for all types."""
base_type = left.base_type
if base_type != right.base_type:
Expand All @@ -255,6 +264,14 @@ class BooleanType(ScalarType):
It provides common operation implementations for all the boolean types, defined above.
"""

@classmethod
def new(cls, mode: Mode, *args) -> BooleanTypes:
"""Returns the corresponding MIR Boolean Type"""
scalar_type = SCALAR_TYPES.get((mode, BaseType.BOOLEAN))
if scalar_type is None:
raise TypeError(f"scalar type not found for mode={mode}")
return scalar_type(args[0])

def __and__(self, other):
return binary_logical_operation(
"BooleanAnd", "&", self, other, lambda lhs, rhs: lhs & rhs
Expand All @@ -270,7 +287,30 @@ def __xor__(self, other):
"BooleanXor", "^", self, other, lambda lhs, rhs: lhs ^ rhs
)

def if_else(self, arg_0: ScalarType, arg_1: ScalarType) -> ScalarType:
# NOTE: These overloads are just for the type checker.
# They are a way to signal the different return types for if_else
@overload
def if_else(self, arg_0: SecretInteger, arg_1: SecretInteger) -> SecretInteger: ...

@overload
def if_else(self, arg_0: PublicInteger, arg_1: PublicInteger) -> SecretInteger: ...

@overload
def if_else(
self, arg_0: SecretUnsignedInteger, arg_1: SecretUnsignedInteger
) -> SecretUnsignedInteger: ...

@overload
def if_else(
self, arg_0: PublicUnsignedInteger, arg_1: PublicUnsignedInteger
) -> SecretUnsignedInteger: ...

def if_else(self, arg_0: NUMBER, arg_1: NUMBER) -> Union[
PublicInteger,
PublicUnsignedInteger,
SecretInteger,
SecretUnsignedInteger,
]:
"""This function implements the function 'if_else' for every class that extends 'BooleanType'."""
base_type = arg_0.base_type
if (
Expand All @@ -285,7 +325,37 @@ def if_else(self, arg_0: ScalarType, arg_1: ScalarType) -> ScalarType:
)
if mode == Mode.CONSTANT:
mode = Mode.PUBLIC
return new_scalar_type(mode, base_type)(inner)
result_type = NumericType.new(mode, base_type, inner)
if isinstance(
result_type,
Union[
PublicInteger,
PublicUnsignedInteger,
SecretInteger,
SecretUnsignedInteger,
],
):
return result_type

raise TypeError(f"Invalid result type: {result_type}")


def binary_relational_operation(
operation, operator, left: NUMBER, right: NUMBER, f
) -> BooleanTypes:
"""This function is an abstraction for the binary relational operations"""
base_type = left.base_type
if base_type != right.base_type or not base_type.is_numeric():
raise TypeError(f"Invalid operation: {left} {operator} {right}")
mode = Mode(max([left.mode.value, right.mode.value]))
match mode:
case Mode.CONSTANT:
return BooleanType.new(mode, (f(left.value, right.value)))
case Mode.PUBLIC | Mode.SECRET:
inner = globals()[operation](
left=left, right=right, source_ref=SourceRef.back_frame()
)
return BooleanType.new(mode, inner)


def binary_logical_operation(
Expand Down Expand Up @@ -395,6 +465,10 @@ def public_equals(
"""Implementation of public equality for Public integer types."""
return public_equals_operation(self, other)

def to_public(self) -> PublicInteger:
"""Convert this scalar type into a public variable."""
return self


@register_scalar_type(Mode.PUBLIC, BaseType.UNSIGNED_INTEGER)
class PublicUnsignedInteger(NumericType):
Expand All @@ -411,10 +485,14 @@ def __eq__(self, other):

def public_equals(
self, other: Union["PublicUnsignedInteger", "SecretUnsignedInteger"]
) -> "PublicBoolean":
) -> PublicBoolean:
"""Implementation of public equality for Public unsigned integer types."""
return public_equals_operation(self, other)

def to_public(self) -> PublicUnsignedInteger:
"""Convert this scalar type into a public variable."""
return self


@dataclass
@register_scalar_type(Mode.PUBLIC, BaseType.BOOLEAN)
Expand All @@ -430,7 +508,7 @@ def __init__(self, inner: NadaType):
def __eq__(self, other):
return ScalarType.__eq__(self, other)

def __invert__(self: "PublicBoolean") -> "PublicBoolean":
def __invert__(self: PublicBoolean) -> PublicBoolean:
operation = Not(this=self, source_ref=SourceRef.back_frame())
return PublicBoolean(inner=operation)

Expand All @@ -454,7 +532,7 @@ def __eq__(self, other):

def public_equals(
self, other: Union["PublicInteger", "SecretInteger"]
) -> "PublicBoolean":
) -> PublicBoolean:
"""Implementation of public equality for secret integer types."""
return public_equals_operation(self, other)

Expand All @@ -476,11 +554,11 @@ def trunc_pr(
raise TypeError(f"Invalid operation: {self}.trunc_pr({other})")

@classmethod
def random(cls) -> "SecretInteger":
def random(cls) -> SecretInteger:
"""Random operation for Secret integers."""
return SecretInteger(inner=Random(source_ref=SourceRef.back_frame()))

def to_public(self: "SecretInteger") -> "PublicInteger":
def to_public(self: SecretInteger) -> PublicInteger:
"""Convert this secret integer into a public variable."""
operation = Reveal(this=self, source_ref=SourceRef.back_frame())
return PublicInteger(inner=operation)
Expand Down Expand Up @@ -585,7 +663,5 @@ def __init__(self, inner: OperationType):
def ecdsa_sign(self, digest: "EcdsaDigestMessage") -> "EcdsaSignature":
"""Random operation for Secret integers."""
return EcdsaSignature(
inner=EcdsaSign(
left=self, right=digest, source_ref=SourceRef.back_frame()
)
inner=EcdsaSign(left=self, right=digest, source_ref=SourceRef.back_frame())
)
Loading

0 comments on commit cd5f296

Please sign in to comment.