You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So, it's nice to be able to have runtime asserts, but we can't have them in a pure function.
In X86, when the processor encounters an error, it raises the floating point exception flag and goes on. You can check that flag at the end of the computation.
I prototyped a similar mechanism in my code. An error flag is raised during the computation. At the end of the computation, we check whether the error flag has been raised.
Unlike asserts, this method:
Can be enabled in prod on the same graph. You could place the checks on a device of your choosing.
Does not know which error came first. If there is a cascade of errors, you will not know. I don't think you can know unless you annotate the graph, or provide a jax op to timestamp the execution.
Does not interrupt errors which cause infinite loops.
The second issue is the annoying one. My assumption is that you errors should generally happen in the order in which they are constructed.
I've implemented it in my model.py but it should go in layers.
def SafePmap(errsrc, step, **pmap_kwargs):
# errsrc = jax_task.model which contains errflags and errinfo
import numpy as np
def StepWithErrflag(*args, **kwargs):
# The step will construct errflag and errinfo.
# This is pure and can be compiled.
ret = list(step(*args, **kwargs))
ret += [errsrc.errflag]
return ret
def RunCheckCompileStep(*arg, **kwargs):
ret = compiled_step(*arg, **kwargs)
# This runs in python land and we can access python objects.
errinfo = errsrc.errinfo
errflag = ret[-1][0]
if np.sum(errflag):
logging.info('==== ERRORS found: %s', np.sum(errflag))
for k, flag in enumerate(errflag):
if flag:
logging.info('== ERR[%d] at %s', k, '\n'.join(errinfo[k].format()))
raise ValueError('Exception flag raised')
return ret[:-1]
compiled_step = jax.pmap(StepWithErrflag, **pmap_kwargs)
return RunCheckCompileStep
...
def train_and_evaluate_pmap(...):
...
p_train_step = SafePmap(jax_task.model, train_step, donate_argnums=(0,), axis_name='batch')
There is the pjit.pjit in trainer_lib.py as well but it's not called in my path.
The text was updated successfully, but these errors were encountered:
So, it's nice to be able to have runtime asserts, but we can't have them in a pure function.
In X86, when the processor encounters an error, it raises the floating point exception flag and goes on. You can check that flag at the end of the computation.
I prototyped a similar mechanism in my code. An error flag is raised during the computation. At the end of the computation, we check whether the error flag has been raised.
Unlike asserts, this method:
The second issue is the annoying one. My assumption is that you errors should generally happen in the order in which they are constructed.
I've implemented it in my
model.py
but it should go in layers.Then, in
lingvo/jax/train.py
:There is the
pjit.pjit
intrainer_lib.py
as well but it's not called in my path.The text was updated successfully, but these errors were encountered: