diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 04be0da..3f1d533 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -25,7 +25,7 @@ jobs: run: | pip install pylint # Error out only in actual errors - pylint src/*/*.py -E + pylint src/*/*.py -E -d E1123,E1120 pylint src/*/*.py --exit-zero - name: Test with pytest run: | diff --git a/.pylintrc b/.pylintrc index 75d5f50..a0febe5 100644 --- a/.pylintrc +++ b/.pylintrc @@ -2,8 +2,11 @@ # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code -extension-pkg-whitelist=numpy +# run arbitrary code. +extension-pkg-whitelist=numpy,tensorflow + +# Specify a score threshold to be exceeded before program exits with error. +fail-under=10 # Add files or directories to the blacklist. They should be base names, not # paths. @@ -17,18 +20,25 @@ ignore-patterns= # pygtk.require(). #init-hook= -# Use multiple processes to speed up Pylint. +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. jobs=2 -# List of plugins (as comma separated values of python modules names) to load, +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, # usually to register additional checkers. load-plugins= # Pickle collected data for later comparisons. persistent=yes -# Specify a configuration file. -#rcfile= +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. @@ -38,124 +48,131 @@ unsafe-load-any-extension=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. confidence= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to +# file where it should appear only once). You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable= - print-statement, - parameter-unpacking, - unpacking-in-except, - old-raise-syntax, - backtick, - long-suffix, - old-ne-operator, - old-octal-literal, - import-star-module-level, - raw-checker-failed, - bad-inline-option, - locally-disabled, - locally-enabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - apply-builtin, - basestring-builtin, - buffer-builtin, - cmp-builtin, - coerce-builtin, - execfile-builtin, - file-builtin, - long-builtin, - raw_input-builtin, - reduce-builtin, - standarderror-builtin, - unicode-builtin, - xrange-builtin, - coerce-method, - delslice-method, - getslice-method, - setslice-method, - no-absolute-import, - old-division, - dict-iter-method, - dict-view-method, - next-method-called, - metaclass-assignment, - indexing-exception, - raising-string, - reload-builtin, - oct-method, - hex-method, - nonzero-method, - cmp-method, - input-builtin, - round-builtin, - intern-builtin, - unichr-builtin, - map-builtin-not-iterating, - zip-builtin-not-iterating, - range-builtin-not-iterating, - filter-builtin-not-iterating, - using-cmp-argument, - eq-without-hash, - div-method, - idiv-method, - rdiv-method, - exception-message-attribute, - invalid-str-codec, - sys-max-int, - bad-python3-import, - deprecated-string-function, - deprecated-str-translate-call, - invalid-name, - too-few-public-methods, - too-many-arguments, - bad-continuation, - redefined-outer-name, - missing-docstring, - bad-whitespace, - no-self-use, - no-else-return, - global-statement, - too-many-public-method, - too-many-ancestors +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + invalid-name, + too-few-public-methods, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape, + E1123, # pylint is not able to deal with tensorflow + E1120, # same as above + C0330, # black indentation when breaking long lines is better + + # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. -enable= +enable=c-extension-no-member [REPORTS] -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details +# used to format the message information. See doc for all details. #msg-template= # Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio).You can also give a reporter class, eg +# and msvs (visual studio). You can also give a reporter class, e.g. # mypackage.mymodule.MyReporterClass. output-format=text -# Tells whether to display a full report or only the messages +# Tells whether to display a full report or only the messages. reports=no # Activate the evaluation score. @@ -167,26 +184,77 @@ score=yes # Maximum number of nested blocks for function / method body max-nested-blocks=5 +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit -[MISCELLANEOUS] -# List of note tags to take in consideration, separated by a comma. -notes=FIXME,XXX,TODO +[STRING] +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no -[SIMILARITIES] +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no -# Ignore comments when computing similarities. -ignore-comments=yes -# Ignore docstrings when computing similarities. -ignore-docstrings=yes +[TYPECHECK] -# Ignore imports when computing similarities. -ignore-imports=no +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager -# Minimum lines number of a similarity. -min-similarity-lines=4 +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=tensorflow + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= [BASIC] @@ -237,7 +305,11 @@ function-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ # Good variable names which should always be accepted, separated by a comma good-names=i,j,k,ex,Run,_ -# Include a hint for the correct naming format with invalid-name +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. include-naming-hint=no # Naming hint for inline iteration names @@ -268,6 +340,7 @@ no-docstring-rgx=^_ # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. property-classes=abc.abstractproperty # Naming hint for variable names @@ -280,7 +353,7 @@ variable-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. +# you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. @@ -290,12 +363,12 @@ allow-global-unused-variables=yes # name must start or end with one of those strings. callbacks=cb_,_cb -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. Default to name -# with leading underscore +# with leading underscore. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. @@ -314,7 +387,7 @@ expected-line-ending-format= # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ -# Number of spaces of indent required inside a hanging or continued line. +# Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 @@ -324,7 +397,7 @@ indent-string=' ' # Maximum number of characters on a single line. max-line-length=100 -# Maximum number of lines in a module +# Maximum number of lines in a module. max-module-lines=1000 # List of optional constructs for which whitespace checking is disabled. `dict- @@ -342,90 +415,104 @@ single-line-class-stmt=no single-line-if-stmt=no +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +#notes-rgx= + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + [SPELLING] -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= -# A path to a file that contains private dictionary; one word per line. +# A path to a file that contains the private dictionary; one word per line. spelling-private-dict-file= -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. spelling-store-unknown-words=no -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager +[IMPORTS] -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules=matplotlib.cm +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= -[LOGGING] +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= [DESIGN] -# Maximum number of arguments for function / method +# Maximum number of arguments for function / method. max-args=5 # Maximum number of attributes for a class (see R0902). max-attributes=7 -# Maximum number of boolean expressions in a if statement +# Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 -# Maximum number of branch for function / method body +# Maximum number of branch for function / method body. max-branches=12 # Maximum number of locals for function / method body @@ -437,10 +524,10 @@ max-parents=7 # Maximum number of public methods for a class (see R0904). max-public-methods=20 -# Maximum number of return / yield for function / method body +# Maximum number of return / yield for function / method body. max-returns=6 -# Maximum number of statements in function / method body +# Maximum number of statements in function / method body. max-statements=50 # Minimum number of public methods for a class (see R0903). @@ -450,54 +537,29 @@ min-public-methods=2 [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__,__new__,setUp +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ # List of member names, which should be excluded from the protected access # warning. -exclude-protected=_asdict,_fields,_replace,_source,_make +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[IMPORTS] - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant +valid-metaclass-classmethod-first-arg=cls [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=Exception +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/PKGBUILD b/PKGBUILD new file mode 100644 index 0000000..99a2819 --- /dev/null +++ b/PKGBUILD @@ -0,0 +1,40 @@ +# Maintainer: Juacrumar + +pkgname=python-vegasflow +_name=vegasflow +pkgver=1.0.2 +pkgrel=1 +pkgdesc='Monte Carlo integration library written in Python and based on the TensorFlow framework' +arch=('any') +url="https://vegasflow.readthedocs.io/" +license=('GPL3') +depends=("python>=3.6" + "python-tensorflow" + "python-joblib" + "python-numpy") +optdepends=("python-cffi: interfacing vegasflow with C code" + "python-tensorflow-cuda: GPU support") +# checkdepends=("python-pytest") +provides=("vegasflow") +changelog= +source=("https://github.com/N3PDF/vegasflow/archive/v.${pkgver}.tar.gz") +md5sums=("118fa9906f588ab7ecd320728c478ade") + +prepare() { + cd "$_name-v.$pkgver" +} + +# check() { +# cd "$_name-v.$pkgver" +# pytest +# } + +build() { + cd "$_name-v.$pkgver" + python setup.py build +} + +package() { + cd "$_name-v.$pkgver" + python setup.py install --root="$pkgdir" --optimize=2 --skip-build +} diff --git a/README.md b/README.md index 6f4beb1..b57b85f 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ [![Tests](https://github.com/N3PDF/vegasflow/workflows/pytest/badge.svg)](https://github.com/N3PDF/vegasflow/actions?query=workflow%3A%22pytest%22) [![Documentation Status](https://readthedocs.org/projects/vegasflow/badge/?version=latest)](https://vegasflow.readthedocs.io/en/latest/?badge=latest) [![DOI](https://zenodo.org/badge/226363558.svg)](https://zenodo.org/badge/latestdoi/226363558) -[![Anaconda-Server Badge](https://anaconda.org/conda-forge/vegasflow/badges/installer/conda.svg)](https://conda.anaconda.org/conda-forge) +[![Anaconda-Server Badge](https://anaconda.org/conda-forge/vegasflow/badges/installer/conda.svg)](https://anaconda.org/conda-forge/vegasflow) +[![AUR](https://img.shields.io/aur/version/python-vegasflow)](https://aur.archlinux.org/packages/python-vegasflow/) # VegasFlow diff --git a/src/vegasflow/configflow.py b/src/vegasflow/configflow.py index 248472e..a390e7c 100644 --- a/src/vegasflow/configflow.py +++ b/src/vegasflow/configflow.py @@ -4,7 +4,7 @@ # Most of this can be moved to a yaml file without loss of generality import tensorflow as tf -# Define the tensorflow numberic types +# Define the tf.numberic types DTYPE = tf.float64 DTYPEINT = tf.int32 @@ -17,15 +17,17 @@ # set it lower if hitting memory problems MAX_EVENTS_LIMIT = int(1e7) # Select the list of devices to look for -DEFAULT_ACTIVE_DEVICES = ['GPU']#, 'CPU'] +DEFAULT_ACTIVE_DEVICES = ["GPU"] # , 'CPU'] # Create wrappers in order to have numbers of the correct type def int_me(i): - return tf.constant(i, dtype=DTYPEINT) + """ Casts any interger to DTYPEINT """ + return tf.cast(i, dtype=DTYPEINT) def float_me(i): - return tf.constant(i, dtype=DTYPE) + """ Cast any float to DTYPE """ + return tf.cast(i, dtype=DTYPE) ione = int_me(1) diff --git a/src/vegasflow/monte_carlo.py b/src/vegasflow/monte_carlo.py index de1f471..7f3968a 100644 --- a/src/vegasflow/monte_carlo.py +++ b/src/vegasflow/monte_carlo.py @@ -34,7 +34,6 @@ import time import copy -import inspect import threading from abc import abstractmethod, ABC import joblib @@ -53,6 +52,28 @@ def print_iteration(it, res, error, extra="", threshold=0.1): else: print(f"Result for iteration {it}: {res:.4f} +/- {error:.4f}" + extra) +def _accumulate(accumulators): + """ Accumulate all the quantities in accumulators + The default accumulation is implemented for tensorflow tensors + as a sum of all partial results. + + Parameters + ---------- + `accumulators`: list of tensorflow tensors + + Returns + ------- + `results`: `sum` for each element of the accumulators + + Function not compiled + """ + results = [] + len_acc = len(accumulators[0]) + for i in range(len_acc): + total = tf.reduce_sum([acc[i] for acc in accumulators], axis=0) + results.append(total) + return results + class MonteCarloFlow(ABC): """ @@ -115,7 +136,8 @@ def events_per_run(self, val): self._events_per_run = min(val, self.n_events) if self.n_events % self._events_per_run != 0: print( - f"Warning, the number of events per run step {self._events_per_run} doesn't perfectly divide the number of events {self.n_events}, which can harm performance" + f"Warning, the number of events per run step {self._events_per_run} doesn't perfectly" + f"divide the number of events {self.n_events}, which can harm performance" ) @property @@ -169,29 +191,7 @@ def release_device(self, device): finally: self.lock.release() - def accumulate(self, accumulators): - """ Accumulate all the quantities in accumulators - The default accumulation is implemented for tensorflow tensors - as a sum of all partial results. - - Parameters - ---------- - `accumulators`: list of tensorflow tensors - - Returns - ------- - `results`: `sum` for each element of the accumulators - - Function not compiled - """ - results = [] - len_acc = len(accumulators[0]) - for i in range(len_acc): - total = tf.reduce_sum([acc[i] for acc in accumulators], axis=0) - results.append(total) - return results - - def device_run(self, ncalls, sent_pc = 100.0, **kwargs): + def device_run(self, ncalls, sent_pc=100.0, **kwargs): """ Wrapper function to select a specific device when running the event If the devices were not set, tensorflow default will be used @@ -204,7 +204,7 @@ def device_run(self, ncalls, sent_pc = 100.0, **kwargs): `result`: raw result from the integrator """ if self._verbose: - print(f"Events sent to the computing device: {sent_pc:.1f} %", end='\r') + print(f"Events sent to the computing device: {sent_pc:.1f} %", end="\r") if not self.event: raise RuntimeError("Compile must be ran before running any iterations") if self.devices: @@ -242,7 +242,7 @@ def run_event(self, **kwargs): pc = 0.0 while events_left > 0: ncalls = min(events_left, self.events_per_run) - pc += ncalls/self.n_events*100 + pc += ncalls / self.n_events * 100 percentages.append(pc) events_to_do.append(ncalls) events_left -= self.events_per_run @@ -250,7 +250,7 @@ def run_event(self, **kwargs): if self.devices: running_pool = [] for ncalls, pc in zip(events_to_do, percentages): - delay_job = joblib.delayed(self.device_run)(ncalls, sent_pc = pc,**kwargs) + delay_job = joblib.delayed(self.device_run)(ncalls, sent_pc=pc, **kwargs) running_pool.append(delay_job) accumulators = self.pool(running_pool) else: @@ -258,7 +258,7 @@ def run_event(self, **kwargs): for ncalls, pc in zip(events_to_do, percentages): res = self.device_run(ncalls, sent_pc=pc, **kwargs) accumulators.append(res) - return self.accumulate(accumulators) + return _accumulate(accumulators) def compile(self, integrand, compilable=True): """ Receives an integrand, prepares it for integration @@ -383,9 +383,7 @@ def run_integration(self, n_iter, log_time=True, histograms=None): return final_result, sigma -def wrapper( - integrator_class, integrand, n_dim, n_iter, total_n_events, compilable=True -): +def wrapper(integrator_class, integrand, n_dim, n_iter, total_n_events, compilable=True): """ Convenience wrapper Parameters diff --git a/src/vegasflow/plain.py b/src/vegasflow/plain.py index 9038a37..4edb249 100644 --- a/src/vegasflow/plain.py +++ b/src/vegasflow/plain.py @@ -2,7 +2,7 @@ Plain implementation of the plainest possible MonteCarlo """ -from vegasflow.configflow import DTYPE, DTYPEINT, fone, fzero, float_me +from vegasflow.configflow import DTYPE, fone, fzero from vegasflow.monte_carlo import MonteCarloFlow, wrapper import tensorflow as tf @@ -42,4 +42,5 @@ def _run_iteration(self): def plain_wrapper(*args): + """ Wrapper around PlainFlow """ return wrapper(PlainFlow, *args) diff --git a/src/vegasflow/utils.py b/src/vegasflow/utils.py index 508a1e7..39e6c54 100644 --- a/src/vegasflow/utils.py +++ b/src/vegasflow/utils.py @@ -5,6 +5,7 @@ import tensorflow as tf from vegasflow.configflow import DTYPEINT, fzero + @tf.function def consume_array_into_indices(input_arr, indices, result_size): """ @@ -30,8 +31,8 @@ def consume_array_into_indices(input_arr, indices, result_size): `final_result` Array of size `result_size` """ - all_bins = tf.range(result_size, dtype = DTYPEINT) + all_bins = tf.range(result_size, dtype=DTYPEINT) eq = tf.transpose(tf.equal(indices, all_bins)) res_tmp = tf.where(eq, input_arr, fzero) - final_result = tf.reduce_sum(res_tmp, axis = 1) + final_result = tf.reduce_sum(res_tmp, axis=1) return final_result diff --git a/src/vegasflow/vflow.py b/src/vegasflow/vflow.py index 7cf0cc3..1b14572 100644 --- a/src/vegasflow/vflow.py +++ b/src/vegasflow/vflow.py @@ -9,7 +9,7 @@ import numpy as np import tensorflow as tf -from vegasflow.configflow import DTYPE, DTYPEINT, fone, fzero, float_me, ione, izero +from vegasflow.configflow import DTYPE, DTYPEINT, fone, fzero, float_me, ione from vegasflow.configflow import BINS_MAX, ALPHA from vegasflow.monte_carlo import MonteCarloFlow, wrapper from vegasflow.utils import consume_array_into_indices @@ -162,6 +162,7 @@ def __init__(self, n_dim, n_events, train=True, **kwargs): # otherwise it will be frozen self.train = train self.iteration_content = None + self.compile_args = None # Initialize grid self.grid_bins = BINS_MAX + 1 @@ -215,10 +216,12 @@ def load_grid(self, file_name=None, numpy_grid=None): """ if file_name is not None and numpy_grid is not None: raise ValueError( - "Received both a numpy grid and a file_name to load the grid from. Ambiguous call to `load_grid`" + "Received both a numpy grid and a file_name to load the grid from." + "Ambiguous call to `load_grid`" ) + # If it received a file, loads up the grid - elif file_name: + if file_name: with open(file_name, "r") as f: json_dict = json.load(f) # First check the parameters of the grid are unchanged @@ -230,7 +233,8 @@ def load_grid(self, file_name=None, numpy_grid=None): integrand_grid = json_dict.get("integrand") if integrand_name != integrand_grid: print( - f"WARNING: The grid was written for the integrand: {integrand_grid} which is different from {integrand_name}" + f"WARNING: The grid was written for the integrand: {integrand_grid}" + f"which is different from {integrand_name}" ) # Now that everything is clear, let's load up the grid numpy_grid = np.array(json_dict["grid"]) @@ -242,11 +246,13 @@ def load_grid(self, file_name=None, numpy_grid=None): # Check that the grid has the right dimensions if grid_dim is not None and self.n_dim != grid_dim: raise ValueError( - f"Received a {grid_dim}-dimensional grid while VegasFlow was instantiated with {self.n_dim} dimensions" + f"Received a {grid_dim}-dimensional grid while VegasFlow" + f"was instantiated with {self.n_dim} dimensions" ) if grid_bins is not None and self.grid_bins != grid_bins: raise ValueError( - f"The received grid contains {grid_bins} bins while the current settings is of {self.grid_bins} bins" + f"The received grid contains {grid_bins} bins while the" + f"current settings is of {self.grid_bins} bins" ) if file_name: print(f" > SUCCESS: Loaded grid from {file_name}") @@ -311,7 +317,9 @@ def _run_event(self, integrand, ncalls=None): # If the training is active, save the result of the integral sq for j in range(self.n_dim): arr_res2.append( - consume_array_into_indices(tmp2, ind[:, j : j + 1], self.grid_bins-1) + consume_array_into_indices( + tmp2, ind[:, j : j + 1], self.grid_bins - 1 + ) ) arr_res2 = tf.reshape(arr_res2, (self.n_dim, -1)) @@ -334,6 +342,10 @@ def compile(self, integrand, compilable=True, **kwargs): self.iteration_content = self._iteration_content def recompile(self): + """ Forces recompilation with the same arguments that have + previously been used for compilation""" + if self.compile_args is None: + raise RuntimeError("recompile was called without ever having called compile") a = self.compile_args self.compile(a[0], a[1], **a[2])