From 0ea9ca9c4372dcd5e91387b6884fd7fc975f6873 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Fri, 1 Nov 2024 14:13:44 +0100 Subject: [PATCH] Overload correct interface function `ComputableDAGs.graph` (#39) --- docs/src/examples/compton.jl | 2 +- docs/src/examples/trident.jl | 2 +- docs/src/lib/public.md | 2 +- src/QEDFeynmanDiagrams.jl | 2 +- src/computable_dags/generation.jl | 92 ++++++++++++++----------------- test/synced_spin_pol.jl | 10 ++-- 6 files changed, 51 insertions(+), 59 deletions(-) diff --git a/docs/src/examples/compton.jl b/docs/src/examples/compton.jl index ed1caac..5e25f0b 100644 --- a/docs/src/examples/compton.jl +++ b/docs/src/examples/compton.jl @@ -33,7 +33,7 @@ length(feynman_diagrams(proc)) # Next, we can generate the DAG representing the computation for our scattering process' # squared matrix element. This uses [`ComputableDAGs.jl`](https://github.com/ComputableDAGs/ComputableDAGs.jl). -dag = generate_DAG(proc) +dag = graph(proc) # In this graph output you can see the number of nodes necessary to compute. # Note that for larger processes, the number of total nodes can be *lower* than diff --git a/docs/src/examples/trident.jl b/docs/src/examples/trident.jl index f68d5c8..ab08f01 100644 --- a/docs/src/examples/trident.jl +++ b/docs/src/examples/trident.jl @@ -32,7 +32,7 @@ length(feynman_diagrams(proc)) # Next, we can generate the DAG representing the computation for our scattering process' # squared matrix element. This uses `ComputableDAGs.jl`. -dag = generate_DAG(proc) +dag = graph(proc) # To continue, we will need [`ComputableDAGs.jl`](https://github.com/ComputableDAGs/ComputableDAGs.jl). Since `ComputableDAGs.jl` uses # `RuntimeGeneratedFunction`s as the return type of [`ComputableDAGs.get_compute_function`](@extref), we need diff --git a/docs/src/lib/public.md b/docs/src/lib/public.md index 9cf9a05..413148b 100644 --- a/docs/src/lib/public.md +++ b/docs/src/lib/public.md @@ -15,7 +15,7 @@ VirtualParticle ```@docs external_particles feynman_diagrams -generate_DAG +graph process virtual_particles ``` diff --git a/src/QEDFeynmanDiagrams.jl b/src/QEDFeynmanDiagrams.jl index fbd5302..2541fef 100644 --- a/src/QEDFeynmanDiagrams.jl +++ b/src/QEDFeynmanDiagrams.jl @@ -12,7 +12,7 @@ using DataStructures export FeynmanDiagram, VirtualParticle export feynman_diagrams -export external_particles, virtual_particles, process, generate_DAG +export external_particles, virtual_particles, process, graph include("flat_matrix.jl") diff --git a/src/computable_dags/generation.jl b/src/computable_dags/generation.jl index 9806354..9be7b37 100644 --- a/src/computable_dags/generation.jl +++ b/src/computable_dags/generation.jl @@ -364,11 +364,11 @@ function _is_index_valid_combination(proc::AbstractProcessDefinition, index::Tup end """ - generate_DAG(proc::AbstractProcessDefinition) + ComputableDAGs.graph(proc::AbstractProcessDefinition) Generate and return a [`ComputableDAGs.DAG`](@extref), representing the computation for the squared matrix element of this scattering process, summed over spin and polarization combinations allowed by the process. """ -function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition} +function ComputableDAGs.graph(proc::PROC) where {PROC<:AbstractProcessDefinition} I = number_incoming_particles(proc) O = number_outgoing_particles(proc) SPECIFIC_VP = VirtualParticle{PROC,NTuple{I,Bool},NTuple{O,Bool}} @@ -378,7 +378,7 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition} sort!(pairs) triples = sort(total_particle_triples(particles)) # triples to generate the triple tasks - graph = DAG() + g = DAG() # -- Base State Tasks -- propagated_outputs = Dict{VirtualParticle,Vector{Node}}() @@ -395,19 +395,19 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition} # names are "bs____" data_node_name = "bs_$(_dir_str(dir))_$(_species_str(species))_$(_spin_pol_str(spin_pol))_$(index)" - data_in = insert_node!(graph, DataTask(0), data_node_name) + data_in = insert_node!(g, DataTask(0), data_node_name) # generate initial base_state tasks - compute_base_state = insert_node!(graph, ComputeTask_BaseState()) + compute_base_state = insert_node!(g, ComputeTask_BaseState()) data_out = insert_node!( - graph, + g, DataTask(0), "$(_total_index(proc, dir, species, index))_$(_spin_pol_str(spin_pol))", ) - insert_edge!(graph, data_in, compute_base_state) - insert_edge!(graph, compute_base_state, data_out) + insert_edge!(g, data_in, compute_base_state) + insert_edge!(g, compute_base_state, data_out) if !haskey(propagated_outputs, p) propagated_outputs[p] = Vector{Node}() @@ -426,12 +426,12 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition} data_node_name = "pr_$vp_index" - data_in = insert_node!(graph, DataTask(0), data_node_name) - compute_vp_propagator = insert_node!(graph, ComputeTask_Propagator()) - data_out = insert_node!(graph, DataTask(0)) + data_in = insert_node!(g, DataTask(0), data_node_name) + compute_vp_propagator = insert_node!(g, ComputeTask_Propagator()) + data_out = insert_node!(g, DataTask(0)) - insert_edge!(graph, data_in, compute_vp_propagator) - insert_edge!(graph, compute_vp_propagator, data_out) + insert_edge!(g, data_in, compute_vp_propagator) + insert_edge!(g, compute_vp_propagator, data_out) propagator_task_outputs[vp] = data_out end @@ -464,22 +464,16 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition} # make the compute pair nodes for every combination of the found input_particle_nodes to get all spin/pol combinations - compute_pair = insert_node!(graph, ComputeTask_Pair()) - pair_data_out = insert_node!(graph, DataTask(0)) + compute_pair = insert_node!(g, ComputeTask_Pair()) + pair_data_out = insert_node!(g, DataTask(0)) insert_edge!( - graph, - in_nodes[1], - compute_pair, - _edge_index_from_vp(input_particles[1]), + g, in_nodes[1], compute_pair, _edge_index_from_vp(input_particles[1]) ) insert_edge!( - graph, - in_nodes[2], - compute_pair, - _edge_index_from_vp(input_particles[2]), + g, in_nodes[2], compute_pair, _edge_index_from_vp(input_particles[2]) ) - insert_edge!(graph, compute_pair, pair_data_out) + insert_edge!(g, compute_pair, pair_data_out) if !haskey(pair_output_nodes_by_spin_pol, index) pair_output_nodes_by_spin_pol[index] = Vector() @@ -492,25 +486,23 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition} for (index, nodes_to_sum) in pair_output_nodes_by_spin_pol compute_pairs_sum = insert_node!( - graph, ComputeTask_CollectPairs(length(nodes_to_sum)) + g, ComputeTask_CollectPairs(length(nodes_to_sum)) ) - data_pairs_sum = insert_node!(graph, DataTask(0)) - compute_propagated = insert_node!(graph, ComputeTask_PropagatePairs()) + data_pairs_sum = insert_node!(g, DataTask(0)) + compute_propagated = insert_node!(g, ComputeTask_PropagatePairs()) # give this out node the correct name - data_out_propagated = insert_node!( - graph, DataTask(0), _make_node_name([index...]) - ) + data_out_propagated = insert_node!(g, DataTask(0), _make_node_name([index...])) for node in nodes_to_sum - insert_edge!(graph, node, compute_pairs_sum) + insert_edge!(g, node, compute_pairs_sum) end - insert_edge!(graph, compute_pairs_sum, data_pairs_sum) + insert_edge!(g, compute_pairs_sum, data_pairs_sum) - insert_edge!(graph, propagator_node, compute_propagated, 1) - insert_edge!(graph, data_pairs_sum, compute_propagated, 2) + insert_edge!(g, propagator_node, compute_propagated, 1) + insert_edge!(g, data_pairs_sum, compute_propagated, 2) - insert_edge!(graph, compute_propagated, data_out_propagated) + insert_edge!(g, compute_propagated, data_out_propagated) push!(propagated_outputs[product_particle], data_out_propagated) end @@ -530,14 +522,14 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition} continue end - compute_triples = insert_node!(graph, ComputeTask_Triple()) - data_triples = insert_node!(graph, DataTask(0)) + compute_triples = insert_node!(g, ComputeTask_Triple()) + data_triples = insert_node!(g, DataTask(0)) - insert_edge!(graph, a, compute_triples, 1) # first argument photons - insert_edge!(graph, b, compute_triples, 2) # second argument electrons - insert_edge!(graph, c, compute_triples, 3) # third argument positrons + insert_edge!(g, a, compute_triples, 1) # first argument photons + insert_edge!(g, b, compute_triples, 2) # second argument electrons + insert_edge!(g, c, compute_triples, 3) # third argument positrons - insert_edge!(graph, compute_triples, data_triples) + insert_edge!(g, compute_triples, data_triples) if !haskey(triples_results, index) triples_results[index] = Vector{DataTaskNode}() @@ -550,27 +542,27 @@ function generate_DAG(proc::PROC) where {PROC<:AbstractProcessDefinition} collected_triples = Vector{DataTaskNode}() for (index, results) in triples_results compute_collect_triples = insert_node!( - graph, ComputeTask_CollectTriples(length(results)) + g, ComputeTask_CollectTriples(length(results)) ) - data_collect_triples = insert_node!(graph, DataTask(0)) + data_collect_triples = insert_node!(g, DataTask(0)) for triple in results - insert_edge!(graph, triple, compute_collect_triples) + insert_edge!(g, triple, compute_collect_triples) end - insert_edge!(graph, compute_collect_triples, data_collect_triples) + insert_edge!(g, compute_collect_triples, data_collect_triples) push!(collected_triples, data_collect_triples) end # Finally, abs2 sum over spin/pol configurations compute_total_result = insert_node!( - graph, ComputeTask_SpinPolCumulation(length(collected_triples)) + g, ComputeTask_SpinPolCumulation(length(collected_triples)) ) for finished_triple in collected_triples - insert_edge!(graph, finished_triple, compute_total_result) + insert_edge!(g, finished_triple, compute_total_result) end - final_data_out = insert_node!(graph, DataTask(0)) - insert_edge!(graph, compute_total_result, final_data_out) - return graph + final_data_out = insert_node!(g, DataTask(0)) + insert_edge!(g, compute_total_result, final_data_out) + return g end diff --git a/test/synced_spin_pol.jl b/test/synced_spin_pol.jl index e9bb011..3542a24 100644 --- a/test/synced_spin_pol.jl +++ b/test/synced_spin_pol.jl @@ -34,9 +34,9 @@ RNG = MersenneTwister(0) (AllSpin(), AllPol()), ) - g_synced = generate_DAG(proc_synced) - g_polx = generate_DAG(proc_polx) - g_poly = generate_DAG(proc_poly) + g_synced = graph(proc_synced) + g_polx = graph(proc_polx) + g_poly = graph(proc_poly) @test length(g_polx.nodes) == length(g_poly.nodes) @test length(g_polx.nodes) < length(g_synced.nodes) < 2 * length(g_polx.nodes) @@ -99,8 +99,8 @@ GC.gc() ) end - g_synced = generate_DAG(proc_synced) - graphs = generate_DAG.(procs) + g_synced = graph(proc_synced) + graphs = graph.(procs) for g in graphs[2:end] @test length(graphs[1].nodes) == length(g.nodes)