diff --git a/docs/src/examples/compton.jl b/docs/src/examples/compton.jl
index ed1caac..5e25f0b 100644
--- a/docs/src/examples/compton.jl
+++ b/docs/src/examples/compton.jl
@@ -33,7 +33,7 @@ length(feynman_diagrams(proc))
# Next, we can generate the DAG representing the computation for our scattering process'
# squared matrix element. This uses [`ComputableDAGs.jl`](https://github.com/ComputableDAGs/ComputableDAGs.jl).
-dag = generate_DAG(proc)
+dag = graph(proc)
# In this graph output you can see the number of nodes necessary to compute.
# Note that for larger processes, the number of total nodes can be *lower* than
diff --git a/docs/src/examples/trident.jl b/docs/src/examples/trident.jl
index f68d5c8..ab08f01 100644
--- a/docs/src/examples/trident.jl
+++ b/docs/src/examples/trident.jl
@@ -32,7 +32,7 @@ length(feynman_diagrams(proc))
# Next, we can generate the DAG representing the computation for our scattering process'
# squared matrix element. This uses `ComputableDAGs.jl`.
-dag = generate_DAG(proc)
+dag = graph(proc)
# To continue, we will need [`ComputableDAGs.jl`](https://github.com/ComputableDAGs/ComputableDAGs.jl). Since `ComputableDAGs.jl` uses
# `RuntimeGeneratedFunction`s as the return type of [`ComputableDAGs.get_compute_function`](@extref), we need
diff --git a/docs/src/lib/public.md b/docs/src/lib/public.md
index 9cf9a05..413148b 100644
--- a/docs/src/lib/public.md
+++ b/docs/src/lib/public.md
@@ -15,7 +15,7 @@ VirtualParticle
```@docs
external_particles
feynman_diagrams
-generate_DAG
+graph
process
virtual_particles
```
diff --git a/src/QEDFeynmanDiagrams.jl b/src/QEDFeynmanDiagrams.jl
index fbd5302..2541fef 100644
--- a/src/QEDFeynmanDiagrams.jl
+++ b/src/QEDFeynmanDiagrams.jl
@@ -12,7 +12,7 @@ using DataStructures
export FeynmanDiagram, VirtualParticle
export feynman_diagrams
-export external_particles, virtual_particles, process, generate_DAG
+export external_particles, virtual_particles, process, graph
include("flat_matrix.jl")
diff --git a/src/computable_dags/generation.jl b/src/computable_dags/generation.jl
index 9806354..9be7b37 100644
--- a/src/computable_dags/generation.jl
+++ b/src/computable_dags/generation.jl
@@ -364,11 +364,11 @@ function _is_index_valid_combination(proc::AbstractProcessDefinition, index::Tup
end
"""
- generate_DAG(proc::AbstractProcessDefinition)
+ ComputableDAGs.graph(proc::AbstractProcessDefinition)
Generate and return a [`ComputableDAGs.DAG`](@extref), representing the computation for the squared matrix element of this scattering process, summed over spin and polarization combinations allowed by the process.
"""
-function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
+function ComputableDAGs.graph(proc::PROC) where {PROC<:AbstractProcessDefinition}
I = number_incoming_particles(proc)
O = number_outgoing_particles(proc)
SPECIFIC_VP = VirtualParticle{PROC,NTuple{I,Bool},NTuple{O,Bool}}
@@ -378,7 +378,7 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
sort!(pairs)
triples = sort(total_particle_triples(particles)) # triples to generate the triple tasks
- graph = DAG()
+ g = DAG()
# -- Base State Tasks --
propagated_outputs = Dict{VirtualParticle,Vector{Node}}()
@@ -395,19 +395,19 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
# names are "bs_
___"
data_node_name = "bs_$(_dir_str(dir))_$(_species_str(species))_$(_spin_pol_str(spin_pol))_$(index)"
- data_in = insert_node!(graph, DataTask(0), data_node_name)
+ data_in = insert_node!(g, DataTask(0), data_node_name)
# generate initial base_state tasks
- compute_base_state = insert_node!(graph, ComputeTask_BaseState())
+ compute_base_state = insert_node!(g, ComputeTask_BaseState())
data_out = insert_node!(
- graph,
+ g,
DataTask(0),
"$(_total_index(proc, dir, species, index))_$(_spin_pol_str(spin_pol))",
)
- insert_edge!(graph, data_in, compute_base_state)
- insert_edge!(graph, compute_base_state, data_out)
+ insert_edge!(g, data_in, compute_base_state)
+ insert_edge!(g, compute_base_state, data_out)
if !haskey(propagated_outputs, p)
propagated_outputs[p] = Vector{Node}()
@@ -426,12 +426,12 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
data_node_name = "pr_$vp_index"
- data_in = insert_node!(graph, DataTask(0), data_node_name)
- compute_vp_propagator = insert_node!(graph, ComputeTask_Propagator())
- data_out = insert_node!(graph, DataTask(0))
+ data_in = insert_node!(g, DataTask(0), data_node_name)
+ compute_vp_propagator = insert_node!(g, ComputeTask_Propagator())
+ data_out = insert_node!(g, DataTask(0))
- insert_edge!(graph, data_in, compute_vp_propagator)
- insert_edge!(graph, compute_vp_propagator, data_out)
+ insert_edge!(g, data_in, compute_vp_propagator)
+ insert_edge!(g, compute_vp_propagator, data_out)
propagator_task_outputs[vp] = data_out
end
@@ -464,22 +464,16 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
# make the compute pair nodes for every combination of the found input_particle_nodes to get all spin/pol combinations
- compute_pair = insert_node!(graph, ComputeTask_Pair())
- pair_data_out = insert_node!(graph, DataTask(0))
+ compute_pair = insert_node!(g, ComputeTask_Pair())
+ pair_data_out = insert_node!(g, DataTask(0))
insert_edge!(
- graph,
- in_nodes[1],
- compute_pair,
- _edge_index_from_vp(input_particles[1]),
+ g, in_nodes[1], compute_pair, _edge_index_from_vp(input_particles[1])
)
insert_edge!(
- graph,
- in_nodes[2],
- compute_pair,
- _edge_index_from_vp(input_particles[2]),
+ g, in_nodes[2], compute_pair, _edge_index_from_vp(input_particles[2])
)
- insert_edge!(graph, compute_pair, pair_data_out)
+ insert_edge!(g, compute_pair, pair_data_out)
if !haskey(pair_output_nodes_by_spin_pol, index)
pair_output_nodes_by_spin_pol[index] = Vector()
@@ -492,25 +486,23 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
for (index, nodes_to_sum) in pair_output_nodes_by_spin_pol
compute_pairs_sum = insert_node!(
- graph, ComputeTask_CollectPairs(length(nodes_to_sum))
+ g, ComputeTask_CollectPairs(length(nodes_to_sum))
)
- data_pairs_sum = insert_node!(graph, DataTask(0))
- compute_propagated = insert_node!(graph, ComputeTask_PropagatePairs())
+ data_pairs_sum = insert_node!(g, DataTask(0))
+ compute_propagated = insert_node!(g, ComputeTask_PropagatePairs())
# give this out node the correct name
- data_out_propagated = insert_node!(
- graph, DataTask(0), _make_node_name([index...])
- )
+ data_out_propagated = insert_node!(g, DataTask(0), _make_node_name([index...]))
for node in nodes_to_sum
- insert_edge!(graph, node, compute_pairs_sum)
+ insert_edge!(g, node, compute_pairs_sum)
end
- insert_edge!(graph, compute_pairs_sum, data_pairs_sum)
+ insert_edge!(g, compute_pairs_sum, data_pairs_sum)
- insert_edge!(graph, propagator_node, compute_propagated, 1)
- insert_edge!(graph, data_pairs_sum, compute_propagated, 2)
+ insert_edge!(g, propagator_node, compute_propagated, 1)
+ insert_edge!(g, data_pairs_sum, compute_propagated, 2)
- insert_edge!(graph, compute_propagated, data_out_propagated)
+ insert_edge!(g, compute_propagated, data_out_propagated)
push!(propagated_outputs[product_particle], data_out_propagated)
end
@@ -530,14 +522,14 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
continue
end
- compute_triples = insert_node!(graph, ComputeTask_Triple())
- data_triples = insert_node!(graph, DataTask(0))
+ compute_triples = insert_node!(g, ComputeTask_Triple())
+ data_triples = insert_node!(g, DataTask(0))
- insert_edge!(graph, a, compute_triples, 1) # first argument photons
- insert_edge!(graph, b, compute_triples, 2) # second argument electrons
- insert_edge!(graph, c, compute_triples, 3) # third argument positrons
+ insert_edge!(g, a, compute_triples, 1) # first argument photons
+ insert_edge!(g, b, compute_triples, 2) # second argument electrons
+ insert_edge!(g, c, compute_triples, 3) # third argument positrons
- insert_edge!(graph, compute_triples, data_triples)
+ insert_edge!(g, compute_triples, data_triples)
if !haskey(triples_results, index)
triples_results[index] = Vector{DataTaskNode}()
@@ -550,27 +542,27 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition}
collected_triples = Vector{DataTaskNode}()
for (index, results) in triples_results
compute_collect_triples = insert_node!(
- graph, ComputeTask_CollectTriples(length(results))
+ g, ComputeTask_CollectTriples(length(results))
)
- data_collect_triples = insert_node!(graph, DataTask(0))
+ data_collect_triples = insert_node!(g, DataTask(0))
for triple in results
- insert_edge!(graph, triple, compute_collect_triples)
+ insert_edge!(g, triple, compute_collect_triples)
end
- insert_edge!(graph, compute_collect_triples, data_collect_triples)
+ insert_edge!(g, compute_collect_triples, data_collect_triples)
push!(collected_triples, data_collect_triples)
end
# Finally, abs2 sum over spin/pol configurations
compute_total_result = insert_node!(
- graph, ComputeTask_SpinPolCumulation(length(collected_triples))
+ g, ComputeTask_SpinPolCumulation(length(collected_triples))
)
for finished_triple in collected_triples
- insert_edge!(graph, finished_triple, compute_total_result)
+ insert_edge!(g, finished_triple, compute_total_result)
end
- final_data_out = insert_node!(graph, DataTask(0))
- insert_edge!(graph, compute_total_result, final_data_out)
- return graph
+ final_data_out = insert_node!(g, DataTask(0))
+ insert_edge!(g, compute_total_result, final_data_out)
+ return g
end
diff --git a/test/synced_spin_pol.jl b/test/synced_spin_pol.jl
index e9bb011..3542a24 100644
--- a/test/synced_spin_pol.jl
+++ b/test/synced_spin_pol.jl
@@ -34,9 +34,9 @@ RNG = MersenneTwister(0)
(AllSpin(), AllPol()),
)
- g_synced = generate_DAG(proc_synced)
- g_polx = generate_DAG(proc_polx)
- g_poly = generate_DAG(proc_poly)
+ g_synced = graph(proc_synced)
+ g_polx = graph(proc_polx)
+ g_poly = graph(proc_poly)
@test length(g_polx.nodes) == length(g_poly.nodes)
@test length(g_polx.nodes) < length(g_synced.nodes) < 2 * length(g_polx.nodes)
@@ -99,8 +99,8 @@ GC.gc()
)
end
- g_synced = generate_DAG(proc_synced)
- graphs = generate_DAG.(procs)
+ g_synced = graph(proc_synced)
+ graphs = graph.(procs)
for g in graphs[2:end]
@test length(graphs[1].nodes) == length(g.nodes)