-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompiled.py
119 lines (92 loc) · 3.22 KB
/
compiled.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
"""
from bindings import *
import weldtypes
import ctypes
import time
import numpy as np
class WeldEncoder(object):
"""
An abstract class that must be overwridden by libraries. This class
is used to marshall objects from Python types to Weld types.
"""
def encode(obj):
"""
"""
raise NotImplementedError
def py_to_weld_type(self, obj):
raise NotImplementedError
class WeldDecoder(object):
"""
An abstract class that must be overwridden by libraries. This class
is used to marshall objects from Weld types to Python types.
"""
def decode(obj, restype):
"""
Decodes obj, assuming object is of type `restype`. obj's Python
type is ctypes.POINTER(restype.ctype_class).
"""
raise NotImplementedError
# Returns a wrapped ctypes Structure
def args_factory(arg_names, arg_types):
class Args(ctypes.Structure):
_fields_ = list(zip(arg_names, arg_types))
return Args
def compile(program, arg_types, restype, decoder, verbose=False):
"""Compiles a program and returns a function for calling it.
Parameters
----------
program : a string representing a Weld program.
arg_types : a tuple of (type, encoder)
decoder : a decoder for the returned value.
"""
start = time.time()
conf = WeldConf()
err = WeldError()
module = WeldModule(program, conf, err)
if err.code() != 0:
raise ValueError("Could not compile function {}: {}".format(
function, err.message()))
end = time.time()
if verbose:
print("Weld compile time:", end - start)
def func(*args):
# Field names.
names = []
# C type of each argument.
arg_c_types = []
# Encoded version of each argument.
encoded = []
for (i, (arg, arg_type)) in enumerate(zip(args, arg_types)):
names.append("_{}".format(i))
print(i, arg, arg_type)
if isinstance(arg_type, WeldEncoder):
arg_c_types.append(arg_type.py_to_weld_type(arg).ctype_class)
encoded.append(arg_type.encode(arg))
else:
# Primitive type with a builtin encoder
assert isinstance(arg, arg_type)
ctype = weldtypes.encoder(arg_type)
arg_c_types.append(ctype)
encoded.append(ctype(arg))
print(names)
print(arg_c_types)
print(encoded)
Args = args_factory(names, arg_c_types)
raw_args = Args()
for name, value in zip(names, encoded):
setattr(raw_args, name, value)
raw_args_pointer = ctypes.cast(ctypes.byref(raw_args), ctypes.c_void_p)
weld_input = WeldValue(raw_args_pointer)
conf = WeldConf()
err = WeldError()
result = module.run(conf, weld_input, err)
if err.code() != 0:
raise ValueError(("Error while running function,\n{}\n\n"
"Error message: {}").format(
function, err.message()))
pointer_type = POINTER(restype.ctype_class)
data = ctypes.cast(result.data(), pointer_type)
result = decoder.decode(data, restype)
return result
return func