Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework FunctionCalls to allow multiple returned objects #44

Merged
merged 12 commits into from
Dec 4, 2024
11 changes: 3 additions & 8 deletions ext/devices/cuda/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,21 @@ function ComputableDAGs.kernel(
machine = cpu_st()
tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module)

assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...)
# TODO: use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
if (id > n)
return
end
@inline data_input = input_vector[id]
@inline input = input_vector[id]
$(assign_inputs)
$code
@inline output_vector[id] = $res_sym
@inline output_vector[id] = $(tape.output_symbol)
return nothing
end"
)
Expand Down
11 changes: 3 additions & 8 deletions ext/devices/rocm/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@ function ComputableDAGs.kernel(
machine = cpu_st()
tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module)

assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...)

# TODO use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x
if (id > n)
return
end
@inline data_input = input_vector[id]
@inline input = input_vector[id]
$(assign_inputs)
$code
@inline output_vector[id] = $res_sym
@inline output_vector[id] = $(tape.output_symbol)
return nothing
end"
)
Expand Down
8 changes: 2 additions & 6 deletions src/ComputableDAGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ export reset_graph!
export get_operations

# code generation related
export execute
export get_compute_function
export gen_tape, execute_tape
export unpack_identity

# estimator
export cost_type, graph_cost, operation_effect
Expand All @@ -50,8 +47,7 @@ export optimize_step!, optimize!
export fixpoint_reached, optimize_to_fixpoint!

# models
export AbstractModel, AbstractProblemInstance
export problem_instance, input_type, graph, input_expr
export graph, input_type, input_expr

# machine info
export Machine
Expand All @@ -69,6 +65,7 @@ include("properties/type.jl")
include("operation/type.jl")
include("graph/type.jl")
include("scheduler/type.jl")
include("code_gen/type.jl")

include("trie.jl")
include("utils.jl")
Expand Down Expand Up @@ -127,7 +124,6 @@ include("devices/ext.jl")
include("scheduler/interface.jl")
include("scheduler/greedy.jl")

include("code_gen/type.jl")
include("code_gen/utils.jl")
include("code_gen/tape_machine.jl")
include("code_gen/function.jl")
Expand Down
44 changes: 11 additions & 33 deletions src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,50 +17,28 @@ in your top level.

## Keyword Arguments

`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the
compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
**Note** that the actually used closure size might be different than the one passed here, since the function automatically chooses a size that
is close to a n-th root of the total number of loc, based off the given size.
"""
function get_compute_function(
graph::DAG, instance, machine::Machine, context_module::Module; closures_size=0
)
tape = gen_tape(graph, instance, machine, context_module)

assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = gen_function_body(tape; closures_size=closures_size)
code = gen_function_body(tape, context_module; closures_size=closures_size)
assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...)

functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(_gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
expr = #
Expr(
function_id = to_var_name(UUIDs.uuid1(rng[1]))
res_sym = tape.output_symbol
expr = Expr(
:function, # function definition
Expr(
:call,
Symbol("compute_$functionId"),
Expr(:(::), :data_input, input_type(instance)),
:call, Symbol("compute_$function_id"), Expr(:(::), :input, input_type(instance))
), # function name and parameters
Expr(:block, assignInputs, code, Expr(:return, resSym)), # function body
Expr(:block, assign_inputs, code, Expr(:return, res_sym)), # function body
)

return RuntimeGeneratedFunction(@__MODULE__, context_module, expr)
end

"""
execute(
graph::DAG,
instance,
machine::Machine,
input,
context_module::Module
)

Execute the code of the given `graph` on the given input values.

This is essentially shorthand for
```julia
tape = gen_tape(graph, instance, machine, context_module)
return execute_tape(tape, input)
```
"""
function execute(graph::DAG, instance, machine::Machine, input, context_module::Module)
tape = gen_tape(graph, instance, machine, context_module)
return execute_tape(tape, input)
end
Loading
Loading