Skip to content

Commit

Permalink
Overload correct interface function ComputableDAGs.graph (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard authored Nov 1, 2024
1 parent 54e343e commit 0ea9ca9
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 59 deletions.
2 changes: 1 addition & 1 deletion docs/src/examples/compton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ length(feynman_diagrams(proc))

# Next, we can generate the DAG representing the computation for our scattering process'
# squared matrix element. This uses [`ComputableDAGs.jl`](https://github.com/ComputableDAGs/ComputableDAGs.jl).
dag = generate_DAG(proc)
dag = graph(proc)

# In this graph output you can see the number of nodes necessary to compute.
# Note that for larger processes, the number of total nodes can be *lower* than
Expand Down
2 changes: 1 addition & 1 deletion docs/src/examples/trident.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ length(feynman_diagrams(proc))

# Next, we can generate the DAG representing the computation for our scattering process'
# squared matrix element. This uses `ComputableDAGs.jl`.
dag = generate_DAG(proc)
dag = graph(proc)

# To continue, we will need [`ComputableDAGs.jl`](https://github.com/ComputableDAGs/ComputableDAGs.jl). Since `ComputableDAGs.jl` uses
# `RuntimeGeneratedFunction`s as the return type of [`ComputableDAGs.get_compute_function`](@extref), we need
Expand Down
2 changes: 1 addition & 1 deletion docs/src/lib/public.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ VirtualParticle
```@docs
external_particles
feynman_diagrams
generate_DAG
graph
process
virtual_particles
```
2 changes: 1 addition & 1 deletion src/QEDFeynmanDiagrams.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using DataStructures

export FeynmanDiagram, VirtualParticle
export feynman_diagrams
export external_particles, virtual_particles, process, generate_DAG
export external_particles, virtual_particles, process, graph

include("flat_matrix.jl")

Expand Down
92 changes: 42 additions & 50 deletions src/computable_dags/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,11 @@ function _is_index_valid_combination(proc::AbstractProcessDefinition, index::Tup
end

"""
generate_DAG(proc::AbstractProcessDefinition)
ComputableDAGs.graph(proc::AbstractProcessDefinition)
Generate and return a [`ComputableDAGs.DAG`](@extref), representing the computation for the squared matrix element of this scattering process, summed over spin and polarization combinations allowed by the process.
"""
function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
function ComputableDAGs.graph(proc::PROC) where {PROC<:AbstractProcessDefinition}
I = number_incoming_particles(proc)
O = number_outgoing_particles(proc)
SPECIFIC_VP = VirtualParticle{PROC,NTuple{I,Bool},NTuple{O,Bool}}
Expand All @@ -378,7 +378,7 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
sort!(pairs)
triples = sort(total_particle_triples(particles)) # triples to generate the triple tasks

graph = DAG()
g = DAG()

# -- Base State Tasks --
propagated_outputs = Dict{VirtualParticle,Vector{Node}}()
Expand All @@ -395,19 +395,19 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
# names are "bs_<dir>_<species>_<spin/pol>_<index>"
data_node_name = "bs_$(_dir_str(dir))_$(_species_str(species))_$(_spin_pol_str(spin_pol))_$(index)"

data_in = insert_node!(graph, DataTask(0), data_node_name)
data_in = insert_node!(g, DataTask(0), data_node_name)

# generate initial base_state tasks
compute_base_state = insert_node!(graph, ComputeTask_BaseState())
compute_base_state = insert_node!(g, ComputeTask_BaseState())

data_out = insert_node!(
graph,
g,
DataTask(0),
"$(_total_index(proc, dir, species, index))_$(_spin_pol_str(spin_pol))",
)

insert_edge!(graph, data_in, compute_base_state)
insert_edge!(graph, compute_base_state, data_out)
insert_edge!(g, data_in, compute_base_state)
insert_edge!(g, compute_base_state, data_out)

if !haskey(propagated_outputs, p)
propagated_outputs[p] = Vector{Node}()
Expand All @@ -426,12 +426,12 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}

data_node_name = "pr_$vp_index"

data_in = insert_node!(graph, DataTask(0), data_node_name)
compute_vp_propagator = insert_node!(graph, ComputeTask_Propagator())
data_out = insert_node!(graph, DataTask(0))
data_in = insert_node!(g, DataTask(0), data_node_name)
compute_vp_propagator = insert_node!(g, ComputeTask_Propagator())
data_out = insert_node!(g, DataTask(0))

insert_edge!(graph, data_in, compute_vp_propagator)
insert_edge!(graph, compute_vp_propagator, data_out)
insert_edge!(g, data_in, compute_vp_propagator)
insert_edge!(g, compute_vp_propagator, data_out)

propagator_task_outputs[vp] = data_out
end
Expand Down Expand Up @@ -464,22 +464,16 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}

# make the compute pair nodes for every combination of the found input_particle_nodes to get all spin/pol combinations

compute_pair = insert_node!(graph, ComputeTask_Pair())
pair_data_out = insert_node!(graph, DataTask(0))
compute_pair = insert_node!(g, ComputeTask_Pair())
pair_data_out = insert_node!(g, DataTask(0))

insert_edge!(
graph,
in_nodes[1],
compute_pair,
_edge_index_from_vp(input_particles[1]),
g, in_nodes[1], compute_pair, _edge_index_from_vp(input_particles[1])
)
insert_edge!(
graph,
in_nodes[2],
compute_pair,
_edge_index_from_vp(input_particles[2]),
g, in_nodes[2], compute_pair, _edge_index_from_vp(input_particles[2])
)
insert_edge!(graph, compute_pair, pair_data_out)
insert_edge!(g, compute_pair, pair_data_out)

if !haskey(pair_output_nodes_by_spin_pol, index)
pair_output_nodes_by_spin_pol[index] = Vector()
Expand All @@ -492,25 +486,23 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}

for (index, nodes_to_sum) in pair_output_nodes_by_spin_pol
compute_pairs_sum = insert_node!(
graph, ComputeTask_CollectPairs(length(nodes_to_sum))
g, ComputeTask_CollectPairs(length(nodes_to_sum))
)
data_pairs_sum = insert_node!(graph, DataTask(0))
compute_propagated = insert_node!(graph, ComputeTask_PropagatePairs())
data_pairs_sum = insert_node!(g, DataTask(0))
compute_propagated = insert_node!(g, ComputeTask_PropagatePairs())
# give this out node the correct name
data_out_propagated = insert_node!(
graph, DataTask(0), _make_node_name([index...])
)
data_out_propagated = insert_node!(g, DataTask(0), _make_node_name([index...]))

for node in nodes_to_sum
insert_edge!(graph, node, compute_pairs_sum)
insert_edge!(g, node, compute_pairs_sum)
end

insert_edge!(graph, compute_pairs_sum, data_pairs_sum)
insert_edge!(g, compute_pairs_sum, data_pairs_sum)

insert_edge!(graph, propagator_node, compute_propagated, 1)
insert_edge!(graph, data_pairs_sum, compute_propagated, 2)
insert_edge!(g, propagator_node, compute_propagated, 1)
insert_edge!(g, data_pairs_sum, compute_propagated, 2)

insert_edge!(graph, compute_propagated, data_out_propagated)
insert_edge!(g, compute_propagated, data_out_propagated)

push!(propagated_outputs[product_particle], data_out_propagated)
end
Expand All @@ -530,14 +522,14 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
continue
end

compute_triples = insert_node!(graph, ComputeTask_Triple())
data_triples = insert_node!(graph, DataTask(0))
compute_triples = insert_node!(g, ComputeTask_Triple())
data_triples = insert_node!(g, DataTask(0))

insert_edge!(graph, a, compute_triples, 1) # first argument photons
insert_edge!(graph, b, compute_triples, 2) # second argument electrons
insert_edge!(graph, c, compute_triples, 3) # third argument positrons
insert_edge!(g, a, compute_triples, 1) # first argument photons
insert_edge!(g, b, compute_triples, 2) # second argument electrons
insert_edge!(g, c, compute_triples, 3) # third argument positrons

insert_edge!(graph, compute_triples, data_triples)
insert_edge!(g, compute_triples, data_triples)

if !haskey(triples_results, index)
triples_results[index] = Vector{DataTaskNode}()
Expand All @@ -550,27 +542,27 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
collected_triples = Vector{DataTaskNode}()
for (index, results) in triples_results
compute_collect_triples = insert_node!(
graph, ComputeTask_CollectTriples(length(results))
g, ComputeTask_CollectTriples(length(results))
)
data_collect_triples = insert_node!(graph, DataTask(0))
data_collect_triples = insert_node!(g, DataTask(0))

for triple in results
insert_edge!(graph, triple, compute_collect_triples)
insert_edge!(g, triple, compute_collect_triples)
end
insert_edge!(graph, compute_collect_triples, data_collect_triples)
insert_edge!(g, compute_collect_triples, data_collect_triples)

push!(collected_triples, data_collect_triples)
end

# Finally, abs2 sum over spin/pol configurations
compute_total_result = insert_node!(
graph, ComputeTask_SpinPolCumulation(length(collected_triples))
g, ComputeTask_SpinPolCumulation(length(collected_triples))
)
for finished_triple in collected_triples
insert_edge!(graph, finished_triple, compute_total_result)
insert_edge!(g, finished_triple, compute_total_result)
end

final_data_out = insert_node!(graph, DataTask(0))
insert_edge!(graph, compute_total_result, final_data_out)
return graph
final_data_out = insert_node!(g, DataTask(0))
insert_edge!(g, compute_total_result, final_data_out)
return g
end
10 changes: 5 additions & 5 deletions test/synced_spin_pol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ RNG = MersenneTwister(0)
(AllSpin(), AllPol()),
)

g_synced = generate_DAG(proc_synced)
g_polx = generate_DAG(proc_polx)
g_poly = generate_DAG(proc_poly)
g_synced = graph(proc_synced)
g_polx = graph(proc_polx)
g_poly = graph(proc_poly)

@test length(g_polx.nodes) == length(g_poly.nodes)
@test length(g_polx.nodes) < length(g_synced.nodes) < 2 * length(g_polx.nodes)
Expand Down Expand Up @@ -99,8 +99,8 @@ GC.gc()
)
end

g_synced = generate_DAG(proc_synced)
graphs = generate_DAG.(procs)
g_synced = graph(proc_synced)
graphs = graph.(procs)

for g in graphs[2:end]
@test length(graphs[1].nodes) == length(g.nodes)
Expand Down

0 comments on commit 0ea9ca9

Please sign in to comment.