Skip to content

Commit

Permalink
headers: Fix Function Pointer as parameter arguments
Browse files Browse the repository at this point in the history
Use libclang to detect kinds of parameters
  • Loading branch information
bwrsandman committed Sep 28, 2024
1 parent 8d2f689 commit 0f816eb
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 13 deletions.
5 changes: 3 additions & 2 deletions scripts/headers/bw1_decomp_gen/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@


def clean_up_type(typename):
if type(typename) is csnake.FuncPtr:
return typename
type_part, pointer_part, after_pointer_part = typename.partition("*")
type_part = type_part.rstrip()
type_part = TYPE_SUBSTITUTIONS.get(type_part, type_part)
Expand Down Expand Up @@ -110,8 +112,7 @@ def from_json(cls, decl: dict) -> "FuncPtr":

def to_csnake(self) -> CSnakeFuncPtr:
conv = self.call_type
params = [[l or f"param_{i}", a] for i, (l, a) in enumerate(
zip(self.arg_labels, self.args))]
params = [[l or f"param_{i}", a] for i, (l, a) in enumerate(zip(self.arg_labels, self.args))]
if conv == "__thiscall":
params[0][0] = "this"
if len(params) > 1:
Expand Down
71 changes: 67 additions & 4 deletions scripts/headers/bw1_decomp_gen/generate_headers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import time
import shutil
import sys
from clang import cindex
from json import load
from pathlib import Path
from csnake import CodeWriter
import csnake

from header import Header
from header import Header, C_STDLIB_HEADER_IMPORT_MAP
from structs import Struct, Union, Enum, RTTIClass
from typedef import Typedef
from functions import FuncPtr, DefinedFunctionPrototype
from functions import FuncPtr, DefinedFunctionPrototype, CSnakeFuncPtr
from vftable import Vftable
from utils import partition, extract_type_name
from vanilla_filepaths import map_projects_to_object_files, get_object_file_base_names, roomate_classes
Expand Down Expand Up @@ -64,6 +65,66 @@ def is_globals_helper_struct(name: str) -> bool:
'FUNCTIONPROTO': FuncPtr,
}

# TODO: Do every type of variable
def arg_clang_wrapping_declaration_to_csnake(wrapping_declaration):
assert wrapping_declaration.kind.name == "FUNCTION_DECL"

param_declaration = next(wrapping_declaration.get_arguments())
assert param_declaration.kind.name == "PARM_DECL"

type_ = param_declaration.type

is_pointer = type_.kind.name == "POINTER"

if is_pointer:
type_ = type_.get_pointee()
if type_.kind.name == "FUNCTIONPROTO":
call_type = None
for attr in ["fastcall", "cdecl", "thiscall", "stdcall"]:
if type_.spelling.endswith(f"__attribute__(({attr}))"):
call_type = f"__{attr}"
break
return CSnakeFuncPtr(type_.get_result().spelling, [(c.displayname or f"param_{i + 1}", c.type.spelling) for i, c in enumerate(param_declaration.get_children())], call_type)

return None


def arg_to_csnake(type_decl):
"""This function is slow because it launched the compiler for each param. It can easily go from less than a second to more than 20 seconds"""
source = f"void __arg_to_csnake_wrapping_declaration({type_decl});"
translation_unit = cindex.TranslationUnit.from_source('tmp.c', args=["-m32"], unsaved_files=[('tmp.c', source)])
if [d for d in translation_unit.diagnostics if d.severity >= cindex.Diagnostic.Error]:
return type_decl
result = arg_clang_wrapping_declaration_to_csnake(next(c for c in translation_unit.cursor.get_children() if c.spelling == "__arg_to_csnake_wrapping_declaration"))
return result or type_decl


def batched_arg_to_csnake(type_decls):
# Slow path
# for p in type_decls:
# p.args = list(map(arg_to_csnake, p.args))

# Fast path
arg_list_list = []
for p in type_decls:
arg_list_list.append(p.args)

source_list = [f"#include <{i}>" for i in set(C_STDLIB_HEADER_IMPORT_MAP.values())]
for i, p in enumerate(arg_list_list):
for j, a in enumerate(p):
source_list.append(f"void __arg_to_csnake_wrapping_declaration__{i}__{j}__({a});")

source = "\n".join(source_list)
translation_unit = cindex.TranslationUnit.from_source('tmp.c', args=["-m32"], unsaved_files=[('tmp.c', source)])

assert len([d for d in translation_unit.diagnostics if d.severity >= cindex.Diagnostic.Error]) == 0

for wrapping_declaration in (c for c in translation_unit.cursor.get_children() if c.spelling.startswith("__arg_to_csnake_wrapping_declaration")):
i, j = wrapping_declaration.spelling.removeprefix("__arg_to_csnake_wrapping_declaration__").removesuffix("__").split("__")
type_ = arg_clang_wrapping_declaration_to_csnake(wrapping_declaration)
if type_:
type_decls[int(i)].args[int(j)] = type_


# TODO: For each global and their types, create inspector entires: webserver or imgui inspector window
if __name__ == "__main__":
Expand Down Expand Up @@ -147,6 +208,8 @@ def is_ignore_struct(data_type) -> bool:
lambda x: type(x) is Struct,
], primitives)

batched_arg_to_csnake(vftable_function_prototypes)

vftable_function_proto_map = {i.name: i for i in vftable_function_prototypes}

lh_linked_list_pointer_structs = {"struct " + i.name.removeprefix("LHLinkedList__p_").removeprefix("LHLinkedNode__p_") for i in lh_linked_pointer_lists}
Expand Down Expand Up @@ -267,7 +330,7 @@ def get_path(name):
wrote_headers = 0
wrote_bytes = 0
for h in headers:
cw = CodeWriter(indent=2)
cw = csnake.CodeWriter(indent=2)
h.to_code(cw)
path = output_path / h.path
path.parent.mkdir(parents=True, exist_ok=True)
Expand Down
6 changes: 4 additions & 2 deletions scripts/headers/bw1_decomp_gen/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def strip_arrays_and_modifiers(c_type):


def strip_pointers_arrays_and_modifiers(c_type):
if isinstance(c_type, csnake.FuncPtr):
return c_type
c_type = re.sub(r'\*', '', c_type)
c_type = strip_arrays_and_modifiers(c_type)
return c_type
Expand Down Expand Up @@ -161,7 +163,7 @@ def get_includes(self) -> list[str]:

def get_direct_dependencies(self) -> set[str]:
result = self.get_types()
pointers = set(filter(lambda x: '*' in x and strip_pointers_arrays_and_modifiers(x) not in C_STDLIB_TYPEDEFS, result))
pointers = set(filter(lambda x: isinstance(x, csnake.FuncPtr) or ('*' in x and strip_pointers_arrays_and_modifiers(x) not in C_STDLIB_TYPEDEFS), result))
result.difference_update(pointers)
lh_lists = {i for i in result if i.startswith("struct LHListHead__") or i.startswith("struct LHLinkedList__")}
lh_lists_underlying_type = {"struct " + i.removeprefix("struct ").removeprefix("LHListHead__").removeprefix("LHLinkedList__p_").removeprefix("LHLinkedList__") for i in lh_lists}
Expand Down Expand Up @@ -189,7 +191,7 @@ def get_forward_declare_types(self) -> set[str]:
for s in self.structs:
defined_types_so_far.add(f"struct {strip_pointers_arrays_and_modifiers(s.name)}")
struct_types = {strip_pointers_arrays_and_modifiers(r) for r in s.get_types()}
struct_types = {r for r in struct_types if r.startswith("struct ") or r.startswith("union ") or r.startswith("enum ")}
struct_types = {r for r in struct_types if type(r) is str and (r.startswith("struct ") or r.startswith("union ") or r.startswith("enum "))}
struct_types.difference_update(defined_types_so_far)
result.update(struct_types)

Expand Down
4 changes: 2 additions & 2 deletions scripts/headers/bw1_decomp_gen/vftable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional
import csnake

from dataclasses import dataclass
Expand All @@ -13,7 +13,7 @@ class Vftable(Struct):
class Member:
name: str
type: Union[FuncPtr, str]
comment: str
comment: Optional[str]

def to_csnake(self) -> csnake.Variable:
if type(self.type) is FuncPtr:
Expand Down
19 changes: 19 additions & 0 deletions scripts/headers/tests/function_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import unittest
import sys
import os
import csnake
sys.path.append(os.path.dirname(__file__) + "/../bw1_decomp_gen")

from functions import FuncPtr, CSnakeFuncPtr


class TestCSnakeFuncPtr(unittest.TestCase):
def test_func_ptr_func_ptr_arg_to_code(self):
f = CSnakeFuncPtr("void", [("foo", csnake.FuncPtr("void", [("param_1", "int"), ("param_2", "float"), ("param_3", "int")]))], "__fastcall")
self.assertEqual(f.get_declaration("Foo"), "void (__fastcall* Foo)(void (*foo)(int param_1, float param_2, int param_3))")

class TestFuncPtr(unittest.TestCase):
def test_func_ptr_func_ptr_arg_to_code(self):
f = FuncPtr("TestStructVftable__Foo", "__fastcall", "void", [csnake.FuncPtr("void", [("param_1", "int"), ("param_2", "float"), ("param_3", "int")])], ["foo"])
csnake_obj = f.to_csnake()
self.assertEqual(csnake_obj.get_declaration("Foo"), "void (__fastcall* Foo)(void (*foo)(int param_1, float param_2, int param_3))")
41 changes: 41 additions & 0 deletions scripts/headers/tests/header_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import os
import csnake
sys.path.append(os.path.dirname(__file__) + "/../bw1_decomp_gen")

import unittest
Expand Down Expand Up @@ -621,5 +622,45 @@ def test_structs_with_functions(self):
// win1.41 00405070 mac 10102030 TestChildStruct::Qux(int)
char* __fastcall Qux__15TestChildStructFi(struct TestChildStruct* this, const void* edx, int test);
#endif /* BW1_DECOMP_TEST_HEADER_INCLUDED_H */
""")

def test_class_with_callback_in_vftable(self):
function_proto_map = {
"TestStructVftable__Foo": FuncPtr("TestStructVftable__Foo", "__thiscall", "void", ["struct TestStruct *", csnake.FuncPtr("void", [("param_1", "int"), ("param_2", "float"), ("param_3", "int")])], ["this", "foo"]),
}
virtual_table_function_names = (
"Foo",
)
structs: list[Struct] = [
Vftable(Struct("TestStructVftable", 4, [Struct.Member("Foo", "TestStructVftable__Foo*", 0x0)]), function_proto_map),
RTTIClass(Struct("TestStruct", 4, [Struct.Member("vftable", "struct TestStructVftable*", 0x0)]), {}, virtual_table_function_names, {}, {}),
]
header = Header(self.path, includes=[], structs=structs)
header.build_include_list({})
header.to_code(self.cw)

self.assertEqual(self.cw.code,
"""\
#ifndef BW1_DECOMP_TEST_HEADER_INCLUDED_H
#define BW1_DECOMP_TEST_HEADER_INCLUDED_H
#include <assert.h> /* For static_assert */
// Forward Declares
struct TestStruct;
struct TestStructVftable
{
void (__fastcall* Foo)(struct TestStruct* this, const void* edx, void (*foo)(int param_1, float param_2, int param_3));
};
static_assert(sizeof(struct TestStructVftable) == 0x4, "Data type is of wrong size");
struct TestStruct
{
struct TestStructVftable* vftable;
};
static_assert(sizeof(struct TestStruct) == 0x4, "Data type is of wrong size");
#endif /* BW1_DECOMP_TEST_HEADER_INCLUDED_H */
""")
27 changes: 24 additions & 3 deletions scripts/headers/tests/struct_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from vftable import Vftable
from structs import Struct
from functions import FuncPtr
import unittest
import sys
import os
import csnake
sys.path.append(os.path.dirname(__file__) + "/../bw1_decomp_gen")

from vftable import Vftable
from structs import Struct
from functions import FuncPtr


class TestMember(unittest.TestCase):
def test_func_ptr_func_ptr_arg_to_code(self):
m = Vftable.Member("Foo", FuncPtr("TestStructVftable__Foo", "__thiscall", "void", ["struct TestStruct *", csnake.FuncPtr("void", [("param_1", "int"), ("param_2", "float"), ("param_3", "int")])], ["this", "foo"]), None)
csnake_obj = m.to_csnake()
self.assertEqual(csnake_obj.declaration, "void (__fastcall* Foo)(struct TestStruct* this, const void* edx, void (*foo)(int param_1, float param_2, int param_3));")


class TestStruct(unittest.TestCase):
def test_get_type_int(self):
Expand All @@ -21,3 +30,15 @@ def test_get_type_vftable(self):
"Foo", "TestVftable__Foo*", 0x0), Struct.Member("Bar", "TestVftable__Bar*", 0x4)]), function_proto_map)
self.assertSetEqual(s.get_types(), {
"struct TestVftable", "struct Test*", "size_t", "char*", "float"})

def test_func_ptr_func_ptr_arg_to_code(self):
function_proto_map = {
"TestStructVftable__Foo": FuncPtr("TestStructVftable__Foo", "__thiscall", "void", ["struct TestStruct *", csnake.FuncPtr("void", [("param_1", "int"), ("param_2", "float"), ("param_3", "int")])], ["this", "foo"]),
}
s = Vftable(Struct("TestStructVftable", 4, [Struct.Member("Foo", "TestStructVftable__Foo*", 0x0)]), function_proto_map)
csnake_obj = s.to_csnake()
self.assertEqual(csnake_obj.declaration.code, """\
struct TestStructVftable
{
void (__fastcall* Foo)(struct TestStruct* this, const void* edx, void (*foo)(int param_1, float param_2, int param_3));
};""")

0 comments on commit 0f816eb

Please sign in to comment.