From 800bd4a58f91ecf006e97688ac4a5df96ef560e4 Mon Sep 17 00:00:00 2001 From: Dmitri Prime Date: Fri, 20 Sep 2024 10:19:16 -0700 Subject: [PATCH] Fix docstring lints as reported by internal linter. (#180) --- compiler/back_end/cpp/header_generator.py | 70 +++++++++----- compiler/back_end/util/code_template_test.py | 2 +- compiler/front_end/attribute_checker.py | 1 + compiler/front_end/constraints.py | 1 + compiler/front_end/dependency_checker.py | 7 +- compiler/front_end/format_emb.py | 5 + compiler/front_end/lr1.py | 17 ++-- compiler/front_end/module_ir.py | 7 +- compiler/front_end/synthetics.py | 1 + compiler/front_end/type_check.py | 1 + compiler/util/attribute_util.py | 38 ++++---- compiler/util/ir_data.py | 10 +- compiler/util/ir_data_fields.py | 56 +++++++---- compiler/util/ir_data_fields_test.py | 36 +++---- compiler/util/ir_data_utils.py | 99 ++++++++++++++------ compiler/util/ir_data_utils_test.py | 54 +++++------ compiler/util/name_conversion.py | 13 ++- 17 files changed, 265 insertions(+), 153 deletions(-) diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py index 320b15f..e638471 100644 --- a/compiler/back_end/cpp/header_generator.py +++ b/compiler/back_end/cpp/header_generator.py @@ -188,7 +188,7 @@ class Config(NamedTuple): def _get_namespace_components(namespace): - """Gets the components of a C++ namespace + """Gets the components of a C++ namespace. Examples: "::some::name::detail" -> ["some", "name", "detail"] @@ -1339,10 +1339,12 @@ def _generate_structure_definition(type_ir, ir, config: Config): Arguments: type_ir: The IR for the struct definition. ir: The full IR; used for type lookups. + config: The code generation configuration to use. Returns: - A tuple of: (forward declaration for classes, class bodies, method bodies), - suitable for insertion into the appropriate places in the generated header. + A tuple of: (forward declaration for classes, class bodies, method + bodies), suitable for insertion into the appropriate places in the + generated header. """ subtype_bodies, subtype_forward_declarations, subtype_method_definitions = ( _generate_subtype_definitions(type_ir, ir, config) @@ -1530,10 +1532,17 @@ def _generate_structure_definition(type_ir, ir, config: Config): def _split_enum_case_values_into_spans(enum_case_value): """Yields spans containing each enum case in an enum_case attribute value. - Each span is of the form (start, end), which is the start and end position - relative to the beginning of the enum_case_value string. To keep the grammar - of this attribute simple, this only splits on delimiters and trims whitespace - for each case. + Arguments: + enum_case_value: the value of the `enum_case` attribute to be parsed. + + Returns: + An iterator over spans, where each span covers one enum case name. + Each span is a half-open range of the form [start, end), which is the + start and end position relative to the beginning of the enum_case_value + string. The name can be retrieved with `enum_case_value[start:end]`. + + To keep the grammar of this attribute simple, this only splits on + delimiters and trims whitespace for each case. Example: 'SHOUTY_CASE, kCamelCase' -> [(0, 11), (13, 23)]""" # Scan the string from left to right, finding commas and trimming whitespace. @@ -1568,6 +1577,12 @@ def _split_enum_case_values_into_spans(enum_case_value): def _split_enum_case_values(enum_case_value): """Returns all enum cases in an enum case value. + Arguments: + enum_case_value: the value of the enum case attribute to parse. + + Returns: + All enum case names from `enum_case_value`. + Example: 'SHOUTY_CASE, kCamelCase' -> ['SHOUTY_CASE', 'kCamelCase']""" return [ enum_case_value[start:end] @@ -1576,7 +1591,7 @@ def _split_enum_case_values(enum_case_value): def _get_enum_value_names(enum_value): - """Determines one or more enum names based on attributes""" + """Determines one or more enum names based on attributes.""" cases = ["SHOUTY_CASE"] name = enum_value.name.name.text if enum_case := ir_util.get_attribute( @@ -1698,14 +1713,16 @@ def _propagate_defaults(ir, targets, ancestors, add_fn): Traverses the IR to propagate default values to target nodes. Arguments: - targets: A list of target IR types to add attributes to. - ancestors: Ancestor types which may contain the default values. - add_fn: Function to add the attribute. May use any parameter available in - fast_traverse_ir_top_down actions as well as `defaults` containing the - default attributes set by ancestors. + ir: The IR to process. + targets: A list of target IR types to add attributes to. + ancestors: Ancestor types which may contain the default values. + add_fn: Function to add the attribute. May use any parameter available + in fast_traverse_ir_top_down actions as well as `defaults` + containing the + default attributes set by ancestors. Returns: - None + None """ traverse_ir.fast_traverse_ir_top_down( ir, @@ -1719,14 +1736,19 @@ def _propagate_defaults(ir, targets, ancestors, add_fn): def _offset_source_location_column(source_location, offset): - """Adds offsets from the start column of the supplied source location + """Adds offsets from the start column of the supplied source location. - Returns a new source location with all of the same properties as the provided - source location, but with the columns modified by offsets from the original - start column. + Arguments: + source_location: the initial source location + offset: a tuple of (start, end), which are the offsets relative to + source_location.start.column to set the new start.column and + end.column. - Offset should be a tuple of (start, end), which are the offsets relative to - source_location.start.column to set the new start.column and end.column.""" + Returns: + A new source location with all of the same properties as the provided + source location, but with the columns modified by offsets from the + original start column. + """ new_location = ir_data_utils.copy(source_location) new_location.start.column = source_location.start.column + offset[0] @@ -1863,8 +1885,12 @@ def _verify_attribute_values(ir): def _propagate_defaults_and_verify_attributes(ir): """Verify attributes and ensure defaults are set when not overridden. - Returns a list of errors if there are errors present, or an empty list if - verification completed successfully.""" + Arguments: + ir: The IR to process. + + Returns: + A list of errors if there are errors present, or an empty list if + verification completed successfully.""" if errors := attribute_util.check_attributes_in_ir( ir, back_end="cpp", diff --git a/compiler/back_end/util/code_template_test.py b/compiler/back_end/util/code_template_test.py index e4354fe..8149aaa 100644 --- a/compiler/back_end/util/code_template_test.py +++ b/compiler/back_end/util/code_template_test.py @@ -55,7 +55,7 @@ class ParseTemplatesTest(unittest.TestCase): """Tests for code_template.parse_templates.""" def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name - """Compares the results of a parse_templates""" + """Compares the results of a parse_templates.""" # Extract the name and template from the result tuple actual = {k: v.template for k, v in actual._asdict().items()} self.assertEqual(expected, actual) diff --git a/compiler/front_end/attribute_checker.py b/compiler/front_end/attribute_checker.py index 9e7fec2..d1df30c 100644 --- a/compiler/front_end/attribute_checker.py +++ b/compiler/front_end/attribute_checker.py @@ -44,6 +44,7 @@ def _valid_back_ends(attr, module_source_file): + """Checks that `attr` holds a valid list of back end specifiers.""" if not re.match( r"^(?:\s*[a-z][a-z0-9_]*\s*(?:,\s*[a-z][a-z0-9_]*\s*)*,?)?\s*$", attr.value.string_constant.text, diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py index 3249e6d..852c9ad 100644 --- a/compiler/front_end/constraints.py +++ b/compiler/front_end/constraints.py @@ -436,6 +436,7 @@ def _check_physical_type_requirements( def _check_allowed_in_bits(type_ir, type_definition, source_file_name, ir, errors): + """Verifies that atomic fields have types that are allowed in `bits`.""" if not type_ir.HasField("atomic_type"): return referenced_type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir) diff --git a/compiler/front_end/dependency_checker.py b/compiler/front_end/dependency_checker.py index 8a9e903..fb622e1 100644 --- a/compiler/front_end/dependency_checker.py +++ b/compiler/front_end/dependency_checker.py @@ -23,14 +23,15 @@ def _add_reference_to_dependencies( reference, dependencies, name, source_file_name, errors ): + """Adds the specified `reference` to the `dependencies` set.""" if reference.canonical_name.object_path[0] in { "$is_statically_sized", "$static_size_in_bits", "$next", }: - # This error is a bit opaque, but given that the compiler used to crash on - # this case -- for a couple of years -- and no one complained, it seems - # safe to assume that this is a rare error. + # This error is a bit opaque, but given that the compiler used to crash + # on this case -- for a couple of years -- and no one complained, it + # seems safe to assume that this is a rare error. errors.append( [ error.error( diff --git a/compiler/front_end/format_emb.py b/compiler/front_end/format_emb.py index df9bcf3..fc7bf94 100644 --- a/compiler/front_end/format_emb.py +++ b/compiler/front_end/format_emb.py @@ -485,6 +485,7 @@ def _type(struct, name, colon, comment, eol, body): " type-definition* struct-field-block Dedent" ) def _structure_body(indent, docs, attributes, type_definitions, fields, dedent, config): + """Formats a structure (`bits` or `struct`) body.""" del indent, dedent # Unused. spacing = [_Row("field-separator")] if _should_add_blank_lines(fields) else [] columnized_fields = _columnize(fields, config.indent_width, indent_columns=2) @@ -609,6 +610,7 @@ def _field_body(indent, docs, attributes, dedent): ' field-location "bits" ":" Comment? eol anonymous-bits-body' ) def _inline_bits(location, bits, colon, comment, eol, body): + """Formats an inline `bits` definition.""" # Even though an anonymous bits field technically defines a new, anonymous # type, conceptually it's more like defining a bunch of fields on the # surrounding type, so it is treated as an inline list of blocks, instead of @@ -1017,6 +1019,7 @@ def _concatenate(*elements): @_formats("or-expression-right -> or-operator comparison-expression") @_formats("and-expression-right -> and-operator comparison-expression") def _concatenate_with_prefix_spaces(*elements): + """Concatenates non-empty `elements` with leading spaces.""" return "".join(" " + element for element in elements if element) @@ -1032,10 +1035,12 @@ def _concatenate_with_prefix_spaces(*elements): ) @_formats('parameter-definition-list-tail -> "," parameter-definition') def _concatenate_with_spaces(*elements): + """Concatenates non-empty `elements` with spaces between.""" return _concatenate_with(" ", *elements) def _concatenate_with(joiner, *elements): + """Concatenates non-empty `elements` with `joiner` between.""" return joiner.join(element for element in elements if element) diff --git a/compiler/front_end/lr1.py b/compiler/front_end/lr1.py index be99f95..579d729 100644 --- a/compiler/front_end/lr1.py +++ b/compiler/front_end/lr1.py @@ -36,16 +36,17 @@ class Item( ): """An Item is an LR(1) Item: a production, a cursor location, and a terminal. - An Item represents a partially-parsed production, and a lookahead symbol. The - position of the dot indicates what portion of the production has been parsed. - Generally, Items are an internal implementation detail, but they can be useful - elsewhere, particularly for debugging. + An Item represents a partially-parsed production, and a lookahead symbol. + The position of the dot indicates what portion of the production has been + parsed. Generally, Items are an internal implementation detail, but they + can be useful elsewhere, particularly for debugging. Attributes: - production: The Production this Item covers. - dot: The index of the "dot" in production's rhs. - terminal: The terminal lookahead symbol that follows the production in the - input stream. + production: The Production this Item covers. + dot: The index of the "dot" in production's rhs. + terminal: The terminal lookahead symbol that follows the production in + the input stream. + next_symbol: The lookahead symbol. """ def __str__(self): diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py index bd27c8a..4a459c2 100644 --- a/compiler/front_end/module_ir.py +++ b/compiler/front_end/module_ir.py @@ -339,6 +339,7 @@ def _attribute( attribute_value, close_bracket, ): + """Assembles an attribute IR node.""" del open_bracket, colon, close_bracket # Unused. if context_specifier.list: return ir_data.Attribute( @@ -460,14 +461,15 @@ def _expression(expression): ' ":" logical-expression' ) def _choice_expression(condition, question, if_true, colon, if_false): + """Constructs an IR node for a choice operator (`?:`) expression.""" location = parser_types.make_location( condition.source_location.start, if_false.source_location.end ) operator_location = parser_types.make_location( question.source_location.start, colon.source_location.end ) - # The function_name is a bit weird, but should suffice for any error messages - # that might need it. + # The function_name is a bit weird, but should suffice for any error + # messages that might need it. return ir_data.Expression( function=ir_data.Function( function=ir_data.FunctionMapping.CHOICE, @@ -1284,6 +1286,7 @@ def _enum_body(indent, docs, attributes, values, dedent): def _enum_value( name, equals, expression, attribute, documentation, comment, newline, body ): + """Constructs an IR node for an enum value statement (`NAME = value`).""" del equals, comment, newline # Unused. result = ir_data.EnumValue( name=name, diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py index 1b331a3..f55e32f 100644 --- a/compiler/front_end/synthetics.py +++ b/compiler/front_end/synthetics.py @@ -236,6 +236,7 @@ def _add_size_virtuals(structure, type_definition): def _maybe_replace_next_keyword_in_expression( expression_ir, last_location, source_file_name, errors ): + """Replaces the `$next` keyword in an expression.""" if not expression_ir.HasField("builtin_reference"): return if ( diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py index f562cc8..c3f7870 100644 --- a/compiler/front_end/type_check.py +++ b/compiler/front_end/type_check.py @@ -133,6 +133,7 @@ def _type_check_constant_reference(expression, source_file_name, ir, errors): def _type_check_operation(expression, source_file_name, ir, errors): + """Type checks a function or operator expression.""" for arg in expression.function.args: _type_check_expression(arg, source_file_name, ir, errors) function = expression.function.function diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py index a83cf51..2ebbffa 100644 --- a/compiler/util/attribute_util.py +++ b/compiler/util/attribute_util.py @@ -323,22 +323,24 @@ def _check_attributes( ): """Performs basic checks on the given list of attributes. - Checks the given attribute_list for duplicates, unknown attributes, attributes - with incorrect type, and attributes whose values are not constant. + Checks the given attribute_list for duplicates, unknown attributes, + attributes with incorrect type, and attributes whose values are not + constant. Arguments: - attribute_list: An iterable of ir_data.Attribute. - back_end: The qualifier for attributes to check, or None. - attribute_specs: A dict of attribute names to _Attribute structures - specifying the allowed attributes. - context_name: A name for the context of these attributes, such as "struct - 'Foo'" or "module 'm.emb'". Used in error messages. - module_source_file: The value of module.source_file_name from the module - containing 'attribute_list'. Used in error messages. + attribute_list: An iterable of ir_data.Attribute. + types: A map of attribute types to validators. + back_end: The qualifier for attributes to check, or None. + attribute_specs: A dict of attribute names to _Attribute structures + specifying the allowed attributes. + context_name: A name for the context of these attributes, such as + "struct 'Foo'" or "module 'm.emb'". Used in error messages. + module_source_file: The value of module.source_file_name from the module + containing 'attribute_list'. Used in error messages. Returns: - A list of lists of error.Errors. An empty list indicates no errors were - found. + A list of lists of error.Errors. An empty list indicates no errors were + found. """ if attribute_specs is None: attribute_specs = [] @@ -395,17 +397,17 @@ def _check_attributes( def gather_default_attributes(obj, defaults): - """Gathers default attributes for an IR object + """Gathers default attributes for an IR object. - This is designed to be able to be used as-is as an incidental action in an IR - traversal to accumulate defaults for child nodes. + This is designed to be able to be used as-is as an incidental action in an + IR traversal to accumulate defaults for child nodes. Arguments: - defaults: A dict of `{ "defaults": { attr.name.text: attr } }` + defaults: A dict of `{ "defaults": { attr.name.text: attr } }` Returns: - A dict of `{ "defaults": { attr.name.text: attr } }` with any defaults - provided by `obj` added/overridden. + A dict of `{ "defaults": { attr.name.text: attr } }` with any defaults + provided by `obj` added/overridden. """ defaults = defaults.copy() for attr in obj.attribute: diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py index af8c2f7..fda124e 100644 --- a/compiler/util/ir_data.py +++ b/compiler/util/ir_data.py @@ -110,7 +110,10 @@ def HasField(self, name): # pylint:disable=invalid-name def WhichOneof(self, oneof_name): # pylint:disable=invalid-name """Indicates which field has been set for the oneof value. - Returns None if no field has been set. + Args: + oneof_name: the name of the oneof construct to test. + + Returns: the field name, or None if no field has been set. """ for field_name, oneof in self.field_specs.oneof_mappings: if oneof == oneof_name and self.HasField(field_name): @@ -207,7 +210,7 @@ class NumericConstant(Message): class FunctionMapping(int, enum.Enum): - """Enum of supported function types""" + """Enum of supported function types.""" UNKNOWN = 0 ADDITION = 1 @@ -823,8 +826,9 @@ class RuntimeParameter(Message): class AddressableUnit(int, enum.Enum): - """The "addressable unit" is the size of the smallest unit that can be read + """The 'atom size' for a structure. + The "addressable unit" is the size of the smallest unit that can be read from the backing store that this type expects. For `struct`s, this is BYTE; for `enum`s and `bits`, this is BIT, and for `external`s it depends on the specific type diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py index af6fbda..248518b 100644 --- a/compiler/util/ir_data_fields.py +++ b/compiler/util/ir_data_fields.py @@ -76,7 +76,7 @@ def _is_ir_dataclass(obj): class CopyValuesList(list[CopyValuesListT]): - """A list that makes copies of any value that is inserted""" + """A list that makes copies of any value that is inserted.""" def __init__( self, value_type: CopyValuesListT, iterable: Optional[Iterable[Any]] = None @@ -96,7 +96,7 @@ def extend(self, iterable: Iterable) -> None: return super().extend([self._copy(i) for i in iterable]) def shallow_copy(self, iterable: Iterable) -> None: - """Explicitly performs a shallow copy of the provided list""" + """Explicitly performs a shallow copy of the provided list.""" return super().extend(iterable) def append(self, obj: Any) -> None: @@ -107,15 +107,13 @@ def insert(self, index: SupportsIndex, obj: Any) -> None: class TemporaryCopyValuesList(NamedTuple): - """Class used to temporarily hold a CopyValuesList while copying and - constructing an IR dataclass. - """ + """Holder for a CopyValuesList while copying/constructing an IR dataclass.""" temp_list: CopyValuesList class FieldContainer(enum.Enum): - """Indicates a fields container type""" + """Indicates a fields container type.""" NONE = 0 OPTIONAL = 1 @@ -125,8 +123,8 @@ class FieldContainer(enum.Enum): class FieldSpec(NamedTuple): """Indicates the container and type of a field. - `FieldSpec` objects are accessed millions of times during runs so we cache as - many operations as possible. + `FieldSpec` objects are accessed millions of times during runs so we cache + as many operations as possible. - `is_dataclass`: `dataclasses.is_dataclass(data_type)` - `is_sequence`: `container is FieldContainer.LIST` - `is_enum`: `issubclass(data_type, enum.Enum)` @@ -162,7 +160,7 @@ def make_field_spec( def build_default(field_spec: FieldSpec): - """Builds a default instance of the given field""" + """Builds a default instance of the given field.""" if field_spec.is_sequence: return CopyValuesList(field_spec.data_type) if field_spec.is_enum: @@ -216,11 +214,20 @@ def get_specs(cls, data_class): def cache_message_specs(mod, cls): - """Adds a cached `field_specs` attribute to IR dataclasses in `mod` - excluding the given base `cls`. + """Adds `field_specs` to `mod`, excluding `cls`. + + Adds a cached `field_specs` attribute to IR dataclasses in module `mod` + excluding the given base class `cls`. This needs to be done after the dataclass decorators run and create the wrapped classes. + + Arguments: + mod: The module to process. + cls: The base class to exclude. + + Returns: + None """ for data_class in all_ir_classes(mod): if data_class is not cls: @@ -228,7 +235,7 @@ def cache_message_specs(mod, cls): def _field_specs(cls: type[IrDataT]) -> Mapping[str, FieldSpec]: - """Gets the IR data field names and types for the given IR data class""" + """Gets the IR data field names and types for the given IR data class.""" # Get the dataclass fields class_fields = dataclasses.fields(cast(Any, cls)) @@ -287,6 +294,12 @@ def field_specs(obj: Union[IrDataT, type[IrDataT]]) -> Mapping[str, FieldSpec]: """Retrieves the fields specs for the the give data type. The results of this method are cached to reduce lookup overhead. + + Arguments: + obj: Either an IR dataclass type, or an instance of such a type. + + Returns: + The field specs for `obj`. """ cls = obj if isinstance(obj, type) else type(obj) if cls is type(None): @@ -301,8 +314,11 @@ def fields_and_values( """Retrieves the fields and their values for a given IR data class. Args: - ir: The IR data class or a read-only wrapper of an IR data class. - value_filt: Optional filter used to exclude values. + ir: The IR data class or a read-only wrapper of an IR data class. + value_filt: Optional filter used to exclude values. + + Returns: + None """ set_fields: list[Tuple[FieldSpec, Any]] = [] specs: FilteredIrFieldSpecs = ir.field_specs @@ -329,6 +345,7 @@ def fields_and_values( # 5. None checks are only done in `copy()`, `_copy_set_fields` only # references `_copy()` to avoid this step. def _copy_set_fields(ir: IrDataT): + """Deep copies fields from IR node `ir`.""" values: MutableMapping[str, Any] = {} specs: FilteredIrFieldSpecs = ir.field_specs @@ -355,7 +372,7 @@ def _copy(ir: IrDataT) -> IrDataT: def copy(ir: IrDataT) -> Optional[IrDataT]: - """Creates a copy of the given IR data class""" + """Creates a copy of the given IR data class.""" if not ir: return None return _copy(ir) @@ -413,14 +430,14 @@ def __set__(self, obj, value): def oneof_field(name: str): - """Alternative for `datclasses.field` that sets up a oneof variable""" + """Alternative for `datclasses.field` that sets up a oneof variable.""" return dataclasses.field( # pylint:disable=invalid-field-call default=OneOfField(name), metadata={"oneof": name}, init=True ) def str_field(): - """Helper used to define a defaulted str field""" + """Helper used to define a defaulted str field.""" return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call @@ -436,7 +453,10 @@ class Foo: ``` Args: - cls_or_fn: The class type or a function that resolves to the class type. + cls_or_fn: The class type or a function that resolves to the class type. + + Returns: + A field with a `default_factory` that produces an appropriate list. """ def list_factory(c): diff --git a/compiler/util/ir_data_fields_test.py b/compiler/util/ir_data_fields_test.py index 344dd13..103c1ad 100644 --- a/compiler/util/ir_data_fields_test.py +++ b/compiler/util/ir_data_fields_test.py @@ -34,12 +34,12 @@ class TestEnum(enum.Enum): @dataclasses.dataclass class Opaque(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers.""" @dataclasses.dataclass class ClassWithUnion(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers.""" opaque: Optional[Opaque] = ir_data_fields.oneof_field("type") integer: Optional[int] = ir_data_fields.oneof_field("type") @@ -50,7 +50,7 @@ class ClassWithUnion(ir_data.Message): @dataclasses.dataclass class ClassWithTwoUnions(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers.""" opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1") integer: Optional[int] = ir_data_fields.oneof_field("type_1") @@ -62,7 +62,7 @@ class ClassWithTwoUnions(ir_data.Message): @dataclasses.dataclass class NestedClass(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers.""" one_union_class: Optional[ClassWithUnion] = None two_union_class: Optional[ClassWithTwoUnions] = None @@ -78,7 +78,7 @@ class ListCopyTestClass(ir_data.Message): @dataclasses.dataclass class OneofFieldTest(ir_data.Message): - """Basic test class for oneof fields""" + """Basic test class for oneof fields.""" int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1") int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1") @@ -86,7 +86,7 @@ class OneofFieldTest(ir_data.Message): class OneOfTest(unittest.TestCase): - """Tests for the the various oneof field helpers""" + """Tests for the the various oneof field helpers.""" def test_field_attribute(self): """Test the `oneof_field` helper.""" @@ -97,21 +97,21 @@ def test_field_attribute(self): self.assertEqual(test_field.metadata.get("oneof"), "type_1") def test_init_default(self): - """Test creating an instance with default fields""" + """Test creating an instance with default fields.""" one_of_field_test = OneofFieldTest() self.assertIsNone(one_of_field_test.int_field_1) self.assertIsNone(one_of_field_test.int_field_2) self.assertTrue(one_of_field_test.normal_field) def test_init(self): - """Test creating an instance with non-default fields""" + """Test creating an instance with non-default fields.""" one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) self.assertEqual(one_of_field_test.int_field_1, 10) self.assertIsNone(one_of_field_test.int_field_2) self.assertFalse(one_of_field_test.normal_field) def test_set_oneof_field(self): - """Tests setting oneof fields causes others in the group to be unset""" + """Tests setting oneof fields causes others in the group to be unset.""" one_of_field_test = OneofFieldTest() one_of_field_test.int_field_1 = 10 self.assertEqual(one_of_field_test.int_field_1, 10) @@ -128,8 +128,8 @@ def test_set_oneof_field(self): self.assertIsNone(one_of_field_test.int_field_1) self.assertEqual(one_of_field_test.int_field_2, 20) - # Now create a new instance and make sure changes to it are not reflected - # on the original object. + # Now create a new instance and make sure changes to it are not + # reflected on the original object. one_of_field_test_2 = OneofFieldTest() one_of_field_test_2.int_field_1 = 1000 self.assertEqual(one_of_field_test_2.int_field_1, 1000) @@ -138,7 +138,7 @@ def test_set_oneof_field(self): self.assertEqual(one_of_field_test.int_field_2, 20) def test_set_to_none(self): - """Tests explicitly setting a oneof field to None""" + """Tests explicitly setting a oneof field to None.""" one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) self.assertEqual(one_of_field_test.int_field_1, 10) self.assertIsNone(one_of_field_test.int_field_2) @@ -163,7 +163,7 @@ def test_set_to_none(self): self.assertFalse(one_of_field_test.normal_field) def test_oneof_specs(self): - """Tests the `oneof_field_specs` filter""" + """Tests the `oneof_field_specs` filter.""" expected = { "int_field_1": ir_data_fields.make_field_spec( "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" @@ -178,7 +178,7 @@ def test_oneof_specs(self): self.assertDictEqual(actual, expected) def test_oneof_mappings(self): - """Tests the `oneof_mappings` function""" + """Tests the `oneof_mappings` function.""" expected = (("int_field_1", "type_1"), ("int_field_2", "type_1")) actual = ir_data_fields.IrDataclassSpecs.get_specs( OneofFieldTest @@ -187,7 +187,7 @@ def test_oneof_mappings(self): class IrDataFieldsTest(unittest.TestCase): - """Tests misc methods in ir_data_fields""" + """Tests misc methods in ir_data_fields.""" def assertEmpty(self, obj): self.assertEqual(len(obj), 0, msg=f"{obj} is not empty.") @@ -196,7 +196,7 @@ def assertLen(self, obj, length): self.assertEqual(len(obj), length, msg=f"{obj} has length {len(obj)}.") def test_copy(self): - """Tests copying a data class works as expected""" + """Tests copying a data class works as expected.""" union = ClassWithTwoUnions( opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3] ) @@ -210,7 +210,7 @@ def test_copy(self): self.assertIsNone(empty_copy) def test_copy_values_list(self): - """Tests that CopyValuesList copies values""" + """Tests that CopyValuesList copies values.""" data_list = ir_data_fields.CopyValuesList(ListCopyTestClass) self.assertEmpty(data_list) @@ -222,7 +222,7 @@ def test_copy_values_list(self): self.assertEqual(i, list_test) def test_list_param_is_copied(self): - """Test that lists passed to constructors are converted to CopyValuesList""" + """Test that lists passed to constructors are converted to CopyValuesList.""" seq_field = [5, 6, 7] list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field) self.assertLen(list_test.seq_field, len(seq_field)) diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py index 8347625..f880231 100644 --- a/compiler/util/ir_data_utils.py +++ b/compiler/util/ir_data_utils.py @@ -91,7 +91,7 @@ def field_specs(ir: Union[MessageT, type[MessageT]]): class IrDataSerializer: - """Provides methods for serializing IR data objects""" + """Provides methods for serializing IR data objects.""" def __init__(self, ir: MessageT): assert ir is not None @@ -102,6 +102,7 @@ def _to_dict( ir: MessageT, field_func: Callable[[MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]], ) -> MutableMapping[str, Any]: + """Translates the IR to a standard Python `dict`.""" assert ir is not None values: MutableMapping[str, Any] = {} for spec, value in field_func(ir): @@ -117,12 +118,12 @@ def to_dict(self, exclude_none: bool = False): """Converts the IR data class to a dictionary.""" def non_empty(ir): - return fields_and_values( + return _fields_and_values( ir, lambda v: v is not None and (not isinstance(v, list) or len(v)) ) def all_fields(ir): - return fields_and_values(ir) + return _fields_and_values(ir) # It's tempting to use `dataclasses.asdict` here, but that does a deep # copy which is overkill for the current usage; mainly as an intermediary @@ -130,17 +131,17 @@ def all_fields(ir): return self._to_dict(self.ir, non_empty if exclude_none else all_fields) def to_json(self, *args, **kwargs): - """Converts the IR data class to a JSON string""" + """Converts the IR data class to a JSON string.""" return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs) @staticmethod def from_json(data_cls, data): - """Constructs an IR data class from the given JSON string""" + """Constructs an IR data class from the given JSON string.""" as_dict = json.loads(data) return IrDataSerializer.from_dict(data_cls, as_dict) def copy_from_dict(self, data): - """Deserializes the data and overwrites the IR data class with it""" + """Deserializes the data and overwrites the IR data class with it.""" cls = type(self.ir) data_copy = IrDataSerializer.from_dict(cls, data) for k in field_specs(cls): @@ -148,6 +149,7 @@ def copy_from_dict(self, data): @staticmethod def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum: + """Converts `val` to an instance of `enum_cls`.""" if isinstance(val, str): return getattr(enum_cls, val) return enum_cls(val) @@ -158,6 +160,7 @@ def _enum_type_hook(enum_cls: type[enum.Enum]): @staticmethod def _from_dict(data_cls: type[MessageT], data): + """Translates the given `data` dict to an instance of `data_cls`.""" class_fields: MutableMapping[str, Any] = {} for name, spec in ir_data_fields.field_specs(data_cls).items(): if (value := data.get(name)) is not None: @@ -188,12 +191,12 @@ def _from_dict(data_cls: type[MessageT], data): @staticmethod def from_dict(data_cls: type[MessageT], data): - """Creates a new IR data instance from a serialized dict""" + """Creates a new IR data instance from a serialized dict.""" return IrDataSerializer._from_dict(data_cls, data) class _IrDataSequenceBuilder(MutableSequence[MessageT]): - """Wrapper for a list of IR elements + """Wrapper for a list of IR elements. Simply wraps the returned values during indexed access and iteration with IrDataBuilders. @@ -236,7 +239,7 @@ def extend(self, values): class _IrDataBuilder(Generic[MessageT]): - """Wrapper for an IR element""" + """Wrapper for an IR element.""" def __init__(self, ir: MessageT) -> None: assert ir is not None @@ -254,10 +257,15 @@ def __setattr__(self, __name: str, __value: Any) -> None: def __getattribute__(self, name: str) -> Any: """Hook for `getattr` that handles adding missing fields. - If the field is missing inserts it, and then returns either the raw value - for basic types - or a new IrBuilder wrapping the field to handle the next field access in a - longer chain. + If the field is missing inserts it, and then returns either the raw + value for basic types or a new IrBuilder wrapping the field to handle + the next field access in a longer chain. + + Arguments: + name: the name of the attribute to set/retrieve + + Returns: + The value of the attribute `name`. """ # Check if getting one of the builder attributes @@ -294,12 +302,12 @@ def __getattribute__(self, name: str) -> Any: return obj def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name - """Updates the fields of this class with values set in the template""" + """Updates the fields of this class with values set in the template.""" update(cast(type[MessageT], self), template) def builder(target: MessageT) -> MessageT: - """Create a wrapper around the target to help build an IR Data structure""" + """Create a wrapper around the target to help build an IR Data structure.""" # Check if the target is already a builder. if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)): return target @@ -314,7 +322,7 @@ def builder(target: MessageT) -> MessageT: def _field_checker_from_spec(spec: ir_data_fields.FieldSpec): - """Helper that builds an FieldChecker that pretends to be an IR class""" + """Helper that builds an FieldChecker that pretends to be an IR class.""" if spec.is_sequence: return [] if spec.is_dataclass: @@ -323,13 +331,14 @@ def _field_checker_from_spec(spec: ir_data_fields.FieldSpec): def _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type[Any]: + """Returns the Python type of the given field.""" if isinstance(ir_or_spec, ir_data_fields.FieldSpec): return ir_or_spec.data_type return type(ir_or_spec) class _ReadOnlyFieldChecker: - """Class used the chain calls to fields that aren't set""" + """Class used to chain calls to fields that aren't set.""" def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None: self.ir_or_spec = ir_or_spec @@ -382,8 +391,7 @@ def __ne__(self, other): def reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT: - """Builds a read-only wrapper that can be used to check chains of possibly - unset fields. + """Builds a wrapper that can be used to read chains of possibly unset fields. This wrapper explicitly does not alter the wrapped object and is only intended for reading contents. @@ -391,18 +399,36 @@ def reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT: For example, a `reader` lets you do: ``` def get_function_name_end_column(function: ir_data.Function): - return reader(function).function_name.source_location.end.column + return reader(function).function_name.source_location.end.column ``` Instead of: ``` def get_function_name_end_column(function: ir_data.Function): - if function.function_name: - if function.function_name.source_location: - if function.function_name.source_location.end: - return function.function_name.source_location.end.column - return 0 + if function.function_name: + if function.function_name.source_location: + if function.function_name.source_location.end: + return function.function_name.source_location.end.column + return 0 ``` + + Arguments: + obj: The IR node to wrap. + + Returns: + An object whose attributes return either: + + The value of `obj.attr` if `attr` is an atomic type and is set on + `obj`. + + A default value for `obj.attr` if `obj.attr` is not set, but is of an + atomic type. + + A read-only wrapper around `obj.attr` if `obj.attr` is set and is an IR + node type. + + A read-only wrapper around an empty IR node object if `obj.attr` is not + set, and is of an IR node type. """ # Create a read-only wrapper if it's not already one. if not isinstance(obj, _ReadOnlyFieldChecker): @@ -426,15 +452,19 @@ def _extract_ir( return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper) -def fields_and_values( +def _fields_and_values( ir_wrapper: Union[MessageT, _ReadOnlyFieldChecker], value_filt: Optional[Callable[[Any], bool]] = None, ) -> list[Tuple[ir_data_fields.FieldSpec, Any]]: """Retrieves the fields and their values for a given IR data class. Args: - ir: The IR data class or a read-only wrapper of an IR data class. - value_filt: Optional filter used to exclude values. + ir: The IR data class or a read-only wrapper of an IR data class. + value_filt: Optional filter used to exclude values. + + Returns: + Fields and their values for the IR held by `ir_wrapper`, optionally + filtered by `value_filt`. """ if (ir := _extract_ir(ir_wrapper)) is None: return [] @@ -443,15 +473,22 @@ def fields_and_values( def get_set_fields(ir: MessageT): - """Retrieves the field spec and value of fields that are set in the given IR data class. + """Retrieves the field specs and values of fields that are set in `ir`. A value is considered "set" if it is not None. + + Arguments: + ir: The IR node to operate on. + + Returns: + The field specs and values of fields that are set in the given IR data + class. """ - return fields_and_values(ir, lambda v: v is not None) + return _fields_and_values(ir, lambda v: v is not None) def copy(ir_wrapper: Optional[MessageT]) -> Optional[MessageT]: - """Creates a copy of the given IR data class""" + """Creates a copy of the given IR data class.""" if (ir := _extract_ir(ir_wrapper)) is None: return None ir_copy = ir_data_fields.copy(ir) diff --git a/compiler/util/ir_data_utils_test.py b/compiler/util/ir_data_utils_test.py index c6e0435..954fc33 100644 --- a/compiler/util/ir_data_utils_test.py +++ b/compiler/util/ir_data_utils_test.py @@ -35,12 +35,12 @@ class TestEnum(enum.Enum): @dataclasses.dataclass class Opaque(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers.""" @dataclasses.dataclass class ClassWithUnion(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers.""" opaque: Optional[Opaque] = ir_data_fields.oneof_field("type") integer: Optional[int] = ir_data_fields.oneof_field("type") @@ -51,7 +51,7 @@ class ClassWithUnion(ir_data.Message): @dataclasses.dataclass class ClassWithTwoUnions(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers.""" opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1") integer: Optional[int] = ir_data_fields.oneof_field("type_1") @@ -65,7 +65,7 @@ class IrDataUtilsTest(unittest.TestCase): """Tests for the miscellaneous utility functions in ir_data_utils.py.""" def test_field_specs(self): - """Tests the `field_specs` method""" + """Tests the `field_specs` method.""" fields = ir_data_utils.field_specs(ir_data.TypeDefinition) self.assertIsNotNone(fields) expected_fields = ( @@ -132,7 +132,7 @@ def test_field_specs(self): self.assertEqual(fields["base_type"], expected_field) def test_is_sequence(self): - """Tests for the `FieldSpec.is_sequence` helper""" + """Tests for the `FieldSpec.is_sequence` helper.""" type_def = ir_data.TypeDefinition( attribute=[ ir_data.Attribute( @@ -151,7 +151,7 @@ def test_is_sequence(self): self.assertFalse(fields["is_default"].is_sequence) def test_is_dataclass(self): - """Tests FieldSpec.is_dataclass against ir_data""" + """Tests FieldSpec.is_dataclass against ir_data.""" type_def = ir_data.TypeDefinition( attribute=[ ir_data.Attribute( @@ -173,7 +173,7 @@ def test_is_dataclass(self): self.assertFalse(fields["fields_in_dependency_order"].is_dataclass) def test_get_set_fields(self): - """Tests that get set fields works""" + """Tests that get set fields works.""" type_def = ir_data.TypeDefinition( attribute=[ ir_data.Attribute( @@ -196,7 +196,7 @@ def test_get_set_fields(self): self.assertSetEqual(found_fields, expected_fields) def test_copy(self): - """Tests the `copy` helper""" + """Tests the `copy` helper.""" attribute = ir_data.Attribute( value=ir_data.AttributeValue(expression=ir_data.Expression()), name=ir_data.Word(text="phil"), @@ -219,7 +219,7 @@ def test_copy(self): self.assertIsNot(type_def.attribute, type_def_copy.attribute) def test_update(self): - """Tests the `update` helper""" + """Tests the `update` helper.""" attribute_template = ir_data.Attribute( value=ir_data.AttributeValue(expression=ir_data.Expression()), name=ir_data.Word(text="phil"), @@ -236,7 +236,7 @@ def test_update(self): class IrDataBuilderTest(unittest.TestCase): - """Tests for IrDataBuilder""" + """Tests for IrDataBuilder.""" def assertEmpty(self, obj): self.assertEqual(len(obj), 0, msg=f"{obj} is not empty.") @@ -245,7 +245,7 @@ def assertLen(self, obj, length): self.assertEqual(len(obj), length, msg=f"{obj} has length {len(obj)}.") def test_ir_data_builder(self): - """Tests that basic builder chains work""" + """Tests that basic builder chains work.""" # We start with an empty type type_def = ir_data.TypeDefinition() self.assertFalse(type_def.HasField("name")) @@ -264,7 +264,7 @@ def test_ir_data_builder(self): self.assertEqual(type_def.name.name.text, "phil") def test_ir_data_builder_bad_field(self): - """Tests accessing an undefined field name fails""" + """Tests accessing an undefined field name fails.""" type_def = ir_data.TypeDefinition() builder = ir_data_utils.builder(type_def) self.assertRaises(AttributeError, lambda: builder.foo) @@ -272,7 +272,7 @@ def test_ir_data_builder_bad_field(self): self.assertRaises(AttributeError, getattr, type_def, "foo") def test_ir_data_builder_sequence(self): - """Tests that sequences are properly wrapped""" + """Tests that sequences are properly wrapped.""" # We start with an empty type type_def = ir_data.TypeDefinition() self.assertTrue(type_def.HasField("attribute")) @@ -374,7 +374,7 @@ def test_copy_from_list(self): ) def test_ir_data_builder_sequence_scalar(self): - """Tests that sequences of scalars function properly""" + """Tests that sequences of scalars function properly.""" # We start with an empty type structure = ir_data.Structure() @@ -410,10 +410,10 @@ def test_ir_data_builder_oneof(self): class IrDataSerializerTest(unittest.TestCase): - """Tests for IrDataSerializer""" + """Tests for IrDataSerializer.""" def test_ir_data_serializer_to_dict(self): - """Tests serialization with `IrDataSerializer.to_dict` with default settings""" + """Tests serialization with `IrDataSerializer.to_dict` with default settings.""" attribute = ir_data.Attribute( value=ir_data.AttributeValue(expression=ir_data.Expression()), name=ir_data.Word(text="phil"), @@ -444,7 +444,7 @@ def test_ir_data_serializer_to_dict(self): self.assertDictEqual(raw_dict, expected) def test_ir_data_serializer_to_dict_exclude_none(self): - """Tests serialization with `IrDataSerializer.to_dict` when excluding None values""" + """.Tests serialization with `IrDataSerializer.to_dict` when excluding None values""" attribute = ir_data.Attribute( value=ir_data.AttributeValue(expression=ir_data.Expression()), name=ir_data.Word(text="phil"), @@ -455,7 +455,7 @@ def test_ir_data_serializer_to_dict_exclude_none(self): self.assertDictEqual(raw_dict, expected) def test_ir_data_serializer_to_dict_enum(self): - """Tests that serialization of `enum.Enum` values works properly""" + """Tests that serialization of `enum.Enum` values works properly.""" type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE) serializer = ir_data_utils.IrDataSerializer(type_def) raw_dict = serializer.to_dict(exclude_none=True) @@ -463,7 +463,7 @@ def test_ir_data_serializer_to_dict_enum(self): self.assertDictEqual(raw_dict, expected) def test_ir_data_serializer_from_dict(self): - """Tests deserializing IR data from a serialized dict""" + """Tests deserializing IR data from a serialized dict.""" attribute = ir_data.Attribute( value=ir_data.AttributeValue(expression=ir_data.Expression()), name=ir_data.Word(text="phil"), @@ -474,7 +474,7 @@ def test_ir_data_serializer_from_dict(self): self.assertEqual(attribute, new_attribute) def test_ir_data_serializer_from_dict_enum(self): - """Tests that deserializing `enum.Enum` values works properly""" + """Tests that deserializing `enum.Enum` values works properly.""" type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE) serializer = ir_data_utils.IrDataSerializer(type_def) @@ -483,7 +483,7 @@ def test_ir_data_serializer_from_dict_enum(self): self.assertEqual(type_def, new_type_def) def test_ir_data_serializer_from_dict_enum_is_str(self): - """Tests that deserializing `enum.Enum` values works properly when string constant is used""" + """Tests that deserializing `enum.Enum` values works properly when string constant is used.""" type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE) raw_dict = {"addressable_unit": "BYTE"} serializer = ir_data_utils.IrDataSerializer(type_def) @@ -491,7 +491,7 @@ def test_ir_data_serializer_from_dict_enum_is_str(self): self.assertEqual(type_def, new_type_def) def test_ir_data_serializer_from_dict_exclude_none(self): - """Tests that deserializing from a dict that excluded None values works properly""" + """Tests that deserializing from a dict that excluded None values works properly.""" attribute = ir_data.Attribute( value=ir_data.AttributeValue(expression=ir_data.Expression()), name=ir_data.Word(text="phil"), @@ -558,7 +558,7 @@ def test_from_dict_list(self): self.assertIsNotNone(func) def test_ir_data_serializer_copy_from_dict(self): - """Tests that updating an IR data struct from a dict works properly""" + """Tests that updating an IR data struct from a dict works properly.""" attribute = ir_data.Attribute( value=ir_data.AttributeValue(expression=ir_data.Expression()), name=ir_data.Word(text="phil"), @@ -573,10 +573,10 @@ def test_ir_data_serializer_copy_from_dict(self): class ReadOnlyFieldCheckerTest(unittest.TestCase): - """Tests the ReadOnlyFieldChecker""" + """Tests the ReadOnlyFieldChecker.""" def test_basic_wrapper(self): - """Tests basic field checker actions""" + """Tests basic field checker actions.""" union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10) field_checker = ir_data_utils.reader(union) @@ -597,7 +597,7 @@ def test_basic_wrapper(self): self.assertTrue(field_checker.HasField("non_union_field")) def test_construct_from_field_checker(self): - """Tests that constructing from another field checker works""" + """Tests that constructing from another field checker works.""" union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10) field_checker_orig = ir_data_utils.reader(union) field_checker = ir_data_utils.reader(field_checker_orig) @@ -621,7 +621,7 @@ def test_construct_from_field_checker(self): self.assertTrue(field_checker.HasField("non_union_field")) def test_read_only(self) -> None: - """Tests that the read only wrapper really is read only""" + """Tests that the read only wrapper really is read only.""" union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10) field_checker = ir_data_utils.reader(union) diff --git a/compiler/util/name_conversion.py b/compiler/util/name_conversion.py index b13d667..b4303fb 100644 --- a/compiler/util/name_conversion.py +++ b/compiler/util/name_conversion.py @@ -63,10 +63,19 @@ def snake_to_k_camel(name): def convert_case(case_from, case_to, value): """Converts cases based on runtime case values. - Note: Cases can be strings or enum values.""" + Note: Cases can be strings or enum values. + + Arguments: + case_from: the name of the original case + case_to: the name of the desired case + value: the value to convert + + Returns: + `value` converted from `case_from` to `case_to`. + """ return _case_conversions[case_from, case_to](value) def is_case_conversion_supported(case_from, case_to): - """Determines if a case conversion would be supported""" + """Determines if a case conversion would be supported.""" return (case_from, case_to) in _case_conversions