-
Notifications
You must be signed in to change notification settings - Fork 1
/
wrapper.t
188 lines (161 loc) · 6.22 KB
/
wrapper.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
-- SPDX-FileCopyrightText: 2024 René Hiemstra <[email protected]>
-- SPDX-FileCopyrightText: 2024 Torsten Keßler <[email protected]>
--
-- SPDX-License-Identifier: MIT
local function tuple_type_to_list(tpl)
--[=[
Return list of types in given tuple type
Args:
tpl: Tuple type
Returns:
One-based terra list composed of the types in the tuple
Examples:
print(tuple_type_to_list(tuple(int, double))[2])
-- double
--]=]
-- The entries key of a tuple type is a terra list of tables,
-- where each table stores the index (zero based) and the type.
-- Hence we can use the map method of a terra list to extract a list
-- of terra types. For details, see the implementation of the tuples type
-- https://github.com/terralang/terra/blob/4d32a10ffe632694aa973c1457f1d3fb9372c737/src/terralib.lua#L1762
return tpl.entries:map(function(t) return t[2] end)
end
local function get_signature_list(func)
--[=[
Extract types from function signature
Args:
func: Terra function
Returns:
Argument and return types in two separate lists.
Examples:
print(get_signature_list(terra(x: int, y: float): double end))
-- {int32,float} {double}
--]=]
assert(terralib.isfunction(func), "Argument must be terra function")
local input = func.type.parameters
local ret = func.type.returntype
local output
if ret.entries ~= nil then
output = tuple_type_to_list(ret)
else
output = terralib.newlist()
output:insert(ret)
end
return input, output
end
local function cast_signature(T, func, TRef)
--[=[
Replace given type and references in function signature with new type
Args:
T: new type
func: Terra function
TRef: old type, to be replaced
Returns:
Argument and return types as separate lists
--]=]
local input, output = get_signature_list(func)
local arg = terralib.newlist()
local ret = terralib.newlist()
local cast = function(S)
if S == TRef then
return T
elseif S == &TRef then
return &T
else
return S
end
end
local arg = input:map(cast)
local ret = output:map(cast)
return arg, ret
end
local function generate_wrapper(T, c_func, TRef, r_func)
--[=[
Generate uniform wrappers for BLAS and LAPACK functions
Args:
T: type for which the wrapper is generated
c_func: Underlying C function, wrapped as a terra function
TRef: Reference type for function signature
r_func: Terra function with model function signature
If not given, c_func is used to extract the function signature
Returns:
Wrapper around c_func for type T with same interface as r_func
--]=]
r_func = r_func or c_func
local terra_arg, terra_ret = cast_signature(T, r_func, TRef)
local terra_sym = terra_arg:map(symbol)
local c_arg, c_ret = get_signature_list(c_func)
local statement = terralib.newlist()
local c_sym = terralib.newlist()
-- BLAS or LAPACK functions can only have the following input types:
-- Enums: These are passed to terra as uint32
-- Integers: These are passed to terra as int32
-- Scalars: For real data, these are mapped to float or double. Complex
-- data types are passed as opaque pointers in C for BLAS.
-- Arrays: For complex data types these are passed as opaque pointers in BLAS
-- Integer arrays: These are passed to terra as &int32
--
-- This means that we need a reference signature to decide if a opaque
-- pointer refers to a scalar or an array.
--
-- To summarize, only scalars and arrays depend on the actual type, i.e.
-- the prefix of BLAS/LAPACK routine. Thus, when looping over the arguments
-- of the function
for i = 1, #terra_arg do
-- As discussed above, only the cases if the terra argument is a sclar
-- or a scalar array need special treatment.
if terra_arg[i] == T then
local scalar = symbol(c_arg[i])
-- complex data types are passed by reference in BLAS
if c_arg[i]:ispointer() then
statement:insert(quote
var [scalar] = [ c_arg[i] ](&[ terra_sym[i] ])
end)
else
statement:insert(quote
var [scalar] = @[ &c_arg[i] ](&[ terra_sym[i] ])
end)
end
c_sym:insert(scalar)
elseif terra_arg[i] == &T then
local pointer = symbol(c_arg[i])
statement:insert(quote
var [pointer] = [ c_arg[i] ]([ terra_sym[i] ])
end)
c_sym:insert(pointer)
else
c_sym:insert(terra_sym[i])
end
end
local return_statement = terralib.newlist()
local c_call
-- If the number of arguments match, then the return value of the
-- C call is passed via the return statement.
-- If the number of arguments of the C call is larger than the number of
-- arguments of the reference call, we assume that the return value is
-- passed by reference, so we declare the return value, assign its address
-- to a pointer of matching c type, call the C function and return
-- the result by value.
if #c_arg == #terra_arg then
c_call = quote return [c_func]([c_sym]) end
elseif #c_arg == #terra_arg + 1 then
local ref = symbol(terra_ret[1])
statement:insert(quote var [ref] end)
local c_ref = symbol(c_arg[#c_arg])
statement:insert(quote var [c_ref] = [ c_arg[#c_arg] ](&[ref]) end)
c_sym:insert(c_ref)
return_statement:insert(quote return [ref] end)
c_call = quote [c_func]([c_sym]) end
else
error("Unsupported number of return statements")
end
local terra wrapper([terra_sym])
[statement]
[c_call]
[return_statement]
end
return wrapper
end
return {
generate_wrapper = generate_wrapper
}