Skip to content

Commit

Permalink
Merge pull request #52 from N3PDF/cuda_example
Browse files Browse the repository at this point in the history
Cuda example
  • Loading branch information
scarlehoff authored Jul 28, 2020
2 parents b60f175 + 92ef4c0 commit 8e02416
Show file tree
Hide file tree
Showing 11 changed files with 251 additions and 24 deletions.
28 changes: 28 additions & 0 deletions examples/cuda/cuda_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from vegasflow.configflow import DTYPE, DTYPEINT
import time
import numpy as np
import tensorflow as tf
from vegasflow.plain import plain_wrapper

# MC integration setup
dim = 4
ncalls = np.int32(1e6)
n_iter = 5

integrand_module = tf.load_op_library('./integrand.so')

@tf.function
def wrapper_integrand(xarr, **kwargs):
return integrand_module.integrand_op(xarr)

@tf.function
def fully_python_integrand(xarr, **kwargs):
return tf.reduce_sum(xarr, axis=1)

if __name__ == "__main__":
print(f"VEGAS MC, ncalls={ncalls}:")
start = time.time()
ncalls = 10*ncalls
r = plain_wrapper(wrapper_integrand, dim, n_iter, ncalls)
end = time.time()
print(f"Vegas took: time (s): {end-start}")
89 changes: 89 additions & 0 deletions examples/cuda/integrand.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//#include "cuda_kernel.h"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "integrand.h"

/*
* In this example we follow the TF guide for operation creation
* https://www.tensorflow.org/guide/create_op
* to create an integrand as a custom operators.
*
* To first approximation, these operators are function that take
* a tensor and return a tensor.
*/

using namespace tensorflow;

using GPUDevice = Eigen::GpuDevice;
using CPUDevice = Eigen::ThreadPoolDevice;

// CPU
template <typename T>
struct IntegrandOpFunctor<CPUDevice, T> {
void operator()(const CPUDevice &d, const T *input, T *output, const int nevents, const int dims) {
for (int i = 0; i < nevents; i++) {
output[i] = 0.0;
for(int j = 0; j < dims; j++) {
output[i] += input[i,j];
}
}
}
};


/* The input and output type must be coherent with the types used in tensorflow
* at this moment we are using float64 as default for vegasflow.
*
* The output shape is set to be (input_shape[0], ), i.e., number of events
*/
REGISTER_OP("IntegrandOp")
.Input("xarr: double")
.Output("ret: double")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c -> set_output(0, c -> MakeShape( { c -> Dim(c -> input(0), 0) } ) );
return Status::OK();
});

template<typename Device, typename T>
class IntegrandOp: public OpKernel {
public:
explicit IntegrandOp(OpKernelConstruction* context): OpKernel(context) {}

void Compute(OpKernelContext* context) override {
// Grab input tensor, which is expected to be of shape (nevents, ndim)
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.tensor<T, 2>().data();
auto input_shape = input_tensor.shape();

// Create an output tensor of shape (nevents,)
Tensor* output_tensor = nullptr;
TensorShape output_shape;
const int N = input_shape.dim_size(0);
const int dims = input_shape.dim_size(1);
output_shape.AddDim(N);
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));

auto output_flat = output_tensor->flat<T>().data();

// Perform the actual computation
IntegrandOpFunctor<Device, T>()(
context->eigen_device<Device>(), input, output_flat, N, dims
);
}
};

// Register the CPU version of the kernel
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("IntegrandOp").Device(DEVICE_CPU), IntegrandOp<CPUDevice, T>);
REGISTER_CPU(double);

// Register the GPU version
#ifdef KERNEL_CUDA
#define REGISTER_GPU(T) \
/* Declare explicit instantiations in kernel_example.cu.cc. */ \
extern template class IntegrandOpFunctor<GPUDevice, T>; \
REGISTER_KERNEL_BUILDER(Name("IntegrandOp").Device(DEVICE_GPU),IntegrandOp<GPUDevice, T>);
REGISTER_GPU(double);
#endif
35 changes: 35 additions & 0 deletions examples/cuda/integrand.cu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#if KERNEL_CUDA
#define EIGEN_USE_GPU

#include "tensorflow/core/framework/op_kernel.h"
#include "integrand.h"

using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;

// This is the kernel that does the actual computation on device
template<typename T>
__global__ void IntegrandOpKernel(const T *input, T *output, const int nevents, const int ndim) {
const auto gid = blockIdx.x*blockDim.x + threadIdx.x;
// note: this an example of usage, not an example of a very optimal anything...
for (int i = gid; i < nevents; i += blockDim.x*gridDim.x) {
output[i] = 0.0;
for (int j = 0; j < ndim; j++) {
output[i] += input[i,j];
}
}
}

// But it still needs to be launched from within C++
// this bit is to be compared with the functor at the top of integrand.cpp
template <typename T>
void IntegrandOpFunctor<GPUDevice, T>::operator()(const GPUDevice &d, const T *input, T *output, const int nevents, const int dims) {
const int block_count = 1024;
const int thread_per_block = 20;
IntegrandOpKernel<T><<<block_count, thread_per_block, 0, d.stream()>>>(input, output, nevents, dims);
}

template struct IntegrandOpFunctor<GPUDevice, double>;


#endif
21 changes: 21 additions & 0 deletions examples/cuda/integrand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef KERNEL_INTEGRAND_
#define KERNEL_INTEGRAND_

namespace tensorflow {
using Eigen::GpuDevice;

template<typename Device, typename T>
struct IntegrandOpFunctor {
void operator()(const Device &d, const T *input, T *output, const int nevents, const int dims);
};

#if KERNEL_CUDA
template<typename T>
struct IntegrandOpFunctor<Eigen::GpuDevice, T> {
void operator()(const Eigen::GpuDevice &d, const T *input, T *output, const int nevents, const int dims);
};
#endif

}

#endif
34 changes: 34 additions & 0 deletions examples/cuda/makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
target_lib=integrand.so

TF_CFLAGS=`python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))' 2> /dev/null`
TF_LFLAGS=`python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))' 2>/dev/null`

CXX=g++
CXFLAGS=-std=c++11 -shared -fPIC -O2
KERNEL_DEF=-D KERNEL_CUDA=1
NCCFLAGS=-std=c++11 $(KERNEL_DEF) -x cu -Xcompiler -fPIC --disable-warnings

# Check whether there's nvcc
ifeq (,$(shell which nvcc 2>/dev/null))
else
NCC:=nvcc
NCCLIB:=$(subst bin/nvcc,lib64, $(shell which nvcc))
CXFLAGS+=$(KERNEL_DEF) -L$(NCCLIB) -lcudart
kernel_comp=integrand.cu.o
endif

.PHONY: run clean

run: $(target_lib)
@python cuda_example.py

%.cu.o: %.cu.cpp
@echo "[$(NCC)] Integrating cuda kernel..."
@$(NCC) $(NCCFLAGS) -c -o $@ $< $(TF_CFLAGS)

%.so: %.cpp $(kernel_comp)
@echo "[$(CXX)] Integrating operator..."
@$(CXX) $(CXFLAGS) $(KERNEL) -o $@ $^ $(TF_CFLAGS) $(TF_LFLAGS)

clean:
rm -f $(target_lib) $(kernel_comp)
2 changes: 1 addition & 1 deletion examples/simgauss_tf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Place your function here
from vegasflow.configflow import DTYPE, DTYPEINT
import time
import numpy as np
import tensorflow as tf
from vegasflow.configflow import DTYPE, DTYPEINT
from vegasflow.vflow import vegas_wrapper
from vegasflow.plain import plain_wrapper

Expand Down
2 changes: 1 addition & 1 deletion src/vegasflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Monte Carlo integration with Tensorflow"""

__version__ = '1.0.2'
__version__ = "1.0.2"
22 changes: 21 additions & 1 deletion src/vegasflow/configflow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
"""
Define some constants, header style
"""
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
# Most of this can be moved to a yaml file without loss of generality
import tensorflow as tf

# uncomment this line for debugging to avoid compiling any tf.function
# tf.config.experimental_run_functions_eagerly(True)

# Configure logging
import logging

module_name = __name__.split(".")[0]
logger = logging.getLogger(module_name)
# Set level debug for development
logger.setLevel(logging.DEBUG)
# Create a handler and format it
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_format = logging.Formatter("[%(levelname)s] %(message)s")
console_handler.setFormatter(console_format)
logger.addHandler(console_handler)

# Define the tf.numberic types
DTYPE = tf.float64
DTYPEINT = tf.int32
Expand All @@ -15,7 +35,7 @@
# Set up the logistics of the integration
# Events Limit limits how many events are done in one single run of the event_loop
# set it lower if hitting memory problems
MAX_EVENTS_LIMIT = int(1e7)
MAX_EVENTS_LIMIT = int(1e6)
# Select the list of devices to look for
DEFAULT_ACTIVE_DEVICES = ["GPU"] # , 'CPU']

Expand Down
15 changes: 10 additions & 5 deletions src/vegasflow/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,21 @@
import tensorflow as tf
from vegasflow.configflow import MAX_EVENTS_LIMIT, DEFAULT_ACTIVE_DEVICES, DTYPE

import logging

logger = logging.getLogger(__name__)


def print_iteration(it, res, error, extra="", threshold=0.1):
""" Checks the size of the result to select between
scientific notation and floating point notation """
# note: actually, the flag 'g' does this automatically
# but I prefer to choose the precision myself...
if res < threshold:
print(f"Result for iteration {it}: {res:.3e} +/- {error:.3e}" + extra)
logger.info(f"Result for iteration {it}: {res:.3e} +/- {error:.3e}" + extra)
else:
print(f"Result for iteration {it}: {res:.4f} +/- {error:.4f}" + extra)
logger.info(f"Result for iteration {it}: {res:.4f} +/- {error:.4f}" + extra)


def _accumulate(accumulators):
""" Accumulate all the quantities in accumulators
Expand Down Expand Up @@ -135,8 +140,8 @@ def events_per_run(self, val):
""" Set the number of events per single step """
self._events_per_run = min(val, self.n_events)
if self.n_events % self._events_per_run != 0:
print(
f"Warning, the number of events per run step {self._events_per_run} doesn't perfectly"
logger.warning(
f"The number of events per run step {self._events_per_run} doesn't perfectly"
f"divide the number of events {self.n_events}, which can harm performance"
)

Expand Down Expand Up @@ -379,7 +384,7 @@ def run_integration(self, n_iter, log_time=True, histograms=None):

final_result = aux_res / weight_sum
sigma = np.sqrt(1.0 / weight_sum)
print(f" > Final results: {final_result.numpy():g} +/- {sigma:g}")
logger.info(f" > Final results: {final_result.numpy():g} +/- {sigma:g}")
return final_result, sigma


Expand Down
4 changes: 1 addition & 3 deletions src/vegasflow/plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def _run_event(self, integrand, ncalls=None):
# Jacobian
xjac = 1.0 / self.n_events
# Generate all random number for this iteration
rnds = tf.random.uniform(
(n_events, self.n_dim), minval=0, maxval=1, dtype=DTYPE
)
rnds = tf.random.uniform((n_events, self.n_dim), minval=0, maxval=1, dtype=DTYPE)
# Compute the integrand
tmp = integrand(rnds, n_dim=self.n_dim, weight=xjac) * xjac
tmp2 = tf.square(tmp)
Expand Down
23 changes: 10 additions & 13 deletions src/vegasflow/vflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from vegasflow.monte_carlo import MonteCarloFlow, wrapper
from vegasflow.utils import consume_array_into_indices

import logging

logger = logging.getLogger(__name__)

FBINS = float_me(BINS_MAX)

# Auxiliary functions for Vegas
Expand Down Expand Up @@ -135,10 +139,7 @@ def while_body(bin_weight, n_bin, cur, prev):
prev = fzero
for _ in range(BINS_MAX - 1):
bin_weight, n_bin, cur, prev = tf.while_loop(
while_check,
while_body,
(bin_weight, n_bin, cur, prev),
parallel_iterations=1,
while_check, while_body, (bin_weight, n_bin, cur, prev), parallel_iterations=1,
)
bin_weight -= ave_t
delta = (cur - prev) * bin_weight / wei_t[n_bin]
Expand Down Expand Up @@ -232,8 +233,8 @@ def load_grid(self, file_name=None, numpy_grid=None):
integrand_name = self.integrand.__name__
integrand_grid = json_dict.get("integrand")
if integrand_name != integrand_grid:
print(
f"WARNING: The grid was written for the integrand: {integrand_grid}"
logger.warning(
f"The grid was written for the integrand: {integrand_grid}"
f"which is different from {integrand_name}"
)
# Now that everything is clear, let's load up the grid
Expand All @@ -255,7 +256,7 @@ def load_grid(self, file_name=None, numpy_grid=None):
f"current settings is of {self.grid_bins} bins"
)
if file_name:
print(f" > SUCCESS: Loaded grid from {file_name}")
logger.info(f" > SUCCESS: Loaded grid from {file_name}")
self.divisions.assign(numpy_grid)

def refine_grid(self, arr_res2):
Expand All @@ -270,9 +271,7 @@ def refine_grid(self, arr_res2):
Function not compiled
"""
for j in range(self.n_dim):
new_divisions = refine_grid_per_dimension(
arr_res2[j, :], self.divisions[j, :]
)
new_divisions = refine_grid_per_dimension(arr_res2[j, :], self.divisions[j, :])
self.divisions[j, :].assign(new_divisions)

def _run_event(self, integrand, ncalls=None):
Expand Down Expand Up @@ -317,9 +316,7 @@ def _run_event(self, integrand, ncalls=None):
# If the training is active, save the result of the integral sq
for j in range(self.n_dim):
arr_res2.append(
consume_array_into_indices(
tmp2, ind[:, j : j + 1], self.grid_bins - 1
)
consume_array_into_indices(tmp2, ind[:, j : j + 1], self.grid_bins - 1)
)
arr_res2 = tf.reshape(arr_res2, (self.n_dim, -1))

Expand Down

0 comments on commit 8e02416

Please sign in to comment.