-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #52 from N3PDF/cuda_example
Cuda example
- Loading branch information
Showing
11 changed files
with
251 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from vegasflow.configflow import DTYPE, DTYPEINT | ||
import time | ||
import numpy as np | ||
import tensorflow as tf | ||
from vegasflow.plain import plain_wrapper | ||
|
||
# MC integration setup | ||
dim = 4 | ||
ncalls = np.int32(1e6) | ||
n_iter = 5 | ||
|
||
integrand_module = tf.load_op_library('./integrand.so') | ||
|
||
@tf.function | ||
def wrapper_integrand(xarr, **kwargs): | ||
return integrand_module.integrand_op(xarr) | ||
|
||
@tf.function | ||
def fully_python_integrand(xarr, **kwargs): | ||
return tf.reduce_sum(xarr, axis=1) | ||
|
||
if __name__ == "__main__": | ||
print(f"VEGAS MC, ncalls={ncalls}:") | ||
start = time.time() | ||
ncalls = 10*ncalls | ||
r = plain_wrapper(wrapper_integrand, dim, n_iter, ncalls) | ||
end = time.time() | ||
print(f"Vegas took: time (s): {end-start}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
//#include "cuda_kernel.h" | ||
|
||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/shape_inference.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
#include "integrand.h" | ||
|
||
/* | ||
* In this example we follow the TF guide for operation creation | ||
* https://www.tensorflow.org/guide/create_op | ||
* to create an integrand as a custom operators. | ||
* | ||
* To first approximation, these operators are function that take | ||
* a tensor and return a tensor. | ||
*/ | ||
|
||
using namespace tensorflow; | ||
|
||
using GPUDevice = Eigen::GpuDevice; | ||
using CPUDevice = Eigen::ThreadPoolDevice; | ||
|
||
// CPU | ||
template <typename T> | ||
struct IntegrandOpFunctor<CPUDevice, T> { | ||
void operator()(const CPUDevice &d, const T *input, T *output, const int nevents, const int dims) { | ||
for (int i = 0; i < nevents; i++) { | ||
output[i] = 0.0; | ||
for(int j = 0; j < dims; j++) { | ||
output[i] += input[i,j]; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
|
||
/* The input and output type must be coherent with the types used in tensorflow | ||
* at this moment we are using float64 as default for vegasflow. | ||
* | ||
* The output shape is set to be (input_shape[0], ), i.e., number of events | ||
*/ | ||
REGISTER_OP("IntegrandOp") | ||
.Input("xarr: double") | ||
.Output("ret: double") | ||
.SetShapeFn([](shape_inference::InferenceContext* c) { | ||
c -> set_output(0, c -> MakeShape( { c -> Dim(c -> input(0), 0) } ) ); | ||
return Status::OK(); | ||
}); | ||
|
||
template<typename Device, typename T> | ||
class IntegrandOp: public OpKernel { | ||
public: | ||
explicit IntegrandOp(OpKernelConstruction* context): OpKernel(context) {} | ||
|
||
void Compute(OpKernelContext* context) override { | ||
// Grab input tensor, which is expected to be of shape (nevents, ndim) | ||
const Tensor& input_tensor = context->input(0); | ||
auto input = input_tensor.tensor<T, 2>().data(); | ||
auto input_shape = input_tensor.shape(); | ||
|
||
// Create an output tensor of shape (nevents,) | ||
Tensor* output_tensor = nullptr; | ||
TensorShape output_shape; | ||
const int N = input_shape.dim_size(0); | ||
const int dims = input_shape.dim_size(1); | ||
output_shape.AddDim(N); | ||
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); | ||
|
||
auto output_flat = output_tensor->flat<T>().data(); | ||
|
||
// Perform the actual computation | ||
IntegrandOpFunctor<Device, T>()( | ||
context->eigen_device<Device>(), input, output_flat, N, dims | ||
); | ||
} | ||
}; | ||
|
||
// Register the CPU version of the kernel | ||
#define REGISTER_CPU(T) \ | ||
REGISTER_KERNEL_BUILDER(Name("IntegrandOp").Device(DEVICE_CPU), IntegrandOp<CPUDevice, T>); | ||
REGISTER_CPU(double); | ||
|
||
// Register the GPU version | ||
#ifdef KERNEL_CUDA | ||
#define REGISTER_GPU(T) \ | ||
/* Declare explicit instantiations in kernel_example.cu.cc. */ \ | ||
extern template class IntegrandOpFunctor<GPUDevice, T>; \ | ||
REGISTER_KERNEL_BUILDER(Name("IntegrandOp").Device(DEVICE_GPU),IntegrandOp<GPUDevice, T>); | ||
REGISTER_GPU(double); | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#if KERNEL_CUDA | ||
#define EIGEN_USE_GPU | ||
|
||
#include "tensorflow/core/framework/op_kernel.h" | ||
#include "integrand.h" | ||
|
||
using namespace tensorflow; | ||
using GPUDevice = Eigen::GpuDevice; | ||
|
||
// This is the kernel that does the actual computation on device | ||
template<typename T> | ||
__global__ void IntegrandOpKernel(const T *input, T *output, const int nevents, const int ndim) { | ||
const auto gid = blockIdx.x*blockDim.x + threadIdx.x; | ||
// note: this an example of usage, not an example of a very optimal anything... | ||
for (int i = gid; i < nevents; i += blockDim.x*gridDim.x) { | ||
output[i] = 0.0; | ||
for (int j = 0; j < ndim; j++) { | ||
output[i] += input[i,j]; | ||
} | ||
} | ||
} | ||
|
||
// But it still needs to be launched from within C++ | ||
// this bit is to be compared with the functor at the top of integrand.cpp | ||
template <typename T> | ||
void IntegrandOpFunctor<GPUDevice, T>::operator()(const GPUDevice &d, const T *input, T *output, const int nevents, const int dims) { | ||
const int block_count = 1024; | ||
const int thread_per_block = 20; | ||
IntegrandOpKernel<T><<<block_count, thread_per_block, 0, d.stream()>>>(input, output, nevents, dims); | ||
} | ||
|
||
template struct IntegrandOpFunctor<GPUDevice, double>; | ||
|
||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#ifndef KERNEL_INTEGRAND_ | ||
#define KERNEL_INTEGRAND_ | ||
|
||
namespace tensorflow { | ||
using Eigen::GpuDevice; | ||
|
||
template<typename Device, typename T> | ||
struct IntegrandOpFunctor { | ||
void operator()(const Device &d, const T *input, T *output, const int nevents, const int dims); | ||
}; | ||
|
||
#if KERNEL_CUDA | ||
template<typename T> | ||
struct IntegrandOpFunctor<Eigen::GpuDevice, T> { | ||
void operator()(const Eigen::GpuDevice &d, const T *input, T *output, const int nevents, const int dims); | ||
}; | ||
#endif | ||
|
||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
target_lib=integrand.so | ||
|
||
TF_CFLAGS=`python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))' 2> /dev/null` | ||
TF_LFLAGS=`python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))' 2>/dev/null` | ||
|
||
CXX=g++ | ||
CXFLAGS=-std=c++11 -shared -fPIC -O2 | ||
KERNEL_DEF=-D KERNEL_CUDA=1 | ||
NCCFLAGS=-std=c++11 $(KERNEL_DEF) -x cu -Xcompiler -fPIC --disable-warnings | ||
|
||
# Check whether there's nvcc | ||
ifeq (,$(shell which nvcc 2>/dev/null)) | ||
else | ||
NCC:=nvcc | ||
NCCLIB:=$(subst bin/nvcc,lib64, $(shell which nvcc)) | ||
CXFLAGS+=$(KERNEL_DEF) -L$(NCCLIB) -lcudart | ||
kernel_comp=integrand.cu.o | ||
endif | ||
|
||
.PHONY: run clean | ||
|
||
run: $(target_lib) | ||
@python cuda_example.py | ||
|
||
%.cu.o: %.cu.cpp | ||
@echo "[$(NCC)] Integrating cuda kernel..." | ||
@$(NCC) $(NCCFLAGS) -c -o $@ $< $(TF_CFLAGS) | ||
|
||
%.so: %.cpp $(kernel_comp) | ||
@echo "[$(CXX)] Integrating operator..." | ||
@$(CXX) $(CXFLAGS) $(KERNEL) -o $@ $^ $(TF_CFLAGS) $(TF_LFLAGS) | ||
|
||
clean: | ||
rm -f $(target_lib) $(kernel_comp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
"""Monte Carlo integration with Tensorflow""" | ||
|
||
__version__ = '1.0.2' | ||
__version__ = "1.0.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters