Skip to content

Commit

Permalink
Merge pull request #146 from GraphStreamingProject/better_verifier
Browse files Browse the repository at this point in the history
Better verifier
  • Loading branch information
etwest authored Mar 21, 2024
2 parents db06e66 + f15c7ec commit d0f3e5b
Show file tree
Hide file tree
Showing 18 changed files with 525 additions and 365 deletions.
4 changes: 1 addition & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ add_library(GraphZeppelinVerifyCC
src/cc_alg_configuration.cpp
src/sketch.cpp
src/util.cpp
test/util/file_graph_verifier.cpp
test/util/mat_graph_verifier.cpp)
test/util/graph_verifier.cpp)
add_dependencies(GraphZeppelinVerifyCC GutterTree StreamingUtilities)
target_link_libraries(GraphZeppelinVerifyCC PUBLIC xxhash GutterTree StreamingUtilities)
target_include_directories(GraphZeppelinVerifyCC PUBLIC include/ include/test/)
Expand All @@ -123,7 +122,6 @@ if (BUILD_EXE)
test/sketch_test.cpp
test/dsu_test.cpp
test/util_test.cpp
test/util/file_graph_verifier.cpp
test/util/graph_verifier_test.cpp)
add_dependencies(tests GraphZeppelinVerifyCC)
target_link_libraries(tests PRIVATE GraphZeppelinVerifyCC)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ int main() {
DriverConfiguration() // configuration
};
driver.process_stream_until(END_OF_STREAM); // Tell the driver to process the entire graph stream
driver.prep_query(); // Ensure that all updates have been processed
driver.prep_query(CONNECTIVITY); // Ensure algorithm is ready for a connectivity query
auto CC = cc_alg.connected_components(); // Extract the connected components
}
```
Expand Down
36 changes: 24 additions & 12 deletions include/cc_sketch_alg.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ struct alignas(64) GlobalMergeData {
}
};

// What type of query is the user going to perform. Used for has_cached_query()
enum QueryCode {
CONNECTIVITY, // connected components and spanning forest of graph
KSPANNINGFORESTS, // k disjoint spanning forests
};

/**
* Algorithm for computing connected components on undirected graph streams
* (no self-edges or multi-edges)
Expand Down Expand Up @@ -85,20 +91,25 @@ class CCSketchAlg {
Sketch **delta_sketches = nullptr;
size_t num_delta_sketches;

CCAlgConfiguration config;
#ifdef VERIFY_SAMPLES_F
std::unique_ptr<GraphVerifier> verifier;
#endif

/**
* Run the first round of Boruvka. We can do things faster here because we know there will
* be no merging we have to do.
*/
bool run_round_zero();

/**
* Update the query array with new samples
* @param query an array of sketch sample results
* @param reps an array containing node indices for the representative of each supernode
* Sample a single supernode represented by a single sketch containing one or more vertices.
* Updates the dsu and spanning forest with query results if edge contains new connectivity info.
* @param skt sketch to sample
* @return [bool] true if the query result indicates we should run an additional round.
*/
bool sample_supernode(Sketch &skt);


/**
* Calculate the instructions for what vertices to merge to form each component
*/
Expand All @@ -117,10 +128,6 @@ class CCSketchAlg {
*/
void boruvka_emulation();

FRIEND_TEST(GraphTestSuite, TestCorrectnessOfReheating);

CCAlgConfiguration config;

// constructor for use when reading from a serialized file
CCSketchAlg(node_id_t num_vertices, size_t seed, std::ifstream &binary_stream,
CCAlgConfiguration config);
Expand Down Expand Up @@ -174,7 +181,13 @@ class CCSketchAlg {
* Return if we have cached an answer to query.
* This allows the driver to avoid flushing the gutters before calling query functions.
*/
bool has_cached_query() { return shared_dsu_valid; }
bool has_cached_query(int query_code) {
QueryCode code = (QueryCode) query_code;
if (code == CONNECTIVITY)
return shared_dsu_valid;
else
return false;
}

/**
* Print the configuration of the connected components graph sketching.
Expand All @@ -201,7 +214,7 @@ class CCSketchAlg {

/**
* Main parallel query algorithm utilizing Boruvka and L_0 sampling.
* @return a vector of the connected components in the graph.
* @return the connected components in the graph.
*/
ConnectedComponents connected_components();

Expand All @@ -217,12 +230,11 @@ class CCSketchAlg {
* Return a spanning forest of the graph utilizing Boruvka and L_0 sampling
* IMPORTANT: The updates to this algorithm MUST NOT be a function of the output of this query
* that is, unless you really know what you're doing.
* @return an adjacency list representation of the spanning forest of the graph
* @return the spanning forest of the graph
*/
SpanningForest calc_spanning_forest();

#ifdef VERIFY_SAMPLES_F
std::unique_ptr<GraphVerifier> verifier;
void set_verifier(std::unique_ptr<GraphVerifier> verifier) {
this->verifier = std::move(verifier);
}
Expand Down
102 changes: 86 additions & 16 deletions include/graph_sketch_driver.h
Original file line number Diff line number Diff line change
@@ -1,61 +1,89 @@

#pragma once
#include <cache_guttering.h>
#include <gutter_tree.h>
#include <standalone_gutters.h>

#include "driver_configuration.h"
#include "graph_stream.h"
#include "worker_thread_group.h"
#ifdef VERIFY_SAMPLES_F
#include "graph_verifier.h"
#endif

class DriverException : public std::exception {
private:
std::string err_msg;
public:
DriverException(std::string msg) : err_msg(msg) {}
virtual const char* what() const throw() {
return err_msg.c_str();
}
};

/**
* GraphSketchDriver class:
* Driver for sketching algorithms on a single machine.
* Templatized by the "top level" sketching algorithm to manage.
*
* Algorithms need to implement the following functions to be managed by the driver
* Algorithms need to implement the following functions to be managed by the driver:
*
* 1) void allocate_worker_memory(size_t num_workers)
* For performance reasons it is often helpful for the algorithm to allocate some scratch
* space to be used by an individual worker threads. For example, in the connected
* components algorithm, we allocate a delta sketch for each worker.
* space to be used by individual worker threads. This scratch memory is managed by the
* algorithm. For example, in the connected components algorithm, we allocate a delta
* sketch for each worker.
*
* 2) size_t get_desired_updates_per_batch()
* Return the number of updates the algorithm would like us to batch. This serves as the
* maximum number of updates in a batch. We only provide smaller batches if force_flush'd
* maximum number of updates in a batch. We only provide smaller batches during
* prep_query()
*
* 3) node_id_t get_num_vertices()
* Returns the number of vertices in the Graph or an appropriate upper bound.
*
* 4) void pre_insert(GraphUpdate upd, node_id_t thr_id)
* Called before each update is added to the guttering system for the purpose of eager
* query heuristics. This function must be fast executing.
* query heuristics. This function must be thread-safe and fast executing. The algorithm
* may choose to make this function a no-op.
*
* 5) void apply_update_batch(size_t thr_id, node_id_t src_vertex, const std::vector<node_id_t>
* &dst_vertices)
* Called by worker threads to apply a batch of updates destined for a single vertex.
* Called by worker threads to apply a batch of updates destined for a single vertex. This
* function must be thread-safe.
*
* 6) bool has_cached_query()
* Check if the algorithm already has a cached answer for its query type. If so, the driver
* can skip flushing the updates and applying them in prep_query().
* 6) bool has_cached_query(int query_type)
* Check if the algorithm already has a cached answer for a given query type. If so, the
* driver can skip flushing the updates and applying them in prep_query(). The query_type
* should be defined by the algorithm as an enum (see cc_sketch_alg.h) but is typed in this
* code as an integer to ensure compatability across algorithms.
*
* 7) void print_configuration()
* Print the configuration of the algorithm. The algorithm may choose to print the
* configurations of subalgorithms as well.
*
* 8) void set_verifier(std::unique_ptr<GraphVerifier> verifier);
* If VERIFIER_SAMPLES_F is defined, then the driver provides the algorithm with a
* verifier. The verifier encodes the graph state at the time of a query losslessly
* and should be used by the algorithm to check its query answer. This is only used for
* correctness testing, not for production code.
*/
template <class Alg>
class GraphSketchDriver {
private:
GutteringSystem *gts;
Alg *sketching_alg;
GraphStream *stream;
#ifdef VERIFY_SAMPLES_F
GraphVerifier *verifier;
std::mutex verifier_mtx;
#endif

WorkerThreadGroup<Alg> *worker_threads;

size_t num_stream_threads;
static constexpr size_t update_array_size = 4000;

std::atomic<size_t> total_updates;
FRIEND_TEST(GraphTest, TestSupernodeRestoreAfterCCFailure);
public:
GraphSketchDriver(Alg *sketching_alg, GraphStream *stream, DriverConfiguration config,
size_t num_stream_threads = 1)
Expand Down Expand Up @@ -83,10 +111,14 @@ class GraphSketchDriver {
sketching_alg->print_configuration();

if (num_stream_threads > 1 && !stream->get_update_is_thread_safe()) {
std::cerr << "WARNING: stream get_update is not thread safe. Setting num inserters to 1"
<< std::endl;
std::cerr
<< "WARNING: stream get_update is not thread safe. Setting number of stream threads to 1"
<< std::endl;
num_stream_threads = 1;
}
#ifdef VERIFY_SAMPLES_F
verifier = new GraphVerifier(sketching_alg->get_num_vertices());
#endif

total_updates = 0;
std::cout << std::endl;
Expand All @@ -95,17 +127,29 @@ class GraphSketchDriver {
~GraphSketchDriver() {
delete worker_threads;
delete gts;
#ifdef VERIFY_SAMPLES_F
delete verifier;
#endif
}

/**
* Processes the stream until a given edge index, at which point the function returns
* @param break_edge_idx the breakpoint edge index. All updates up to but not including this
* index are processed by this call.
* @throws DriverException if we cannot set the requested breakpoint.
*/
void process_stream_until(edge_id_t break_edge_idx) {
if (!stream->set_break_point(break_edge_idx)) {
std::cerr << "ERROR: COULD NOT CORRECTLY SET BREAKPOINT!" << std::endl;
DriverException("Could not correctly set breakpoint: " + std::to_string(break_edge_idx));
exit(EXIT_FAILURE);
}
worker_threads->resume_workers();

auto task = [&](int thr_id) {
GraphStreamUpdate update_array[update_array_size];
#ifdef VERIFY_SAMPLES_F
GraphVerifier local_verifier(sketching_alg->get_num_vertices());
#endif

while (true) {
size_t updates = stream->get_update_buffer(update_array, update_array_size);
Expand All @@ -114,13 +158,21 @@ class GraphSketchDriver {
upd.edge = update_array[i].edge;
upd.type = static_cast<UpdateType>(update_array[i].type);
if (upd.type == BREAKPOINT) {
// reached the breakpoint. Update verifier if applicable and return
#ifdef VERIFY_SAMPLES_F
std::lock_guard<std::mutex> lk(verifier_mtx);
verifier->combine(local_verifier);
#endif
return;
}
else {
sketching_alg->pre_insert(upd, thr_id);
Edge edge = upd.edge;
gts->insert({edge.src, edge.dst}, thr_id);
gts->insert({edge.dst, edge.src}, thr_id);
#ifdef VERIFY_SAMPLES_F
local_verifier.edge_update(edge);
#endif
}
}
}
Expand All @@ -131,10 +183,15 @@ class GraphSketchDriver {

// wait for threads to finish
for (size_t i = 0; i < num_stream_threads; i++) threads[i].join();

// pass the verifier to the algorithm
#ifdef VERIFY_SAMPLES_F
sketching_alg->set_verifier(std::make_unique<GraphVerifier>(*verifier));
#endif
}

void prep_query() {
if (sketching_alg->has_cached_query()) {
void prep_query(int query_code) {
if (sketching_alg->has_cached_query(query_code)) {
flush_start = flush_end = std::chrono::steady_clock::now();
return;
}
Expand All @@ -151,6 +208,19 @@ class GraphSketchDriver {
sketching_alg->apply_update_batch(thr_id, src_vertex, dst_vertices);
}

#ifdef VERIFY_SAMPLES_F
/**
* checks that the verifier we constructed in process_stream_until matches another verifier
* @param expected the ground truth verifier
* @throws DriverException if the verifiers do not match
*/
void check_verifier(const GraphVerifier &expected) {
if (*verifier != expected) {
throw DriverException("Mismatch between driver verifier and expected verifier");
}
}
#endif

size_t get_total_updates() { return total_updates.load(); }

// time hooks for experiments
Expand Down
7 changes: 4 additions & 3 deletions include/return_types.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// This file defines the query return types from the cc algorithm class
#pragma once
#include <cstddef>
#include <iterator>
#include <set>
Expand All @@ -20,8 +21,8 @@ class ConnectedComponents {
~ConnectedComponents();

std::vector<std::set<node_id_t>> get_component_sets();
bool is_connected(node_id_t a, node_id_t b) { return parent_arr[a] == parent_arr[b]; }
node_id_t size() { return num_cc; }
bool is_connected(node_id_t a, node_id_t b) const { return parent_arr[a] == parent_arr[b]; }
node_id_t size() const { return num_cc; }
};

// This class defines a spanning forest of a graph
Expand All @@ -32,5 +33,5 @@ class SpanningForest {
public:
SpanningForest(node_id_t num_vertices, const std::unordered_set<node_id_t> *spanning_forest);

const std::vector<Edge>& get_edges() { return edges; }
const std::vector<Edge>& get_edges() const { return edges; }
};
10 changes: 8 additions & 2 deletions include/sketch.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Sketch {
* @return The number of samples
*/
static size_t calc_cc_samples(node_id_t num_vertices, double f) {
return ceil(f * log2(num_vertices) / num_samples_div);
return std::max(size_t(18), (size_t) ceil(f * log2(num_vertices) / num_samples_div));
}

/**
Expand Down Expand Up @@ -191,8 +191,14 @@ class Sketch {
};

class OutOfSamplesException : public std::exception {
private:
std::string err_msg;
public:
OutOfSamplesException(size_t seed, size_t num_samples, size_t sample_idx)
: err_msg("This sketch (seed=" + std::to_string(seed) +
", max samples=" + std::to_string(num_samples) +
") cannot be sampled more times (cur idx=" + std::to_string(sample_idx) + ")!") {}
virtual const char* what() const throw() {
return "This sketch cannot be sampled more times!";
return err_msg.c_str();
}
};
Loading

0 comments on commit d0f3e5b

Please sign in to comment.