diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 333e3ff..06b44da 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,7 +28,8 @@ jobs: version: "0.4.24" - name: Install dependencies run: uv sync --all-extras --dev - + - name: Format check + run: make format-check - name: Lint with pylint run: | uv run pylint nada_dsl/ diff --git a/Makefile b/Makefile index bb49512..3fd1614 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,16 @@ release: lint: uv run pylint nada_dsl/ +format: + uv run ruff format + +format-check: + uv run ruff format --check + +# TODO fix all check errors +ruff-check: + uv run ruff check + test-dependencies: pip install .'[test]' diff --git a/nada_dsl/ast_util.py b/nada_dsl/ast_util.py index c3d8b33..a25a2fa 100644 --- a/nada_dsl/ast_util.py +++ b/nada_dsl/ast_util.py @@ -1,4 +1,4 @@ -""" AST utilities.""" +"""AST utilities.""" from abc import ABC from dataclasses import dataclass @@ -94,7 +94,6 @@ def child_operations(self): return [self.child] def to_mir(self): - return { self.name: { "id": self.id, diff --git a/nada_dsl/audit/__init__.py b/nada_dsl/audit/__init__.py index 98738ac..ce465fe 100644 --- a/nada_dsl/audit/__init__.py +++ b/nada_dsl/audit/__init__.py @@ -1,20 +1,22 @@ """Export classes and functions for Nada DSL auditing component.""" + import argparse from nada_dsl.audit.abstract import * from nada_dsl.audit.report import html from nada_dsl.audit.strict import strict + def _main(): parser = argparse.ArgumentParser() - parser.add_argument('path', nargs=1, help='Nada DSL source file path') - parser.add_argument('--strict', action='store_true', required=True) + parser.add_argument("path", nargs=1, help="Nada DSL source file path") + parser.add_argument("--strict", action="store_true", required=True) args = parser.parse_args() path = args.path[0] - with open(path, 'r', encoding='UTF-8') as file: + with open(path, "r", encoding="UTF-8") as file: source = file.read() report = strict(source) - with open(path[:-2] + 'html', 'w', encoding='UTF-8') as file: + with open(path[:-2] + "html", "w", encoding="UTF-8") as file: file.write(html(report)) diff --git a/nada_dsl/audit/__main__.py b/nada_dsl/audit/__main__.py index f32948a..3cc0d56 100644 --- a/nada_dsl/audit/__main__.py +++ b/nada_dsl/audit/__main__.py @@ -1,3 +1,5 @@ """Execute the command line interface entry point.""" + from nada_dsl.audit import _main + _main() diff --git a/nada_dsl/audit/abstract.py b/nada_dsl/audit/abstract.py index 44e912c..8b7e1ee 100644 --- a/nada_dsl/audit/abstract.py +++ b/nada_dsl/audit/abstract.py @@ -2,6 +2,7 @@ Abstract interpreter and type definitions for the Nada DSL auditing component. """ + # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-lines # pylint: disable=too-few-public-methods @@ -9,17 +10,19 @@ from typing import Union, Tuple, Sequence import ast + class Metaclass(type): """ Metaclass for the :obj:`Constant`, :obj:`Public`, and :obj:`Secret` classes, enabling comparison of derived classes and conversions between them. """ + def shape(cls: type, cls_: type = None) -> type: """ Return the shape of a class that is an instance of this metaclass. - + >>> PublicInteger.shape().__name__ 'Public' >>> AbstractInteger.shape(Constant).__name__ @@ -73,6 +76,7 @@ def __le__(cls: type, other: type) -> bool: return True + class Constant(metaclass=Metaclass): """ Class from which classes for constants are derived. @@ -89,6 +93,7 @@ class Constant(metaclass=Metaclass): 'Constant' """ + class Public(metaclass=Metaclass): """ Class from which classes for public values are derived. @@ -103,6 +108,7 @@ class Public(metaclass=Metaclass): 'Public' """ + class Secret(metaclass=Metaclass): """ Class from which classes for constants are derived. @@ -117,11 +123,12 @@ class Secret(metaclass=Metaclass): 'Secret' """ + class Abstract(metaclass=Metaclass): """ Base class for abstract interpreter values. All more specific abstract value types are derived from this class. - + The attributes of this class are also used as global aggregators of the signature components (parties, inputs, and outputs) during abstract execution. @@ -135,6 +142,7 @@ class Abstract(metaclass=Metaclass): ... ] [['Party'], ['Input'], ['Output']] """ + # pylint: disable=missing-function-docstring parties = None inputs = None @@ -148,21 +156,14 @@ def initialize(context=None): Abstract.inputs = [] Abstract.outputs = [] Abstract.context = context if context is not None else {} - Abstract.analysis = { - 'add': 0, - 'mul': 0, - 'cmp': 0, - 'eq': 0, - 'ne': 0, - 'ife': 0 - } + Abstract.analysis = {"add": 0, "mul": 0, "cmp": 0, "eq": 0, "ne": 0, "ife": 0} @staticmethod def party(party: Party): Abstract.parties.append(party) @staticmethod - def input(input: Input): # pylint: disable=redefined-builtin + def input(input: Input): # pylint: disable=redefined-builtin Abstract.inputs.append(input) @staticmethod @@ -178,6 +179,7 @@ def __init__(self: Abstract, cls: type = None): if cls is not None: self.__class__ = cls + class Party(Abstract): """ Abstract interpreter values corresponding to parties. @@ -193,16 +195,18 @@ class Party(Abstract): ... TypeError: name parameter must be a string """ + def __init__(self: Party, name: str): super().__init__() type(self).party(self) if not isinstance(name, str): - raise TypeError('name parameter must be a string') + raise TypeError("name parameter must be a string") self.name = name + class Input(Abstract): """ Abstract interpreter values corresponding to inputs. @@ -222,16 +226,17 @@ class Input(Abstract): ... TypeError: party parameter must be a Party object """ + def __init__(self: Input, name: str, party: Party): super().__init__() type(self).input(self) if not isinstance(name, str): - raise TypeError('name parameter must be a string') + raise TypeError("name parameter must be a string") if not isinstance(party, Party): - raise TypeError('party parameter must be a Party object') + raise TypeError("party parameter must be a Party object") self.name = name self.party = party @@ -242,12 +247,13 @@ def _value(self) -> int: """ return self.context.get(self.name, None) + class Output(Abstract): """ Abstract interpreter values corresponding to outputs. >>> party = Party("party") - >>> input = Input("input", party) + >>> input = Input("input", party) >>> isinstance(Output(PublicInteger(input), "output", party), Output) True @@ -267,24 +273,25 @@ class Output(Abstract): ... TypeError: party parameter must be a Party object """ + def __init__( - self: Output, - value: Union[PublicInteger, SecretInteger], - name: str, - party: Party - ): + self: Output, + value: Union[PublicInteger, SecretInteger], + name: str, + party: Party, + ): super().__init__() type(self).output(self) if not isinstance(value, (PublicInteger, SecretInteger)): - raise TypeError('output value must be a PublicInteger or a SecretInteger') + raise TypeError("output value must be a PublicInteger or a SecretInteger") if not isinstance(name, str): - raise TypeError('name parameter must be a string') + raise TypeError("name parameter must be a string") if not isinstance(party, Party): - raise TypeError('party parameter must be a Party object') + raise TypeError("party parameter must be a Party object") self.value = value self.name = name @@ -295,29 +302,30 @@ def __init__( self.final = (type(self).parties, type(self).inputs, type(self).outputs) self.analysis = Abstract.analysis + class AbstractInteger(Abstract): """ Abstract interpreter values corresponding to all integers types. This class is only used by derived classes and is not exported. """ + def __init__( - self: Output, - input: Input = None, # pylint: disable=redefined-builtin - value: int = None - ): + self: Output, + input: Input = None, # pylint: disable=redefined-builtin + value: int = None, + ): super().__init__() self.input = input self.value = self.input._value() if input is not None else value if input is not None: - if not hasattr(input, '_type'): - setattr(input, '_type', None) + if not hasattr(input, "_type"): + setattr(input, "_type", None) input._type = type(self) def __add__( - self: AbstractInteger, - other: Union[int, AbstractInteger] - ) -> AbstractInteger: + self: AbstractInteger, other: Union[int, AbstractInteger] + ) -> AbstractInteger: """ Addition of abstract values that are instances of integer classes. The table below presents the output type for each combination @@ -374,17 +382,19 @@ def __add__( ... TypeError: expecting Integer, PublicInteger, or SecretInteger """ - if isinstance(other, int) and other == 0: # Base case for compatibility with :obj:`sum`. + if ( + isinstance(other, int) and other == 0 + ): # Base case for compatibility with :obj:`sum`. result = Abstract(type(self)) result.value = self.value return result if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") result = Abstract(max(type(self), type(other))) - Abstract.analysis['add'] += 1 + Abstract.analysis["add"] += 1 result.value = None if self.value is not None and other.value is not None: @@ -393,9 +403,8 @@ def __add__( return result def __radd__( - self: AbstractInteger, - other: Union[int, AbstractInteger] - ) -> AbstractInteger: + self: AbstractInteger, other: Union[int, AbstractInteger] + ) -> AbstractInteger: """ Addition for cases in which the left-hand argument is not an instance of a class derived from this class. @@ -407,7 +416,7 @@ def __sub__(self: AbstractInteger, other: AbstractInteger) -> AbstractInteger: Subtraction of abstract values that are instances of integer classes. """ if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") result = Abstract(max(type(self), type(other))) @@ -502,11 +511,11 @@ def __mul__(self: AbstractInteger, other: AbstractInteger) -> AbstractInteger: TypeError: expecting Integer, PublicInteger, or SecretInteger """ if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") result = Abstract(max(type(self), type(other))) - Abstract.analysis['mul'] += 1 + Abstract.analysis["mul"] += 1 result.value = None if self.value is not None and other.value is not None: @@ -571,11 +580,11 @@ def __lt__(self: AbstractInteger, other: AbstractInteger) -> AbstractInteger: TypeError: expecting Integer, PublicInteger, or SecretInteger """ if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") shape = max(type(self), type(other)).shape() result = (AbstractBoolean.shape(shape))() - Abstract.analysis['cmp'] += 1 + Abstract.analysis["cmp"] += 1 result.value = None if self.value is not None and other.value is not None: result.value = self.value < other.value @@ -588,12 +597,12 @@ def __le__(self: AbstractInteger, other: AbstractInteger) -> AbstractBoolean: classes. See :obj:`AbstractInteger.__lt__` for details and examples. """ if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") shape = max(type(self), type(other)).shape() result = (AbstractBoolean.shape(shape))() - Abstract.analysis['cmp'] += 1 + Abstract.analysis["cmp"] += 1 result.value = None if self.value is not None and other.value is not None: @@ -607,12 +616,12 @@ def __gt__(self: AbstractInteger, other: AbstractInteger) -> AbstractBoolean: classes. See :obj:`AbstractInteger.__lt__` for details and examples. """ if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") shape = max(type(self), type(other)).shape() result = (AbstractBoolean.shape(shape))() - Abstract.analysis['cmp'] += 1 + Abstract.analysis["cmp"] += 1 result.value = None if self.value is not None and other.value is not None: @@ -626,12 +635,12 @@ def __ge__(self: AbstractInteger, other: AbstractInteger) -> AbstractBoolean: classes. See :obj:`AbstractInteger.__lt__` for details and examples. """ if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") shape = max(type(self), type(other)).shape() result = (AbstractBoolean.shape(shape))() - Abstract.analysis['cmp'] += 1 + Abstract.analysis["cmp"] += 1 result.value = None if self.value is not None and other.value is not None: @@ -645,12 +654,12 @@ def __eq__(self: AbstractInteger, other: AbstractInteger) -> AbstractBoolean: classes. See :obj:`AbstractInteger.__lt__` for details and examples. """ if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") shape = max(type(self), type(other)).shape() result = (AbstractBoolean.shape(shape))() - Abstract.analysis['eq'] += 1 + Abstract.analysis["eq"] += 1 result.value = None if self.value is not None and other.value is not None: @@ -664,12 +673,12 @@ def __ne__(self: AbstractInteger, other: AbstractInteger) -> AbstractBoolean: classes. See :obj:`AbstractInteger.__lt__` for details and examples. """ if not isinstance(other, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") shape = max(type(self), type(other)).shape() result = (AbstractBoolean.shape(shape))() - Abstract.analysis['ne'] += 1 + Abstract.analysis["ne"] += 1 result.value = None if self.value is not None and other.value is not None: @@ -677,6 +686,7 @@ def __ne__(self: AbstractInteger, other: AbstractInteger) -> AbstractBoolean: return result + class Integer(AbstractInteger, Constant): """ Abstract values corresponding to constant integers. @@ -750,6 +760,7 @@ class Integer(AbstractInteger, Constant): TypeError: expecting Integer, PublicInteger, or SecretInteger """ + class PublicInteger(AbstractInteger, Public): """ Abstract values corresponding to public integers. @@ -821,6 +832,7 @@ class PublicInteger(AbstractInteger, Public): TypeError: expecting Integer, PublicInteger, or SecretInteger """ + class SecretInteger(AbstractInteger, Secret): """ Abstract interpreter values corresponding to secret integers. @@ -875,30 +887,30 @@ class SecretInteger(AbstractInteger, Secret): TypeError: expecting Integer, PublicInteger, or SecretInteger """ + class AbstractBoolean(Abstract): """ Abstract interpreter values corresponding to all boolean types. This class is only used by derived classes and is not exported. """ + def __init__( - self: Output, - input: Input = None, # pylint: disable=redefined-builtin - value: int = None - ): + self: Output, + input: Input = None, # pylint: disable=redefined-builtin + value: int = None, + ): super().__init__() self.input = input self.value = self.input._value() if input is not None else value if input is not None: - if not hasattr(input, '_type'): - setattr(input, '_type', None) + if not hasattr(input, "_type"): + setattr(input, "_type", None) input._type = type(self) def if_else( - self: AbstractBoolean, - true: AbstractInteger, - false: AbstractInteger - ) -> AbstractInteger: + self: AbstractBoolean, true: AbstractInteger, false: AbstractInteger + ) -> AbstractInteger: """ Ternary (*i.e.*, conditional) operator. The table below presents the output type for each combination of argument types. @@ -940,37 +952,45 @@ def if_else( TypeError: expecting Integer, PublicInteger, or SecretInteger """ if not isinstance(true, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") if not isinstance(false, (Integer, PublicInteger, SecretInteger)): - raise TypeError('expecting Integer, PublicInteger, or SecretInteger') + raise TypeError("expecting Integer, PublicInteger, or SecretInteger") shape = max([type(self), type(true), type(false)]).shape() result = (AbstractInteger.shape(shape))() - Abstract.analysis['ife'] += 1 + Abstract.analysis["ife"] += 1 result.value = None - if self.value is not None and true.value is not None and false.value is not None: + if ( + self.value is not None + and true.value is not None + and false.value is not None + ): result.value = true.value if self.value else false.value return result + class Boolean(AbstractBoolean, Constant): """ Abstract values corresponding to constant boolean values. """ + class PublicBoolean(AbstractBoolean, Public): """ Abstract values corresponding to public boolean values. """ + class SecretBoolean(AbstractBoolean, Secret): """ Abstract values corresponding to secret boolean values. """ + def signature(source: str) -> Tuple[list[Party], list[Input], list[Output]]: """ Return the signature of the supplied Nada program (represented as a @@ -1029,24 +1049,26 @@ def signature(source: str) -> Tuple[list[Party], list[Input], list[Output]]: root = ast.parse(source) if ( - len(root.body) == 0 or - not isinstance(root.body[0], ast.ImportFrom) or - len(root.body[0].names) != 1 or - root.body[0].names[0].name != '*' or - root.body[0].module != 'nada_dsl' + len(root.body) == 0 + or not isinstance(root.body[0], ast.ImportFrom) + or len(root.body[0].names) != 1 + or root.body[0].names[0].name != "*" + or root.body[0].module != "nada_dsl" ): - raise ValueError('first statement must be: from nada_dsl import *') + raise ValueError("first statement must be: from nada_dsl import *") # Adjust the import statement and add a statement that resets the static # class attributes being used for aggregation. - root.body[0].module = 'nada_dsl.audit' - #root.body.append(ast.Expr(ast.Call(ast.Name('nada_main', ast.Load()), [], []))) + root.body[0].module = "nada_dsl.audit" + # root.body.append(ast.Expr(ast.Call(ast.Name('nada_main', ast.Load()), [], []))) root.body.append( ast.Expr( ast.Call( - ast.Attribute(ast.Name('Abstract', ast.Load()), 'initialize', ast.Load()), + ast.Attribute( + ast.Name("Abstract", ast.Load()), "initialize", ast.Load() + ), + [], [], - [] ) ) ) @@ -1054,18 +1076,18 @@ def signature(source: str) -> Tuple[list[Party], list[Input], list[Output]]: # Execute the program (introducing the main function into the context). context = {} - exec(compile(root, '', 'exec'), context) # pylint: disable=exec-used - if 'nada_main' not in context: - raise ValueError('nada_main must be defined') + exec(compile(root, "", "exec"), context) # pylint: disable=exec-used + if "nada_main" not in context: + raise ValueError("nada_main must be defined") # Perform abstract execution of the main function and return the signature of # the result. - outputs = context['nada_main']() + outputs = context["nada_main"]() if ( - isinstance(outputs, Sequence) and - len(outputs) > 0 and - all(isinstance(output, Output) for output in outputs) + isinstance(outputs, Sequence) + and len(outputs) > 0 + and all(isinstance(output, Output) for output in outputs) ): return Abstract.signature() - raise ValueError('nada_main must return a sequence of outputs') + raise ValueError("nada_main must return a sequence of outputs") diff --git a/nada_dsl/audit/common.py b/nada_dsl/audit/common.py index 722cb76..c555ef9 100644 --- a/nada_dsl/audit/common.py +++ b/nada_dsl/audit/common.py @@ -2,44 +2,52 @@ Common classes and functions for the Nada DSL auditing component (used across different static analysis submodules). """ + # pylint: disable=wildcard-import,invalid-name # pylint: disable=too-few-public-methods from __future__ import annotations import ast + class Rule(Exception): """ Base class for violations of rules defined by static analysis submodules. """ + class RuleInAncestor: """ Rule attribute value that indicates that the Nada DSL rule attribute of an :obj:`ast` node is determined by an ancestor node's Nada DSL rule attribute. """ + class SyntaxRestriction(Rule): """ Base class for violations of syntax restrictions defined by static analysis submodules. """ + class TypeErrorRoot(TypeError): """ Class for type errors that are not caused by other type errors. """ + class TypeInParent: """ Type attribute value that indicates that the Nada DSL type attribute of an :obj:`ast` node is determined by the parent node's Nada DSL type attribute. """ + class Feedback: """ Feedback aggregator (used throughout the recursive static analysis algorithms in order to collect all created exceptions). """ + def __init__(self: Feedback): self.exceptions = [] @@ -47,6 +55,7 @@ def __call__(self: Feedback, exception: Exception): self.exceptions.append(exception) return exception + def typeerror_demote(t): """ Demote a direct-cause type error to a generic (possibly indirect) type error. @@ -56,13 +65,14 @@ def typeerror_demote(t): return t + def audits(node, key, value=None, default=None, delete=False): """ Set, update, or delete an :obj:`ast` node's static analysis attribute. """ # pylint: disable=protected-access - if not hasattr(node, '_audits'): - setattr(node, '_audits', {}) + if not hasattr(node, "_audits"): + setattr(node, "_audits", {}) if value is None: value = node._audits.get(key, default) @@ -74,6 +84,7 @@ def audits(node, key, value=None, default=None, delete=False): return None + def rules_no_restriction(a, recursive=False): """ Delete the rule attributes of an :obj:`ast` node (and possibly those of its @@ -81,9 +92,10 @@ def rules_no_restriction(a, recursive=False): """ if recursive: for a_ in ast.walk(a): - audits(a_, 'rules', delete=True) + audits(a_, "rules", delete=True) else: - audits(a, 'rules', delete=True) + audits(a, "rules", delete=True) + def unify(t_a, t_b): """ @@ -92,9 +104,9 @@ def unify(t_a, t_b): if t_a == t_b: return t_a - if t_a.__name__ == 'list' and t_b.__name__ == 'list': - if hasattr(t_a, '__args__') and len(t_a.__args__) == 1: - if hasattr(t_b, '__args__'): + if t_a.__name__ == "list" and t_b.__name__ == "list": + if hasattr(t_a, "__args__") and len(t_a.__args__) == 1: + if hasattr(t_b, "__args__"): if len(t_b.__args__) == 1: return unify(t_a.__args__[0], t_b.__args__[0]) return None diff --git a/nada_dsl/audit/report.py b/nada_dsl/audit/report.py index 6c65a7f..5ca61d1 100644 --- a/nada_dsl/audit/report.py +++ b/nada_dsl/audit/report.py @@ -2,6 +2,7 @@ Common functions that Nada DSL static analysis submodules can use to build interactive HTML reports. """ + # pylint: disable=wildcard-import,unused-wildcard-import,invalid-name from __future__ import annotations from typing import List, Tuple @@ -13,121 +14,135 @@ try: from nada_dsl.audit.abstract import * from nada_dsl.audit.common import * -except: # pylint: disable=bare-except # For Nada DSL Sandbox support. +except: # pylint: disable=bare-except # For Nada DSL Sandbox support. from abstract import * from common import * + def parse(source: str) -> Tuple[asttokens.ASTTokens, List[int]]: """ Parse a Python source string that represents a Nada DSL program and return both its abstract syntax tree and a list of which lines were skipped by the partial parser due to syntax errors. """ - lines = source.split('\n') + lines = source.split("\n") (_, slices) = parsial.parsial(ast.parse)(source) lines_ = [l[s] for (l, s) in zip(lines, slices)] skips = [i for i in range(len(lines)) if len(lines[i]) != len(lines_[i])] - return (asttokens.ASTTokens('\n'.join(lines_), parse=True), skips) + return (asttokens.ASTTokens("\n".join(lines_), parse=True), skips) + def locations(report_, asttokens_, a): """ Return the starting and ending locations corresponding to an :obj:`ast` node. """ - ((start_line, start_column), (end_line, end_column)) = \ + ((start_line, start_column), (end_line, end_column)) = ( asttokens_.get_text_positions(a, True) + ) # Skip any whitespace when determining the starting location. line = report_.lines[start_line - 1] - while line[start_column] == ' ' and start_column < len(line): + while line[start_column] == " " and start_column < len(line): start_column += 1 return ( richreports.location((start_line, start_column)), - richreports.location((end_line, end_column - 1)) + richreports.location((end_line, end_column - 1)), ) + def type_to_str(t): """ Convert a type and/or type error :obj:`ast` node attribute into a human-readable string. """ - if hasattr(t, '__name__'): - if t.__name__ == 'list': - if hasattr(t, '__args__'): - return 'list[' + type_to_str(t.__args__[0]) + ']' - return 'list' + if hasattr(t, "__name__"): + if t.__name__ == "list": + if hasattr(t, "__args__"): + return "list[" + type_to_str(t.__args__[0]) + "]" + return "list" return str(t.__name__) if isinstance(t, TypeError): - return str('TypeError: ' + str(t)) + return str("TypeError: " + str(t)) + + return str("TypeError: " + "type cannot be determined") - return str('TypeError: ' + 'type cannot be determined') def enrich_from_type(report_, type_, start, end): """ Enrich a range within a report according to the supplied type attribute. """ - if ( - type_ in ( - bool, int, str, range, - Party, Input, Output, - Integer, PublicInteger, SecretInteger, - Boolean, PublicBoolean, SecretBoolean - ) - or - ( - hasattr(type_, '__name__') and - type_.__name__ == 'list' - ) - ): + if type_ in ( + bool, + int, + str, + range, + Party, + Input, + Output, + Integer, + PublicInteger, + SecretInteger, + Boolean, + PublicBoolean, + SecretBoolean, + ) or (hasattr(type_, "__name__") and type_.__name__ == "list"): t_str = type_to_str(type_) report_.enrich( - start, end, - '', '', - True, - True + start, end, '', "", True, True ) if isinstance(type_, (TypeError, TypeErrorRoot)): report_.enrich( - start, end, - '', '', + start, + end, + '', + "", + True, True, - True ) + def enrich_syntaxrestriction(report_, r, start, end): """ Enrich a range within a report according to the supplied syntax restriction. """ report_.enrich( - start, end, - '', '', + start, + end, + '', + "", enrich_intermediate_lines=True, - skip_whitespace=True + skip_whitespace=True, ) report_.enrich( - start, end, + start, + end, '', - '', + "", enrich_intermediate_lines=True, - skip_whitespace=True + skip_whitespace=True, ) + def enrich_keyword(report_, start, length): """ Enrich a range within a report corresponding to a Python keyword. """ (start_line, start_column) = start report_.enrich( - (start_line, start_column), (start_line, start_column + length), - '', '', + (start_line, start_column), + (start_line, start_column + length), + '', + "", enrich_intermediate_lines=True, - skip_whitespace=True + skip_whitespace=True, ) + def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: """ Enrich a report containing the source code of a Nada DSL program using the @@ -135,11 +150,11 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: """ # pylint: disable=too-many-statements,too-many-branches for a in ast.walk(atok.tree): - r = audits(a, 'rules') - t = audits(a, 'types') + r = audits(a, "rules") + t = audits(a, "types") if isinstance(a, (ast.Assign, ast.AnnAssign)): - target = a.targets[0] if hasattr(a, 'targets') else a.target + target = a.targets[0] if hasattr(a, "targets") else a.target (start, end) = locations(report_, atok, target) if isinstance(r, SyntaxRestriction): enrich_syntaxrestriction(report_, r, start, end) @@ -147,13 +162,15 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: enrich_from_type(report_, t, start, end) t_str = ( type_to_str(t) - if not isinstance(t, TypeError) else - 'TypeError: ' + str(t) + if not isinstance(t, TypeError) + else "TypeError: " + str(t) ) report_.enrich( - start, end, - '', '', - True + start, + end, + '', + "", + True, ) elif isinstance(r, SyntaxRestriction): @@ -161,7 +178,7 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: enrich_syntaxrestriction(report_, r, start, end) elif isinstance(r, RuleInAncestor): - pass # This node will be wrapped by an ancestor's enrichment. + pass # This node will be wrapped by an ancestor's enrichment. elif isinstance(a, ast.ImportFrom): (start, _) = locations(report_, atok, a) @@ -173,11 +190,13 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: enrich_syntaxrestriction(report_, r, start, end) else: enrich_keyword(report_, start, 6) - t = audits(a.value, 'types') + t = audits(a.value, "types") report_.enrich( - start, start + (0, 6), - '', '', - True + start, + start + (0, 6), + '', + "", + True, ) elif isinstance(a, ast.For): @@ -186,10 +205,12 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: (_, start) = locations(report_, atok, a.target) (end, _) = locations(report_, atok, a.iter) report_.enrich( - start + (0, 1), end - (0, 1), - '', '', + start + (0, 1), + end - (0, 1), + '', + "", enrich_intermediate_lines=True, - skip_whitespace=True + skip_whitespace=True, ) elif isinstance(a, ast.FunctionDef): @@ -205,10 +226,12 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: (_, start) = locations(report_, atok, generator.target) (end, _) = locations(report_, atok, generator.iter) report_.enrich( - start + (0, 1), end - (0, 1), - '', '', + start + (0, 1), + end - (0, 1), + '', + "", enrich_intermediate_lines=True, - skip_whitespace=True + skip_whitespace=True, ) elif isinstance(a, ast.Call): @@ -218,9 +241,11 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: start = start + (0, 1) enrich_from_type(report_, t, start, end) report_.enrich( - start, end, - '', '', - True + start, + end, + '', + "", + True, ) elif isinstance(a, ast.BoolOp): @@ -229,13 +254,15 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: (_, end) = locations(report_, atok, left) (start, _) = locations(report_, atok, right) (start, end) = (end + (0, 1), start - (0, 1)) - report_.enrich(start, end, '', '', True, True) + report_.enrich(start, end, "", "", True, True) enrich_from_type(report_, t, start, end) report_.enrich( - start, end, - '', '', + start, + end, + '', + "", + True, True, - True ) elif isinstance(a, ast.BinOp): @@ -244,10 +271,12 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: (start, end) = (end + (0, 1), start - (0, 1)) enrich_from_type(report_, t, start, end) report_.enrich( - start, end, - '', '', + start, + end, + '', + "", + True, True, - True ) elif isinstance(a, ast.Compare): @@ -256,10 +285,12 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: (start, end) = (end + (0, 1), start - (0, 1)) enrich_from_type(report_, t, start, end) report_.enrich( - start, end, - '', '', + start, + end, + '', + "", + True, True, - True ) elif isinstance(a, ast.UnaryOp): @@ -268,11 +299,13 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: end = end - (0, 2) enrich_from_type(report_, t, start, end) if isinstance(a.op, ast.Not): - report_.enrich(start, end, '', '', True) + report_.enrich(start, end, "", "", True) report_.enrich( - start, end, - '', '', - True + start, + end, + '', + "", + True, ) elif isinstance(a, ast.Constant): @@ -282,9 +315,11 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: else: enrich_from_type(report_, t, start, end) report_.enrich( - start, end, - '', '', - True + start, + end, + '', + "", + True, ) elif isinstance(a, ast.Name): @@ -295,16 +330,21 @@ def enrich_fromaudits(report_: richreports.report, atok) -> richreports.report: if t is not None and not isinstance(t, TypeInParent): enrich_from_type(report_, t, start, end) report_.enrich( - start, end, - '', '', - True + start, + end, + '', + "", + True, ) + def html(report: richreports.report) -> str: """ Return a self-contained CSS/HTML document corresponding to a report. """ - head = ' ' + ''' + head = ( + " " + + """ - '''.strip() + '\n' + """.strip() + + "\n" + ) - script = ' ' + ''' + script = ( + " " + + """ - '''.strip() + '\n' + """.strip() + + "\n" + ) return ( - '\n' + - ' \n' + - head + - ' \n' + - ' \n
\n' + - report.render() + - '\n \n' + - script + - '\n' + "\n" + + " \n" + + head + + " \n" + + ' \n
\n' + + report.render() + + "\n \n" + + script + + "\n" ) diff --git a/nada_dsl/audit/strict.py b/nada_dsl/audit/strict.py index 2ed0b38..7e40e38 100644 --- a/nada_dsl/audit/strict.py +++ b/nada_dsl/audit/strict.py @@ -2,6 +2,7 @@ Static analysis submodule that defines the highly limited "strict" subset of the Nada DSL syntax. """ + # pylint: disable=wildcard-import,unused-wildcard-import,invalid-name from __future__ import annotations from typing import Callable @@ -12,11 +13,12 @@ from nada_dsl.audit.abstract import * from nada_dsl.audit.common import * from nada_dsl.audit.report import parse, enrich_fromaudits -except: # pylint: disable=bare-except # For Nada DSL Sandbox support. +except: # pylint: disable=bare-except # For Nada DSL Sandbox support. from abstract import * from common import * from report import parse, enrich_fromaudits + def _rules_restrictions_descendants(a): """ Remove all rule attributes that are redundant because they occur in a @@ -24,11 +26,12 @@ def _rules_restrictions_descendants(a): """ for a_ in ast.walk(a): if not isinstance(a_, (ast.Assign, ast.Assign)): - if isinstance(audits(a_, 'rules'), (SyntaxRestriction, RuleInAncestor)): + if isinstance(audits(a_, "rules"), (SyntaxRestriction, RuleInAncestor)): for a__ in ast.walk(a_): if a__ != a_: - audits(a__, 'rules', RuleInAncestor()) - audits(a__, 'types', delete=True) + audits(a__, "rules", RuleInAncestor()) + audits(a__, "types", delete=True) + def rules(a): """ @@ -39,28 +42,36 @@ def rules(a): for a_ in ast.walk(a): audits( a_, - 'rules', - SyntaxRestriction('use of this syntax is prohibited in strict mode') + "rules", + SyntaxRestriction("use of this syntax is prohibited in strict mode"), ) + def _types_base(t): """ Return a boolean value indicating whether the supplied type is a base type. """ return t in ( - bool, int, str, - Integer, PublicInteger, SecretInteger, - Boolean, PublicBoolean, SecretBoolean + bool, + int, + str, + Integer, + PublicInteger, + SecretInteger, + Boolean, + PublicBoolean, + SecretBoolean, ) + def _types_list_monomorphic(t): """ Return a boolean value indicating whether the supplied type represents a monomorphic list type (*i.e.*, a list type wherein the types of the items are fully specified). """ - if t.__name__ == 'list' and hasattr(t, '__args__') and len(t.__args__) == 1: + if t.__name__ == "list" and hasattr(t, "__args__") and len(t.__args__) == 1: return _types_list_monomorphic(t.__args__[0]) if _types_base(t): @@ -68,15 +79,17 @@ def _types_list_monomorphic(t): return False + def _types_list_monomorphic_depth(t): """ Return an integer representing the depth of the monomorphic list type. """ - if t.__name__ == 'list' and hasattr(t, '__args__') and len(t.__args__) == 1: + if t.__name__ == "list" and hasattr(t, "__args__") and len(t.__args__) == 1: return 1 + _types_list_monomorphic_depth(t.__args__[0]) return 0 + def _types_monomorphic(t): """ Return a boolean value indicating whether the supplied type represents a @@ -90,7 +103,7 @@ def _types_binop_mult_add_sub(t_l, t_r): Determine the result type for multiplication and addition operations involving Nada DSL integers. """ - t = TypeErrorRoot('arguments must have integer types') + t = TypeErrorRoot("arguments must have integer types") if (t_l, t_r) == (Integer, Integer): t = Integer elif (t_l, t_r) == (Integer, PublicInteger): @@ -117,12 +130,13 @@ def _types_binop_mult_add_sub(t_l, t_r): return t + def _types_compare(t_l, t_r): """ Determine the result type for comparison operations involving Nada DSL integers. """ - t = TypeErrorRoot('arguments must have integer types') + t = TypeErrorRoot("arguments must have integer types") if (t_l, t_r) == (Integer, Integer): t = Boolean elif (t_l, t_r) == (Integer, PublicInteger): @@ -149,6 +163,7 @@ def _types_compare(t_l, t_r): return t + def types(a, env=None, func=False): """ Infer types of :obj:`ast` where possible, adding the type (or error) @@ -174,16 +189,18 @@ def types(a, env=None, func=False): if isinstance(a, ast.ImportFrom): if ( - a.module == 'nada_dsl' and - len(a.names) == 1 and a.names[0].name == '*' and a.names[0].asname is None and - a.level == 0 + a.module == "nada_dsl" + and len(a.names) == 1 + and a.names[0].name == "*" + and a.names[0].asname is None + and a.level == 0 ): rules_no_restriction(a, recursive=True) return env if isinstance(a, ast.FunctionDef): if not func: - if a.name == 'nada_main': + if a.name == "nada_main": rules_no_restriction(a) rules_no_restriction(a.args) @@ -197,7 +214,7 @@ def types(a, env=None, func=False): rules_no_restriction(a) rules_no_restriction(a.args) - t_ret = eval(ast.unparse(a.returns)) # pylint: disable=eval-used + t_ret = eval(ast.unparse(a.returns)) # pylint: disable=eval-used if _types_monomorphic(t_ret): rules_no_restriction(a.returns) @@ -205,7 +222,7 @@ def types(a, env=None, func=False): ts = [] for arg in a.args.args: var = arg.arg - t_var = eval(ast.unparse(arg.annotation)) # pylint: disable=eval-used + t_var = eval(ast.unparse(arg.annotation)) # pylint: disable=eval-used if _types_monomorphic(t_var): rules_no_restriction(arg) rules_no_restriction(arg.annotation, recursive=True) @@ -224,16 +241,16 @@ def types(a, env=None, func=False): rules_no_restriction(a) rules_no_restriction(target) types(a.value, env, func) - t = audits(a.value, 'types') + t = audits(a.value, "types") if t == list and not _types_list_monomorphic(t): t = TypeErrorRoot( - 'assignment of list value with underspecified ' + - 'type requires fully specified type annotation' - ) - audits(a, 'types', t) + "assignment of list value with underspecified " + + "type requires fully specified type annotation" + ) + audits(a, "types", t) elif t is not None: - audits(a, 'types', typeerror_demote(t)) - audits(target, 'types', TypeInParent()) + audits(a, "types", typeerror_demote(t)) + audits(target, "types", TypeInParent()) if not isinstance(t, TypeError): if isinstance(target, ast.Name): var = target.id @@ -241,7 +258,7 @@ def types(a, env=None, func=False): elif isinstance(target, ast.Subscript): rules_no_restriction(a) types(a.value, env, func) - t = audits(a.value, 'types') + t = audits(a.value, "types") target_ = target invalid_index = False @@ -250,7 +267,7 @@ def types(a, env=None, func=False): depth += 1 rules_no_restriction(target_) types(target_.slice, env, func) - t_s = audits(target_.slice, 'types') + t_s = audits(target_.slice, "types") if t_s != int: invalid_index = True break @@ -258,7 +275,7 @@ def types(a, env=None, func=False): target_ = target_.value if invalid_index: - audits(a, 'types', TypeErrorRoot('indices must be integers')) + audits(a, "types", TypeErrorRoot("indices must be integers")) if isinstance(target_, ast.Name): rules_no_restriction(target) @@ -266,11 +283,17 @@ def types(a, env=None, func=False): if target_.id in env: t_b = env[target_.id] if _types_list_monomorphic_depth(t_b) < depth: - audits(a, 'types', TypeErrorRoot('target has incompatible type')) + audits( + a, + "types", + TypeErrorRoot("target has incompatible type"), + ) else: - audits(a, 'types', typeerror_demote(t)) + audits(a, "types", typeerror_demote(t)) else: - audits(a, 'types', TypeErrorRoot('unbound variable: ' + target_.id)) + audits( + a, "types", TypeErrorRoot("unbound variable: " + target_.id) + ) return env @@ -279,34 +302,36 @@ def types(a, env=None, func=False): rules_no_restriction(a) rules_no_restriction(a.target) types(a.value, env, func) - t = audits(a.value, 'types') + t = audits(a.value, "types") try: - t_a = eval(ast.unparse(a.annotation)) # pylint: disable=eval-used + t_a = eval(ast.unparse(a.annotation)) # pylint: disable=eval-used rules_no_restriction(a.annotation, recursive=True) if not _types_list_monomorphic(t_a): audits( a.annotation, - 'types', + "types", TypeErrorRoot( - 'assignment of list value requires fully specified type annotation' - ) + "assignment of list value requires fully specified type annotation" + ), ) - except: # pylint: disable=bare-except - t_a = TypeErrorRoot('invalid type annotation') + except: # pylint: disable=bare-except + t_a = TypeErrorRoot("invalid type annotation") if isinstance(t, TypeError): - audits(a, 'types', t) - audits(a.target, 'types', t) + audits(a, "types", t) + audits(a.target, "types", t) else: t_u = unify(t_a, t) if t_u is None: - t = TypeErrorRoot('value type cannot be reconciled with type annotation') - audits(a, 'types', t) - audits(a.target, 'types', t) + t = TypeErrorRoot( + "value type cannot be reconciled with type annotation" + ) + audits(a, "types", t) + audits(a.target, "types", t) else: t = t_u - audits(a, 'types', t) - audits(a.target, 'types', TypeInParent()) + audits(a, "types", t) + audits(a.target, "types", TypeInParent()) var = a.target.id env[var] = t @@ -321,29 +346,29 @@ def types(a, env=None, func=False): if isinstance(a, ast.For): rules_no_restriction(a) types(a.iter, env, func) - t_i = audits(a.iter, 'types') + t_i = audits(a.iter, "types") if isinstance(a.target, ast.Name): rules_no_restriction(a.target) var = a.target.id - audits(a.target, 'types', int) + audits(a.target, "types", int) if isinstance(t_i, TypeError): - pass # Allow the error to pass through. + pass # Allow the error to pass through. elif t_i == range: env[var] = int for a_ in a.body: env = types(a_, env, func) else: - audits(a.iter, 'types', TypeErrorRoot('iterable must be a range')) + audits(a.iter, "types", TypeErrorRoot("iterable must be a range")) return env if isinstance(a, ast.Expr): types(a.value, env, func) - if not isinstance(audits(a.value, 'rules'), SyntaxRestriction): + if not isinstance(audits(a.value, "rules"), SyntaxRestriction): rules_no_restriction(a) return env # Handle cases in which the input node is an expression. - audits(a, 'types', TypeError('type cannot be determined')) + audits(a, "types", TypeError("type cannot be determined")) if isinstance(a, ast.ListComp): rules_no_restriction(a) @@ -351,220 +376,216 @@ def types(a, env=None, func=False): for comprehension in a.generators: rules_no_restriction(comprehension) types(comprehension.iter, env, func) - t_c = audits(comprehension.iter, 'types') + t_c = audits(comprehension.iter, "types") if isinstance(comprehension.target, ast.Name): rules_no_restriction(comprehension.target) var = comprehension.target.id t = int if isinstance(t_c, TypeError): - pass # Allow the error to pass through. + pass # Allow the error to pass through. elif t_c == range: rules_no_restriction(comprehension.iter) ts[var] = t else: - t = TypeErrorRoot('iterable must be a range value') - audits(comprehension.iter, 'types', t) + t = TypeErrorRoot("iterable must be a range value") + audits(comprehension.iter, "types", t) t = typeerror_demote(t) - audits(comprehension.target, 'types', t) + audits(comprehension.target, "types", t) env_ = dict(env) - for (var, t_) in ts.items(): + for var, t_ in ts.items(): env_[var] = t_ types(a.elt, env_, func) - t_e = audits(a.elt, 'types') + t_e = audits(a.elt, "types") if t_e is not None and not isinstance(t_e, TypeError): - audits(a, 'types', list[t_e]) + audits(a, "types", list[t_e]) elif isinstance(a, ast.Call): ats = [] kts = [] for a_ in a.args: types(a_, env, func) - ats.append(audits(a_, 'types')) + ats.append(audits(a_, "types")) for a_ in a.keywords: types(a_.value, env, func) - kts.append(audits(a_.value, 'types')) + kts.append(audits(a_.value, "types")) if isinstance(a.func, ast.Attribute): types(a.func.value, env, func) - t_v = audits(a.func.value, 'types') - if a.func.attr == 'if_else': + t_v = audits(a.func.value, "types") + if a.func.attr == "if_else": if len(a.args) == 2: ts = ats rules_no_restriction(a) rules_no_restriction(a.func) if t_v == Boolean: - if ( - ts[0] in (Integer, PublicInteger, SecretInteger) and - ts[1] in (Integer, PublicInteger, SecretInteger) - ): + if ts[0] in (Integer, PublicInteger, SecretInteger) and ts[ + 1 + ] in (Integer, PublicInteger, SecretInteger): t = max(ts) else: - t = TypeErrorRoot('branches must have the same integer type') + t = TypeErrorRoot( + "branches must have the same integer type" + ) elif t_v == PublicBoolean: - if ( - ts[0] in (Integer, PublicInteger, SecretInteger) and - ts[1] in (Integer, PublicInteger, SecretInteger) - ): + if ts[0] in (Integer, PublicInteger, SecretInteger) and ts[ + 1 + ] in (Integer, PublicInteger, SecretInteger): t = max(ts) else: - t = TypeErrorRoot('branches must have the same integer type') + t = TypeErrorRoot( + "branches must have the same integer type" + ) elif t_v == SecretBoolean: - if ( - ts[0] in (Integer, PublicInteger, SecretInteger) and - ts[1] in (Integer, PublicInteger, SecretInteger) - ): + if ts[0] in (Integer, PublicInteger, SecretInteger) and ts[ + 1 + ] in (Integer, PublicInteger, SecretInteger): t = max(ts) else: - t = TypeErrorRoot('branches must have the same integer type') + t = TypeErrorRoot( + "branches must have the same integer type" + ) else: - t = TypeErrorRoot('condition must have a boolean type') + t = TypeErrorRoot("condition must have a boolean type") - audits(a, 'types', t) - elif a.func.attr == 'append': + audits(a, "types", t) + elif a.func.attr == "append": if len(a.args) == 1: rules_no_restriction(a) rules_no_restriction(a.func) t_i = ats[0] if unify(t_v, list[t_i]): - audits(a, 'types', type(None)) + audits(a, "types", type(None)) else: audits( a, - 'types', - TypeErrorRoot('item type does not match list type') + "types", + TypeErrorRoot("item type does not match list type"), ) elif isinstance(a.func, ast.Name): - if a.func.id == 'Party': + if a.func.id == "Party": rules_no_restriction(a) rules_no_restriction(a.func) - t = TypeError('party requires name parameter (a string)') + t = TypeError("party requires name parameter (a string)") if ( - len(a.args) == 0 and - len(a.keywords) == 1 and - a.keywords[0].arg == 'name' + len(a.args) == 0 + and len(a.keywords) == 1 + and a.keywords[0].arg == "name" ): rules_no_restriction(a.keywords[0]) - if audits(a.keywords[0].value, 'types') == str: + if audits(a.keywords[0].value, "types") == str: t = Party - elif ( - len(a.args) == 1 and - len(a.keywords) == 0 - ): + elif len(a.args) == 1 and len(a.keywords) == 0: rules_no_restriction(a.args[0]) - if audits(a.args[0], 'types') == str: + if audits(a.args[0], "types") == str: t = Party - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) - elif a.func.id == 'Input': + elif a.func.id == "Input": rules_no_restriction(a) rules_no_restriction(a.func) - t = TypeError('input requires name parameter (a string) and party parameter') - if ( - len(a.args) == 2 and - len(a.keywords) == 0 - ): + t = TypeError( + "input requires name parameter (a string) and party parameter" + ) + if len(a.args) == 2 and len(a.keywords) == 0: rules_no_restriction(a.args[0]) rules_no_restriction(a.args[1]) if ( - audits(a.args[0], 'types') == str and - audits(a.args[1], 'types') == Party + audits(a.args[0], "types") == str + and audits(a.args[1], "types") == Party ): t = Input if ( - len(a.args) == 1 and - len(a.keywords) == 1 and - a.keywords[0].arg == 'party' + len(a.args) == 1 + and len(a.keywords) == 1 + and a.keywords[0].arg == "party" ): rules_no_restriction(a.args[0]) rules_no_restriction(a.keywords[0]) if ( - audits(a.args[0], 'types') == str and - audits(a.keywords[0].value, 'types') == Party + audits(a.args[0], "types") == str + and audits(a.keywords[0].value, "types") == Party ): t = Input if ( - len(a.args) == 0 and - len(a.keywords) == 2 and - a.keywords[0].arg == 'name' and - a.keywords[1].arg == 'party' + len(a.args) == 0 + and len(a.keywords) == 2 + and a.keywords[0].arg == "name" + and a.keywords[1].arg == "party" ): rules_no_restriction(a.keywords[0]) rules_no_restriction(a.keywords[1]) if ( - audits(a.keywords[0].value, 'types') == str and - audits(a.keywords[1].value, 'types') == Party + audits(a.keywords[0].value, "types") == str + and audits(a.keywords[1].value, "types") == Party ): t = Input if ( - len(a.args) == 0 and - len(a.keywords) == 2 and - a.keywords[1].arg == 'name' and - a.keywords[0].arg == 'party' + len(a.args) == 0 + and len(a.keywords) == 2 + and a.keywords[1].arg == "name" + and a.keywords[0].arg == "party" ): rules_no_restriction(a.keywords[0]) rules_no_restriction(a.keywords[1]) if ( - audits(a.keywords[1].value, 'types') == str and - audits(a.keywords[0].value, 'types') == Party + audits(a.keywords[1].value, "types") == str + and audits(a.keywords[0].value, "types") == Party ): t = Input - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) - elif a.func.id == 'Output': + elif a.func.id == "Output": rules_no_restriction(a) rules_no_restriction(a.func) t = TypeError( - 'output requires value parameter, name parameter (a string), ' + - 'and party parameter' + "output requires value parameter, name parameter (a string), " + + "and party parameter" ) kwargs = {kw.arg: kw.value for kw in a.keywords} if len(a.args) == 0 and len(a.keywords) == 3: - if set(kwargs.keys()) == {'value', 'name', 'party'}: + if set(kwargs.keys()) == {"value", "name", "party"}: for kw in a.keywords: rules_no_restriction(kw) if ( ( - audits(kwargs['value'], 'types') - in - (SecretInteger, PublicInteger) - ) and - audits(kwargs['name'], 'types') == str and - audits(kwargs['party'], 'types') == Party + audits(kwargs["value"], "types") + in (SecretInteger, PublicInteger) + ) + and audits(kwargs["name"], "types") == str + and audits(kwargs["party"], "types") == Party ): t = Output elif len(a.args) == 1 and len(a.keywords) == 2: - if set(kwargs.keys()) == {'name', 'party'}: + if set(kwargs.keys()) == {"name", "party"}: for arg in a.args: rules_no_restriction(arg) for kw in a.keywords: rules_no_restriction(kw) if ( ( - audits(a.args[0], 'types') - in - (SecretInteger, PublicInteger) - ) and - audits(kwargs['name'], 'types') == str and - audits(kwargs['party'], 'types') == Party + audits(a.args[0], "types") + in (SecretInteger, PublicInteger) + ) + and audits(kwargs["name"], "types") == str + and audits(kwargs["party"], "types") == Party ): t = Output elif len(a.args) == 2 and len(a.keywords) == 1: - if set(kwargs.keys()) == {'party'}: + if set(kwargs.keys()) == {"party"}: for arg in a.args: rules_no_restriction(arg) for kw in a.keywords: rules_no_restriction(kw) if ( ( - audits(a.args[0], 'types') - in - (SecretInteger, PublicInteger) - ) and - audits(a.args[1], 'types') == str and - audits(kwargs['party'], 'types') == Party + audits(a.args[0], "types") + in (SecretInteger, PublicInteger) + ) + and audits(a.args[1], "types") == str + and audits(kwargs["party"], "types") == Party ): t = Output elif len(a.args) == 3 and len(a.keywords) == 0: @@ -573,124 +594,105 @@ def types(a, env=None, func=False): for kw in a.keywords: rules_no_restriction(kw) if ( - ( - audits(a.args[0], 'types') - in - (SecretInteger, PublicInteger) - ) and - audits(a.args[1], 'types') == str and - audits(a.args[2], 'types') == Party + (audits(a.args[0], "types") in (SecretInteger, PublicInteger)) + and audits(a.args[1], "types") == str + and audits(a.args[2], "types") == Party ): t = Output - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) - elif a.func.id == 'Integer': + elif a.func.id == "Integer": rules_no_restriction(a) rules_no_restriction(a.func) t = Integer - if ( - len(a.args) != 1 or - audits(a.args[0], 'types') != int - ): - t = TypeError('expecting single argument (an integer)') - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + if len(a.args) != 1 or audits(a.args[0], "types") != int: + t = TypeError("expecting single argument (an integer)") + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) - elif a.func.id == 'PublicInteger': + elif a.func.id == "PublicInteger": rules_no_restriction(a) rules_no_restriction(a.func) t = PublicInteger - if ( - len(a.args) != 1 or - audits(a.args[0], 'types') != Input - ): - t = TypeError('expecting single argument (an input object)') - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + if len(a.args) != 1 or audits(a.args[0], "types") != Input: + t = TypeError("expecting single argument (an input object)") + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) - elif a.func.id == 'SecretInteger': + elif a.func.id == "SecretInteger": rules_no_restriction(a) rules_no_restriction(a.func) t = SecretInteger - if ( - len(a.args) != 1 or - audits(a.args[0], 'types') != Input - ): - t = TypeError('expecting single argument (an input object)') - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + if len(a.args) != 1 or audits(a.args[0], "types") != Input: + t = TypeError("expecting single argument (an input object)") + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) - elif a.func.id == 'range': + elif a.func.id == "range": rules_no_restriction(a) rules_no_restriction(a.func) t = range - if ( - len(a.args) != 1 or - audits(a.args[0], 'types') != int - ): - t = TypeErrorRoot('expecting single integer argument') - if isinstance(audits(a.args[0], 'types'), TypeError): + if len(a.args) != 1 or audits(a.args[0], "types") != int: + t = TypeErrorRoot("expecting single integer argument") + if isinstance(audits(a.args[0], "types"), TypeError): t = typeerror_demote(t) - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) - elif a.func.id == 'str': + elif a.func.id == "str": rules_no_restriction(a) rules_no_restriction(a.func) t = str - if ( - len(a.args) != 1 or - audits(a.args[0], 'types') != int - ): - t = TypeError('expecting single integer argument') - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + if len(a.args) != 1 or audits(a.args[0], "types") != int: + t = TypeError("expecting single integer argument") + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) - elif a.func.id == 'sum': + elif a.func.id == "sum": rules_no_restriction(a) rules_no_restriction(a.func) t = SecretInteger if ( - len(a.args) != 1 or - audits(a.args[0], 'types') != list[SecretInteger] + len(a.args) != 1 + or audits(a.args[0], "types") != list[SecretInteger] ): - t = TypeError('expecting argument of type list[SecretInteger]') - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + t = TypeError("expecting argument of type list[SecretInteger]") + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) elif a.func.id in env: rules_no_restriction(a) rules_no_restriction(a.func) t_f = env[a.func.id] ts = t_f.__args__[:-1] - t = TypeErrorRoot('function arguments do not match function type') + t = TypeErrorRoot("function arguments do not match function type") if len(ats) == len(ts): if all(t_a == t for (t_a, t) in zip(ats, ts)): t = t_f.__args__[-1] - audits(a, 'types', t) - audits(a.func, 'types', TypeInParent()) + audits(a, "types", t) + audits(a.func, "types", TypeInParent()) elif isinstance(a, ast.Subscript): rules_no_restriction(a) types(a.value, env, func) types(a.slice, env, func) - t_v = audits(a.value, 'types') - t_s = audits(a.slice, 'types') + t_v = audits(a.value, "types") + t_s = audits(a.slice, "types") if (not isinstance(t_v, TypeError)) and (not isinstance(t_s, TypeError)): - t = TypeErrorRoot('expecting list value and integer index') - if t_v.__name__ == 'list' and t_s == int: - if hasattr(t_v, '__args__') and len(t_v.__args__) == 1: + t = TypeErrorRoot("expecting list value and integer index") + if t_v.__name__ == "list" and t_s == int: + if hasattr(t_v, "__args__") and len(t_v.__args__) == 1: t = t_v.__args__[0] - audits(a, 'types', t) + audits(a, "types", t) elif isinstance(a, ast.List): rules_no_restriction(a) for a_ in a.elts: types(a_, env, func) - ts = [audits(a_, 'types') for a_ in a.elts] - t = TypeError('lists must contain elements that are all of the same type') + ts = [audits(a_, "types") for a_ in a.elts] + t = TypeError("lists must contain elements that are all of the same type") if len(set(ts)) == 0: t = list elif len(set(ts)) == 1: @@ -699,66 +701,66 @@ def types(a, env=None, func=False): t = typeerror_demote(t) else: t = list[t] - audits(a, 'types', t) + audits(a, "types", t) elif isinstance(a, ast.BoolOp): rules_no_restriction(a) ts = [] for a_ in a.values: types(a_, env, func) - ts.append(audits(a_, 'types')) - t = TypeErrorRoot('arguments must be boolean values') + ts.append(audits(a_, "types")) + t = TypeErrorRoot("arguments must be boolean values") if all(t == bool for t in ts): t = bool elif any(isinstance(t, TypeError) for t in ts): t = typeerror_demote(t) - audits(a, 'types', t) + audits(a, "types", t) elif isinstance(a, ast.BinOp): types(a.left, env, func) types(a.right, env, func) - t_l = audits(a.left, 'types') - t_r = audits(a.right, 'types') + t_l = audits(a.left, "types") + t_r = audits(a.right, "types") if isinstance(a.op, ast.Add): rules_no_restriction(a) - t = TypeError('unsupported operand types') + t = TypeError("unsupported operand types") if t_l == int and t_r == int: t = int elif t_l == str and t_r == str: t = str else: t = _types_binop_mult_add_sub(t_l, t_r) - audits(a, 'types', t) + audits(a, "types", t) elif isinstance(a.op, ast.Sub): rules_no_restriction(a) - t = TypeError('unsupported operand types') + t = TypeError("unsupported operand types") if t_l == int and t_r == int: t = int else: t = _types_binop_mult_add_sub(t_l, t_r) - audits(a, 'types', t) + audits(a, "types", t) elif isinstance(a.op, ast.Mult): rules_no_restriction(a) if t_l == int and t_r == int: t = int else: t = _types_binop_mult_add_sub(t_l, t_r) - audits(a, 'types', t) + audits(a, "types", t) elif isinstance(a, ast.Compare): if len(a.comparators) != 1: audits( a, - 'rules', - SyntaxRestriction('chained comparisons are prohibited in strict mode') + "rules", + SyntaxRestriction("chained comparisons are prohibited in strict mode"), ) else: op = a.ops[0] rules_no_restriction(a) types(a.left, env, func) types(a.comparators[0], env, func) - t_l = audits(a.left, 'types') - t_r = audits(a.comparators[0], 'types') + t_l = audits(a.left, "types") + t_r = audits(a.comparators[0], "types") if isinstance(op, (ast.Eq, ast.NotEq)): if t_l == bool and t_r == bool: t = bool @@ -773,11 +775,11 @@ def types(a, env=None, func=False): t = bool else: t = _types_compare(t_l, t_r) - audits(a, 'types', t) + audits(a, "types", t) elif isinstance(a, ast.UnaryOp): types(a.operand, env, func) - t = audits(a.operand, 'types') + t = audits(a.operand, "types") if isinstance(a.op, (ast.UAdd, ast.USub)): rules_no_restriction(a) if isinstance(t, TypeError): @@ -785,8 +787,8 @@ def types(a, env=None, func=False): elif t in (int, Integer, PublicInteger, SecretInteger): pass else: - t = TypeErrorRoot('argument must have an integer type') - audits(a, 'types', t) + t = TypeErrorRoot("argument must have an integer type") + audits(a, "types", t) elif isinstance(a.op, ast.Not): rules_no_restriction(a) @@ -795,30 +797,31 @@ def types(a, env=None, func=False): elif t == bool: pass else: - t = TypeErrorRoot('argument must be a boolean value') - audits(a, 'types', t) + t = TypeErrorRoot("argument must be a boolean value") + audits(a, "types", t) elif isinstance(a, ast.Name): rules_no_restriction(a) var = a.id audits( a, - 'types', - env[var] if var in env else TypeError("name '" + var + "' is not defined") + "types", + env[var] if var in env else TypeError("name '" + var + "' is not defined"), ) elif isinstance(a, ast.Constant): - if a.value in (False, True) and str(a.value) in ('False', 'True'): + if a.value in (False, True) and str(a.value) in ("False", "True"): rules_no_restriction(a) - audits(a, 'types', bool) + audits(a, "types", bool) elif isinstance(a.value, int): rules_no_restriction(a) - audits(a, 'types', int) + audits(a, "types", int) elif isinstance(a.value, str): rules_no_restriction(a) - audits(a, 'types', str) + audits(a, "types", str) + + return env # Always return the environment. - return env # Always return the environment. def strict(source: str) -> richreports.report: """ @@ -835,26 +838,30 @@ def strict(source: str) -> richreports.report: types(root) # Perform the abstract execution. - #root.body.append(ast.Expr(ast.Call(ast.Name('nada_main', ast.Load()), [], []))) - #ast.fix_missing_locations(root) - #exec(compile(root, path, 'exec')) + # root.body.append(ast.Expr(ast.Call(ast.Name('nada_main', ast.Load()), [], []))) + # ast.fix_missing_locations(root) + # exec(compile(root, path, 'exec')) # Add the results of the analyses to the report and ensure each line is # wrapped as an HTML element. report = richreports.report(source, line=1, column=0) enrich_fromaudits(report, atok) - for (i, line) in enumerate(report.lines): + for i, line in enumerate(report.lines): if i in skips: report.enrich( - (i + 1, 0), (i + 1, len(line) - 1), - '', '', - skip_whitespace=True + (i + 1, 0), + (i + 1, len(line) - 1), + '', + "", + skip_whitespace=True, ) report.enrich( - (i + 1, 0), (i + 1, len(line) - 1), - '', '', - skip_whitespace=True + (i + 1, 0), + (i + 1, len(line) - 1), + '', + "", + skip_whitespace=True, ) - report.enrich((i + 1, 0), (i + 1, len(line)), '
', '
') + report.enrich((i + 1, 0), (i + 1, len(line)), "
", "
") return report diff --git a/nada_dsl/compiler_frontend.py b/nada_dsl/compiler_frontend.py index 6efcffc..de41075 100644 --- a/nada_dsl/compiler_frontend.py +++ b/nada_dsl/compiler_frontend.py @@ -308,7 +308,6 @@ def process_operation( add_input_to_map(operation) processed_operation = ProcessOperationOutput(operation.to_mir(), None) elif isinstance(operation, LiteralASTOperation): - LITERALS[operation.literal_index] = (str(operation.value), operation.ty) processed_operation = ProcessOperationOutput(operation.to_mir(), None) elif isinstance( diff --git a/nada_dsl/program_io.py b/nada_dsl/program_io.py index 7d23d26..cd4de48 100644 --- a/nada_dsl/program_io.py +++ b/nada_dsl/program_io.py @@ -1,6 +1,6 @@ -"""Program Input Output utilities. +"""Program Input Output utilities. -Define the types used for inputs and outputs in Nada programs. +Define the types used for inputs and outputs in Nada programs. """ diff --git a/nada_dsl/timer.py b/nada_dsl/timer.py index 983cad8..e7c1c18 100644 --- a/nada_dsl/timer.py +++ b/nada_dsl/timer.py @@ -1,8 +1,8 @@ """ Timer class -This is a timer class used to measure the performance of different stages -in the compilation process. +This is a timer class used to measure the performance of different stages +in the compilation process. """ from dataclasses import dataclass diff --git a/pyproject.toml b/pyproject.toml index 804fc32..5a2cb22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev-dependencies = [ "tomli", "requests", "typing_extensions~=4.12.2", + "ruff>=0.8.0" ] [tool.uv.sources] diff --git a/tests/audit_strict_test.py b/tests/audit_strict_test.py index fd23fe0..fc00d7e 100644 --- a/tests/audit_strict_test.py +++ b/tests/audit_strict_test.py @@ -1,10 +1,12 @@ """ Nada DSL audit component tests. """ + import richreports from nada_dsl.audit.strict import strict + def test_strict_syntax(): source = """ from nada_dsl import * @@ -56,7 +58,8 @@ def nada_main(): return [Output(value=new_int, party=party1, name="my_output")] """ - assert(isinstance(strict(source), richreports.report)) + assert isinstance(strict(source), richreports.report) + def test_strict_functional(): source = """ @@ -84,7 +87,8 @@ def nada_main(): for c in range(4) ] """ - assert(isinstance(strict(source), richreports.report)) + assert isinstance(strict(source), richreports.report) + def test_strict_imperative(): source = """ @@ -121,4 +125,4 @@ def nada_main(): return outputs """ - assert(isinstance(strict(source), richreports.report)) + assert isinstance(strict(source), richreports.report) diff --git a/tests/compiler_frontend_test.py b/tests/compiler_frontend_test.py index 6469cbc..e4a2199 100644 --- a/tests/compiler_frontend_test.py +++ b/tests/compiler_frontend_test.py @@ -332,7 +332,6 @@ def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: def test_nada_function_call(): - c = create_input(SecretInteger, "c", "party", **{}) d = create_input(SecretInteger, "c", "party", **{}) @@ -349,7 +348,6 @@ def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: def test_nada_function_using_operations(): - c = create_input(SecretInteger, "c", "party", **{}) d = create_input(SecretInteger, "d", "party", **{}) diff --git a/tests/scalar_type_test.py b/tests/scalar_type_test.py index 24abd8a..94b9aa3 100644 --- a/tests/scalar_type_test.py +++ b/tests/scalar_type_test.py @@ -6,8 +6,20 @@ from nada_dsl import Input, Party from nada_dsl.nada_types import BaseType, Mode -from nada_dsl.nada_types.scalar_types import Integer, PublicInteger, SecretInteger, Boolean, PublicBoolean, \ - SecretBoolean, UnsignedInteger, PublicUnsignedInteger, SecretUnsignedInteger, ScalarType, BooleanType +from nada_dsl.nada_types.scalar_types import ( + Integer, + PublicInteger, + SecretInteger, + Boolean, + PublicBoolean, + SecretBoolean, + UnsignedInteger, + PublicUnsignedInteger, + SecretUnsignedInteger, + ScalarType, + BooleanType, +) + def combine_lists(list1, list2): """This returns all combinations for the items of two lists""" @@ -30,7 +42,7 @@ def combine_lists(list1, list2): booleans = [ Boolean(value=True), PublicBoolean(Input(name="public", party=Party("party"))), - SecretBoolean(Input(name="secret", party=Party("party"))) + SecretBoolean(Input(name="secret", party=Party("party"))), ] # All public boolean values @@ -46,7 +58,7 @@ def combine_lists(list1, list2): integers = [ Integer(value=1), PublicInteger(Input(name="public", party=Party("party"))), - SecretInteger(Input(name="secret", party=Party("party"))) + SecretInteger(Input(name="secret", party=Party("party"))), ] # All public integer values @@ -61,14 +73,14 @@ def combine_lists(list1, list2): # All integer inputs (non literal elements) variable_integers = [ PublicInteger(Input(name="public", party=Party("party"))), - SecretInteger(Input(name="public", party=Party("party"))) + SecretInteger(Input(name="public", party=Party("party"))), ] # All unsigned integer values unsigned_integers = [ UnsignedInteger(value=1), PublicUnsignedInteger(Input(name="public", party=Party("party"))), - SecretUnsignedInteger(Input(name="secret", party=Party("party"))) + SecretUnsignedInteger(Input(name="secret", party=Party("party"))), ] # All public unsigned integer values @@ -83,7 +95,7 @@ def combine_lists(list1, list2): # All unsigned integer inputs (non-literal elements) variable_unsigned_integers = [ PublicUnsignedInteger(Input(name="public", party=Party("party"))), - SecretUnsignedInteger(Input(name="public", party=Party("party"))) + SecretUnsignedInteger(Input(name="public", party=Party("party"))), ] # Binary arithmetic operations. They are provided as functions to the tests to avoid duplicate code @@ -98,9 +110,11 @@ def combine_lists(list1, list2): # Data set for the binary arithmetic operation tests. It combines all allowed operands with the operations. binary_arithmetic_operations = ( # Integers - combine_lists(itertools.product(integers, repeat=2), binary_arithmetic_functions) - # UnsignedIntegers - + combine_lists(itertools.product(unsigned_integers, repeat=2), binary_arithmetic_functions) + combine_lists(itertools.product(integers, repeat=2), binary_arithmetic_functions) + # UnsignedIntegers + + combine_lists( + itertools.product(unsigned_integers, repeat=2), binary_arithmetic_functions + ) ) @@ -114,16 +128,16 @@ def test_binary_arithmetic_operations(left: ScalarType, right: ScalarType, opera # Allowed operands for the power operation allowed_pow_operands = ( - # Integers: Only combinations of public integers - combine_lists(public_integers, public_integers) - # UnsignedIntegers: Only combinations of public unsigned integers - + combine_lists(public_unsigned_integers, public_unsigned_integers) + # Integers: Only combinations of public integers + combine_lists(public_integers, public_integers) + # UnsignedIntegers: Only combinations of public unsigned integers + + combine_lists(public_unsigned_integers, public_unsigned_integers) ) @pytest.mark.parametrize("left, right", allowed_pow_operands) def test_pow(left: ScalarType, right: ScalarType): - result = left ** right + result = left**right assert result.base_type, left.base_type assert result.base_type, right.base_type assert result.mode.value, max([left.mode.value, right.mode.value]) @@ -137,10 +151,12 @@ def test_pow(left: ScalarType, right: ScalarType): # The shift operations accept public unsigned integers on the right operand only. allowed_shift_operands = ( - # Integers on the left operand - combine_lists(combine_lists(integers, public_unsigned_integers), shift_functions) - # UnsignedIntegers on the left operand - + combine_lists(combine_lists(unsigned_integers, public_unsigned_integers), shift_functions) + # Integers on the left operand + combine_lists(combine_lists(integers, public_unsigned_integers), shift_functions) + # UnsignedIntegers on the left operand + + combine_lists( + combine_lists(unsigned_integers, public_unsigned_integers), shift_functions + ) ) @@ -157,15 +173,17 @@ def test_shift(left: ScalarType, right: ScalarType, operation): lambda lhs, rhs: lhs < rhs, lambda lhs, rhs: lhs > rhs, lambda lhs, rhs: lhs <= rhs, - lambda lhs, rhs: lhs >= rhs + lambda lhs, rhs: lhs >= rhs, ] # Allowed operands that are accepted by the numeric relational operations. They are combined with the operations. binary_relational_operations = ( - # Integers - combine_lists(itertools.product(integers, repeat=2), binary_relational_functions) - # UnsignedIntegers - + combine_lists(itertools.product(unsigned_integers, repeat=2), binary_relational_functions) + # Integers + combine_lists(itertools.product(integers, repeat=2), binary_relational_functions) + # UnsignedIntegers + + combine_lists( + itertools.product(unsigned_integers, repeat=2), binary_relational_functions + ) ) @@ -177,16 +195,13 @@ def test_binary_relational_operations(left: ScalarType, right: ScalarType, opera # Equality operations -equals_functions = [ - lambda lhs, rhs: lhs == rhs, - lambda lhs, rhs: lhs != rhs -] +equals_functions = [lambda lhs, rhs: lhs == rhs, lambda lhs, rhs: lhs != rhs] # Allowed operands that are accepted by the equality operations. They are combined with the operations. equals_operations = ( - combine_lists(itertools.product(integers, repeat=2), equals_functions) - + combine_lists(itertools.product(unsigned_integers, repeat=2), equals_functions) - + combine_lists(itertools.product(booleans, repeat=2), equals_functions) + combine_lists(itertools.product(integers, repeat=2), equals_functions) + + combine_lists(itertools.product(unsigned_integers, repeat=2), equals_functions) + + combine_lists(itertools.product(booleans, repeat=2), equals_functions) ) @@ -199,17 +214,27 @@ def test_equals_operations(left: ScalarType, right: ScalarType, operation): # Allowed operands that are accepted by the public_equals function. Literals are not accepted. public_equals_operands = ( - # Integers - combine_lists(variable_integers, variable_integers) - # UnsignedIntegers - + combine_lists(variable_unsigned_integers, variable_unsigned_integers) + # Integers + combine_lists(variable_integers, variable_integers) + # UnsignedIntegers + + combine_lists(variable_unsigned_integers, variable_unsigned_integers) ) @pytest.mark.parametrize("left, right", public_equals_operands) def test_public_equals( - left: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] - , right: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] + left: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], + right: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], ): assert isinstance(left.public_equals(right), PublicBoolean) @@ -218,11 +243,13 @@ def test_public_equals( logic_functions = [ lambda lhs, rhs: lhs & rhs, lambda lhs, rhs: lhs | rhs, - lambda lhs, rhs: lhs ^ rhs + lambda lhs, rhs: lhs ^ rhs, ] # Allowed operands that are accepted by the logic operations. They are combined with the operations. -binary_logic_operations = combine_lists(combine_lists(booleans, booleans), logic_functions) +binary_logic_operations = combine_lists( + combine_lists(booleans, booleans), logic_functions +) @pytest.mark.parametrize("left, right, operation", binary_logic_operations) @@ -240,10 +267,9 @@ def test_invert_operations(operand): # Allowed operands that are accepted by the probabilistic truncation. -trunc_pr_operands = ( - combine_lists(secret_integers, public_unsigned_integers) - + combine_lists(secret_unsigned_integers, public_unsigned_integers) -) +trunc_pr_operands = combine_lists( + secret_integers, public_unsigned_integers +) + combine_lists(secret_unsigned_integers, public_unsigned_integers) @pytest.mark.parametrize("left, right", trunc_pr_operands) @@ -279,10 +305,14 @@ def test_to_public(operand): # Allow combination of operands that are accepted by if_else function if_else_operands = ( - combine_lists(secret_booleans, combine_lists(integers, integers)) - + combine_lists([public_boolean], combine_lists(integers, integers)) - + combine_lists(secret_booleans, combine_lists(unsigned_integers, unsigned_integers)) - + combine_lists([public_boolean], combine_lists(unsigned_integers, unsigned_integers)) + combine_lists(secret_booleans, combine_lists(integers, integers)) + + combine_lists([public_boolean], combine_lists(integers, integers)) + + combine_lists( + secret_booleans, combine_lists(unsigned_integers, unsigned_integers) + ) + + combine_lists( + [public_boolean], combine_lists(unsigned_integers, unsigned_integers) + ) ) @@ -296,40 +326,57 @@ def test_if_else(condition: BooleanType, left: ScalarType, right: ScalarType): # List of not allowed operations -not_allowed_binary_operations = \ - ( # Arithmetic operations - combine_lists(combine_lists(booleans, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(booleans, integers), binary_arithmetic_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), binary_arithmetic_functions) - + combine_lists(combine_lists(integers, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(integers, unsigned_integers), binary_arithmetic_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), binary_arithmetic_functions) - + combine_lists(combine_lists(unsigned_integers, integers), binary_arithmetic_functions) - # Relational operations - + combine_lists(combine_lists(booleans, booleans), binary_relational_functions) - + combine_lists(combine_lists(booleans, integers), binary_relational_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), binary_relational_functions) - + combine_lists(combine_lists(integers, booleans), binary_relational_functions) - + combine_lists(combine_lists(integers, unsigned_integers), binary_relational_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), binary_relational_functions) - + combine_lists(combine_lists(unsigned_integers, integers), binary_relational_functions) - # Equals operations - + combine_lists(combine_lists(booleans, integers), equals_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), equals_functions) - + combine_lists(combine_lists(integers, booleans), equals_functions) - + combine_lists(combine_lists(integers, unsigned_integers), equals_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), equals_functions) - + combine_lists(combine_lists(unsigned_integers, integers), equals_functions) - # Logic operations - + combine_lists(combine_lists(booleans, integers), logic_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), logic_functions) - + combine_lists(combine_lists(integers, booleans), logic_functions) - + combine_lists(combine_lists(integers, integers), logic_functions) - + combine_lists(combine_lists(integers, unsigned_integers), logic_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), logic_functions) - + combine_lists(combine_lists(unsigned_integers, integers), logic_functions) - + combine_lists(combine_lists(unsigned_integers, unsigned_integers), logic_functions) +not_allowed_binary_operations = ( # Arithmetic operations + combine_lists(combine_lists(booleans, booleans), binary_arithmetic_functions) + + combine_lists(combine_lists(booleans, integers), binary_arithmetic_functions) + + combine_lists( + combine_lists(booleans, unsigned_integers), binary_arithmetic_functions + ) + + combine_lists(combine_lists(integers, booleans), binary_arithmetic_functions) + + combine_lists( + combine_lists(integers, unsigned_integers), binary_arithmetic_functions ) + + combine_lists( + combine_lists(unsigned_integers, booleans), binary_arithmetic_functions + ) + + combine_lists( + combine_lists(unsigned_integers, integers), binary_arithmetic_functions + ) + # Relational operations + + combine_lists(combine_lists(booleans, booleans), binary_relational_functions) + + combine_lists(combine_lists(booleans, integers), binary_relational_functions) + + combine_lists( + combine_lists(booleans, unsigned_integers), binary_relational_functions + ) + + combine_lists(combine_lists(integers, booleans), binary_relational_functions) + + combine_lists( + combine_lists(integers, unsigned_integers), binary_relational_functions + ) + + combine_lists( + combine_lists(unsigned_integers, booleans), binary_relational_functions + ) + + combine_lists( + combine_lists(unsigned_integers, integers), binary_relational_functions + ) + # Equals operations + + combine_lists(combine_lists(booleans, integers), equals_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), equals_functions) + + combine_lists(combine_lists(integers, booleans), equals_functions) + + combine_lists(combine_lists(integers, unsigned_integers), equals_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), equals_functions) + + combine_lists(combine_lists(unsigned_integers, integers), equals_functions) + # Logic operations + + combine_lists(combine_lists(booleans, integers), logic_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), logic_functions) + + combine_lists(combine_lists(integers, booleans), logic_functions) + + combine_lists(combine_lists(integers, integers), logic_functions) + + combine_lists(combine_lists(integers, unsigned_integers), logic_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), logic_functions) + + combine_lists(combine_lists(unsigned_integers, integers), logic_functions) + + combine_lists( + combine_lists(unsigned_integers, unsigned_integers), logic_functions + ) +) @pytest.mark.parametrize("left, right, operation", not_allowed_binary_operations) @@ -341,38 +388,40 @@ def test_not_allowed_binary_operations(left, right, operation): # List of operands that the operation power does not accept. not_allowed_pow = ( - combine_lists(booleans, booleans) - + combine_lists(integers, booleans) - + combine_lists(unsigned_integers, booleans) - + combine_lists(booleans, integers) - + combine_lists(secret_integers, integers) - + combine_lists(public_integers, secret_integers) - + combine_lists(integers, unsigned_integers) - + combine_lists(booleans, unsigned_integers) - + combine_lists(unsigned_integers, integers) - + combine_lists(secret_unsigned_integers, unsigned_integers) - + combine_lists(public_unsigned_integers, secret_unsigned_integers) + combine_lists(booleans, booleans) + + combine_lists(integers, booleans) + + combine_lists(unsigned_integers, booleans) + + combine_lists(booleans, integers) + + combine_lists(secret_integers, integers) + + combine_lists(public_integers, secret_integers) + + combine_lists(integers, unsigned_integers) + + combine_lists(booleans, unsigned_integers) + + combine_lists(unsigned_integers, integers) + + combine_lists(secret_unsigned_integers, unsigned_integers) + + combine_lists(public_unsigned_integers, secret_unsigned_integers) ) @pytest.mark.parametrize("left, right", not_allowed_pow) def test_not_allowed_pow(left, right): with pytest.raises(Exception) as invalid_operation: - left ** right + left**right assert invalid_operation.type == TypeError # List of operands that the shift operation do not accept. not_allowed_shift = ( - combine_lists(combine_lists(booleans, booleans), shift_functions) - + combine_lists(combine_lists(integers, booleans), shift_functions) - + combine_lists(combine_lists(unsigned_integers, booleans), shift_functions) - + combine_lists(combine_lists(booleans, integers), shift_functions) - + combine_lists(combine_lists(integers, integers), shift_functions) - + combine_lists(combine_lists(unsigned_integers, integers), shift_functions) - + combine_lists(combine_lists(booleans, unsigned_integers), shift_functions) - + combine_lists(combine_lists(integers, secret_unsigned_integers), shift_functions) - + combine_lists(combine_lists(unsigned_integers, secret_unsigned_integers), shift_functions) + combine_lists(combine_lists(booleans, booleans), shift_functions) + + combine_lists(combine_lists(integers, booleans), shift_functions) + + combine_lists(combine_lists(unsigned_integers, booleans), shift_functions) + + combine_lists(combine_lists(booleans, integers), shift_functions) + + combine_lists(combine_lists(integers, integers), shift_functions) + + combine_lists(combine_lists(unsigned_integers, integers), shift_functions) + + combine_lists(combine_lists(booleans, unsigned_integers), shift_functions) + + combine_lists(combine_lists(integers, secret_unsigned_integers), shift_functions) + + combine_lists( + combine_lists(unsigned_integers, secret_unsigned_integers), shift_functions + ) ) @@ -384,14 +433,25 @@ def test_not_allowed_shift(left, right, operation): # List of operands that the public_equals function does not accept. -not_allowed_public_equals_operands = (combine_lists(variable_integers, variable_unsigned_integers) - + combine_lists(variable_unsigned_integers, variable_integers)) +not_allowed_public_equals_operands = combine_lists( + variable_integers, variable_unsigned_integers +) + combine_lists(variable_unsigned_integers, variable_integers) @pytest.mark.parametrize("left, right", not_allowed_public_equals_operands) def test_not_allowed_public_equals( - left: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] - , right: Union["PublicInteger", "SecretInteger", "PublicUnsignedInteger", "SecretUnsignedInteger"] + left: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], + right: Union[ + "PublicInteger", + "SecretInteger", + "PublicUnsignedInteger", + "SecretUnsignedInteger", + ], ): with pytest.raises(Exception) as invalid_operation: left.public_equals(right) @@ -411,17 +471,17 @@ def test_not_allowed_invert_operations(operand): # List of operands that the probabilistic truncation does not accept. not_allowed_trunc_pr_operands = ( - combine_lists(booleans, booleans) - + combine_lists(integers, booleans) - + combine_lists(unsigned_integers, booleans) - + combine_lists(booleans, integers) - + combine_lists(integers, integers) - + combine_lists(unsigned_integers, integers) - + combine_lists(booleans, unsigned_integers) - + combine_lists(integers, secret_unsigned_integers) - + combine_lists(public_integers, public_unsigned_integers) - + combine_lists(unsigned_integers, secret_unsigned_integers) - + combine_lists(public_unsigned_integers, public_unsigned_integers) + combine_lists(booleans, booleans) + + combine_lists(integers, booleans) + + combine_lists(unsigned_integers, booleans) + + combine_lists(booleans, integers) + + combine_lists(integers, integers) + + combine_lists(unsigned_integers, integers) + + combine_lists(booleans, unsigned_integers) + + combine_lists(integers, secret_unsigned_integers) + + combine_lists(public_integers, public_unsigned_integers) + + combine_lists(unsigned_integers, secret_unsigned_integers) + + combine_lists(public_unsigned_integers, public_unsigned_integers) ) @@ -429,7 +489,9 @@ def test_not_allowed_invert_operations(operand): def test_not_allowed_trunc_pr(left, right): with pytest.raises(Exception) as invalid_operation: left.trunc_pr(right) - assert invalid_operation.type == TypeError or invalid_operation.type == AttributeError + assert ( + invalid_operation.type == TypeError or invalid_operation.type == AttributeError + ) # List of types that cannot generate a random value @@ -442,20 +504,23 @@ def test_not_allowed_random(operand): operand.random() assert invalid_operation.type == AttributeError + # List of operands that the function if_else does not accept not_allowed_if_else_operands = ( - # Boolean branches - combine_lists(booleans, combine_lists(booleans, booleans)) - # Branches with different types - + combine_lists(booleans, combine_lists(integers, booleans)) - + combine_lists(booleans, combine_lists(unsigned_integers, booleans)) - + combine_lists(booleans, combine_lists(booleans, integers)) - + combine_lists(booleans, combine_lists(unsigned_integers, integers)) - + combine_lists(booleans, combine_lists(booleans, unsigned_integers)) - + combine_lists(booleans, combine_lists(integers, unsigned_integers)) - # The condition is a literal - + combine_lists([Boolean(value=True)], combine_lists(integers, integers)) - + combine_lists([Boolean(value=True)], combine_lists(unsigned_integers, unsigned_integers)) + # Boolean branches + combine_lists(booleans, combine_lists(booleans, booleans)) + # Branches with different types + + combine_lists(booleans, combine_lists(integers, booleans)) + + combine_lists(booleans, combine_lists(unsigned_integers, booleans)) + + combine_lists(booleans, combine_lists(booleans, integers)) + + combine_lists(booleans, combine_lists(unsigned_integers, integers)) + + combine_lists(booleans, combine_lists(booleans, unsigned_integers)) + + combine_lists(booleans, combine_lists(integers, unsigned_integers)) + # The condition is a literal + + combine_lists([Boolean(value=True)], combine_lists(integers, integers)) + + combine_lists( + [Boolean(value=True)], combine_lists(unsigned_integers, unsigned_integers) + ) ) diff --git a/uv.lock b/uv.lock index a9532c1..bc07a9b 100644 --- a/uv.lock +++ b/uv.lock @@ -651,6 +651,7 @@ dev = [ { name = "pytest" }, { name = "pytest-cov" }, { name = "requests" }, + { name = "ruff" }, { name = "sphinx" }, { name = "sphinx-autoapi" }, { name = "sphinx-rtd-theme" }, @@ -682,6 +683,7 @@ dev = [ { name = "pytest", specifier = ">=7.4,<9.0" }, { name = "pytest-cov", specifier = ">=4,<7" }, { name = "requests" }, + { name = "ruff", specifier = ">=0.8.0" }, { name = "sphinx", specifier = ">=5,<9" }, { name = "sphinx-autoapi", specifier = "~=3.3.2" }, { name = "sphinx-rtd-theme", specifier = ">=1.0,<3.1" }, @@ -908,6 +910,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bb/e5/6d9baab97743fab7c168d3ee330ebc1b3d6c90df37469a5ce4e3fa90f811/richreports-0.2.0-py3-none-any.whl", hash = "sha256:b99a4a0fb65d53f0d68e577518a89d5e098a9361a72fb1df7f8f0b9d4b6df2ac", size = 7534 }, ] +[[package]] +name = "ruff" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/d6/a2373f3ba7180ddb44420d2a9d1f1510e1a4d162b3d27282bedcb09c8da9/ruff-0.8.0.tar.gz", hash = "sha256:a7ccfe6331bf8c8dad715753e157457faf7351c2b69f62f32c165c2dbcbacd44", size = 3276537 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/77/e889ee3ce7fd8baa3ed1b77a03b9fb8ec1be68be1418261522fd6a5405e0/ruff-0.8.0-py3-none-linux_armv6l.whl", hash = "sha256:fcb1bf2cc6706adae9d79c8d86478677e3bbd4ced796ccad106fd4776d395fea", size = 10518283 }, + { url = "https://files.pythonhosted.org/packages/da/c8/0a47de01edf19fb22f5f9b7964f46a68d0bdff20144d134556ffd1ba9154/ruff-0.8.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:295bb4c02d58ff2ef4378a1870c20af30723013f441c9d1637a008baaf928c8b", size = 10317691 }, + { url = "https://files.pythonhosted.org/packages/41/17/9885e4a0eeae07abd2a4ebabc3246f556719f24efa477ba2739146c4635a/ruff-0.8.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7b1f1c76b47c18fa92ee78b60d2d20d7e866c55ee603e7d19c1e991fad933a9a", size = 9940999 }, + { url = "https://files.pythonhosted.org/packages/3e/cd/46b6f7043597eb318b5f5482c8ae8f5491cccce771e85f59d23106f2d179/ruff-0.8.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb0d4f250a7711b67ad513fde67e8870109e5ce590a801c3722580fe98c33a99", size = 10772437 }, + { url = "https://files.pythonhosted.org/packages/5d/87/afc95aeb8bc78b1d8a3461717a4419c05aa8aa943d4c9cbd441630f85584/ruff-0.8.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e55cce9aa93c5d0d4e3937e47b169035c7e91c8655b0974e61bb79cf398d49c", size = 10299156 }, + { url = "https://files.pythonhosted.org/packages/65/fa/04c647bb809c4d65e8eae1ed1c654d9481b21dd942e743cd33511687b9f9/ruff-0.8.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f4cd64916d8e732ce6b87f3f5296a8942d285bbbc161acee7fe561134af64f9", size = 11325819 }, + { url = "https://files.pythonhosted.org/packages/90/26/7dad6e7d833d391a8a1afe4ee70ca6f36c4a297d3cca83ef10e83e9aacf3/ruff-0.8.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c5c1466be2a2ebdf7c5450dd5d980cc87c8ba6976fb82582fea18823da6fa362", size = 12023927 }, + { url = "https://files.pythonhosted.org/packages/24/a0/be5296dda6428ba8a13bda8d09fbc0e14c810b485478733886e61597ae2b/ruff-0.8.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2dabfd05b96b7b8f2da00d53c514eea842bff83e41e1cceb08ae1966254a51df", size = 11589702 }, + { url = "https://files.pythonhosted.org/packages/26/3f/7602eb11d2886db545834182a9dbe500b8211fcbc9b4064bf9d358bbbbb4/ruff-0.8.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:facebdfe5a5af6b1588a1d26d170635ead6892d0e314477e80256ef4a8470cf3", size = 12782936 }, + { url = "https://files.pythonhosted.org/packages/4c/5d/083181bdec4ec92a431c1291d3fff65eef3ded630a4b55eb735000ef5f3b/ruff-0.8.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87a8e86bae0dbd749c815211ca11e3a7bd559b9710746c559ed63106d382bd9c", size = 11138488 }, + { url = "https://files.pythonhosted.org/packages/b7/23/c12cdef58413cee2436d6a177aa06f7a366ebbca916cf10820706f632459/ruff-0.8.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:85e654f0ded7befe2d61eeaf3d3b1e4ef3894469cd664ffa85006c7720f1e4a2", size = 10744474 }, + { url = "https://files.pythonhosted.org/packages/29/61/a12f3b81520083cd7c5caa24ba61bb99fd1060256482eff0ef04cc5ccd1b/ruff-0.8.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:83a55679c4cb449fa527b8497cadf54f076603cc36779b2170b24f704171ce70", size = 10369029 }, + { url = "https://files.pythonhosted.org/packages/08/2a/c013f4f3e4a54596c369cee74c24870ed1d534f31a35504908b1fc97017a/ruff-0.8.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:812e2052121634cf13cd6fddf0c1871d0ead1aad40a1a258753c04c18bb71bbd", size = 10867481 }, + { url = "https://files.pythonhosted.org/packages/d5/f7/685b1e1d42a3e94ceb25eab23c70bdd8c0ab66a43121ef83fe6db5a58756/ruff-0.8.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:780d5d8523c04202184405e60c98d7595bdb498c3c6abba3b6d4cdf2ca2af426", size = 11237117 }, + { url = "https://files.pythonhosted.org/packages/03/20/401132c0908e8837625e3b7e32df9962e7cd681a4df1e16a10e2a5b4ecda/ruff-0.8.0-py3-none-win32.whl", hash = "sha256:5fdb6efecc3eb60bba5819679466471fd7d13c53487df7248d6e27146e985468", size = 8783511 }, + { url = "https://files.pythonhosted.org/packages/1d/5c/4d800fca7854f62ad77f2c0d99b4b585f03e2d87a6ec1ecea85543a14a3c/ruff-0.8.0-py3-none-win_amd64.whl", hash = "sha256:582891c57b96228d146725975fbb942e1f30a0c4ba19722e692ca3eb25cc9b4f", size = 9559876 }, + { url = "https://files.pythonhosted.org/packages/5b/bc/cc8a6a5ca4960b226dc15dd8fb511dd11f2014ff89d325c0b9b9faa9871f/ruff-0.8.0-py3-none-win_arm64.whl", hash = "sha256:ba93e6294e9a737cd726b74b09a6972e36bb511f9a102f1d9a7e1ce94dd206a6", size = 8939733 }, +] + [[package]] name = "setuptools" version = "75.3.0"