diff --git a/src/mccode_antlr/translators/c_listener.py b/src/mccode_antlr/translators/c_listener.py index fee8616..38e7d26 100644 --- a/src/mccode_antlr/translators/c_listener.py +++ b/src/mccode_antlr/translators/c_listener.py @@ -91,8 +91,8 @@ def __str__(self): def __hash__(self): return hash(str(self)) - def as_struct_member(self, max_array_length): - dec = self.declare.as_struct_member(max_array_length=max_array_length) + def as_struct_member(self, dims: int, max_array_length): + dec = self.declare.as_struct_member(dims=dims, max_array_length=max_array_length) return self.string(dec) @@ -101,10 +101,14 @@ class CDeclarator: declare: str | CFuncPointer pointer: str | None = None extensions: list[str] = field(default_factory=list) - elements: int | str | None = None + elements: tuple[int,...] | tuple[str,...] | None = None dtype: str | None = None init: str | None = None + def __post_init__(self): + if self.elements is not None and not isinstance(self.elements, tuple): + self.elements = tuple(self.elements) + @property def is_pointer(self) -> bool: return self.pointer is not None and len(self.pointer.strip(' ')) > 0 @@ -164,17 +168,47 @@ def __str__(self): def __hash__(self): return hash(str(self)) - def as_struct_member(self, max_array_length: int = 16384): + def as_struct_member(self, dims: int = 0, max_array_length: int = 16384): if self.init: - max_array_length = min(len(self.init.split(',')), max_array_length) - no = self.elements if self.elements else max_array_length + max_array_length = extract_initializer_size(self.init) + dims = len(self.elements) if self.elements else 0 + mal = max_array_length if isinstance(max_array_length, tuple) else (max_array_length,) + if len(mal) < dims: + mal *= dims // len(mal) + no = self.elements if self.elements else mal + if not isinstance(no, tuple): + raise RuntimeError(f'{no} should be a tuple but is a {type(no)}') if isinstance(self.declare, CFuncPointer): - return self.string(self.declare.as_struct_member(max_array_length=no)) + return self.string(self.declare.as_struct_member(dims=dims, max_array_length=no)) elif self.elements is not None: - return f'{self}[{no}]' + if 0 in no: + # Is this a bad idea? + no = tuple(m if x == 0 else x for x, m in zip(no, mal)) + return f"{self}{''.join(f'[{n}]' for n in no)}" return str(self) +def extract_initializer_size(init: str) -> tuple[int,...]: + """Assuming that the provided initializer string is for an array, possibly + of more statically sized arrays, and that it has been written correctly, + extract the size of each nested dimension + + >>> extract_initializer_size("{{0, 1, 2}, {3, 4, 5}}") + (2, 3) + + >>> extract_initializer_size('{{{0}, {1}}, {{2}, {3}}, {{4}, {5}}, {{6}, {7}}}') + (4, 2, 1) + """ + import re + size = [] + inner = re.compile('{([^{}]*)}') + x = init + while len(y:=inner.findall(x)): + size.append(len(y[0].split(','))) + x = inner.sub('_', x) + return tuple(reversed(size)) + + class DeclaresCVisitor(CVisitor): def __init__(self, typedefs: list | None = None, verbose: bool = False): self.verbose = verbose @@ -294,15 +328,18 @@ def visitDeclarator(self, ctx:CParser.DeclaratorContext): # dec.declare.elements = None pass elif all(x in dec for x in ('[', ']')): - if dec.count('[') > 1 or dec.count(']') > 1: - raise RuntimeError('No idea how to handle multi-level arrays') - dec, num_post = dec.split('[', 1) - num, _ = num_post.split(']', 1) + dec, post = dec.split('[', 1) + nums = tuple() + while ']' in post: + num, post = post.split(']', 1) + nums += (num,) + if '[' in post: + _, post = post.split('[', 1) try: - elements = int(num) if len(num) else 0 + elements = tuple(int(n) if len(n) else 0 for n in nums) except ValueError as er: - logger.info(f"Could not convert an integer from {num} due to {er}") - elements = num + logger.info(f"Could not convert an integer from {nums} due to {er}") + elements = nums return CDeclarator(pointer=ptr, declare=dec, extensions=extensions, elements=elements) def visitPointer(self, ctx:CParser.PointerContext): diff --git a/tests/test_c_type_declaration.py b/tests/test_c_type_declaration.py index 4e2962f..2589a98 100644 --- a/tests/test_c_type_declaration.py +++ b/tests/test_c_type_declaration.py @@ -84,7 +84,7 @@ def test_assignments(): expected = [ CDeclarator(dtype='int', declare='blah', init='1'), CDeclarator(dtype='double', declare='yarg'), - CDeclarator(dtype='char', declare='mmmm', init='"0123456789"', elements=11), + CDeclarator(dtype='char', declare='mmmm', init='"0123456789"', elements=(11,)), ] for x in expected: assert x in variables @@ -121,7 +121,7 @@ def test_typedef_declaration(): expected = [ CDeclarator(dtype='blah', declare='really_a_double', init='1.0f'), CDeclarator(dtype='blah', declare='double_ptr', pointer='*', init='NULL'), - CDeclarator(dtype='blah', declare='double_array', elements=10) + CDeclarator(dtype='blah', declare='double_array', elements=(10,)) ] assert all(x in variables for x in expected) @@ -178,14 +178,14 @@ def test_function_pointer_declaration(): CDeclarator( dtype='int', declare=CFuncPointer( - declare=CDeclarator(pointer='*', declare='fun_ptr_ar3', elements=3), + declare=CDeclarator(pointer='*', declare='fun_ptr_ar3', elements=(3,)), args='int, int', ), ), CDeclarator( dtype='int', declare=CFuncPointer( - declare=CDeclarator(pointer='*', declare='fun_ptr_arr', elements=0), + declare=CDeclarator(pointer='*', declare='fun_ptr_arr', elements=(0,)), args='int, int', ), init='{add, sub, mul}', @@ -198,5 +198,35 @@ def test_function_pointer_declaration(): 'int (* fun_ptr_ar3[3])(int, int)', 'int (* fun_ptr_arr[3])(int, int)', ] + for x, y in zip(members, variables): + assert x == y.as_struct_member() + +def test_multi_level_static_array_types(): + big_table=dedent("""{ + { 2.087063, 0.23391E+00 ,7.485, 2.094, 0.55, 11.81, 63.54, 315, 12.00, 0.30000E+00}, + { 1.80745 , 0.17544E+00 ,7.485, 2.094, 0.55, 11.81, 63.54, 315, 12.00, 0.30000E+00}, + { 1.27806 , 0.87718E-01 ,7.485, 2.094, 0.55, 11.81, 63.54, 315, 12.00, 0.30000E+00}, + { 1.089933, 0.63795E-01 ,7.485, 2.094, 0.55, 11.81, 63.54, 315, 12.00, 0.30000E+00}, + { 0.903725, 0.43859E-01 ,7.485, 2.094, 0.55, 11.81, 63.54, 315, 12.00, 0.30000E+00}, + { 0.829315, 0.36934E-01 ,7.485, 2.094, 0.55, 11.81, 63.54, 315, 12.00, 0.30000E+00}, + { 0.808316, 0.35087E-01 ,7.485, 2.094, 0.55, 11.81, 63.54, 315, 12.00, 0.30000E+00} + }""") # .replace('\n','') + block = dedent(f"""\ + double mmsa[2][3] = {{{{0, 1, 2}}, {{3, 4, 5}}}}; + int mmi[][][] = {{{{{{0}}, {{1}}}}, {{{{2}}, {{3}}}}, {{{{4}}, {{5}}}}, {{{{6}}, {{7}}}}}}; + double big_table[7][10] = {big_table}; + """) + variables, types = extract(block) + expected = [ + CDeclarator(dtype='double', declare='mmsa', elements=(2, 3), + init='{{0, 1, 2}, {3, 4, 5}}',), + CDeclarator(dtype='int', declare='mmi', elements=(0, 0, 0), + init='{{{0}, {1}}, {{2}, {3}}, {{4}, {5}}, {{6}, {7}}}'), + CDeclarator(dtype='double', declare='big_table', elements=(7, 10), + init=big_table) + ] + for x, y in zip(expected, variables): + assert x == y + members = ['double mmsa[2][3]', 'int mmi[4][2][1]', 'double big_table[7][10]'] for x, y in zip(members, variables): assert x == y.as_struct_member() \ No newline at end of file