Skip to content
This repository has been archived by the owner on Oct 21, 2024. It is now read-only.

Commit

Permalink
Update to latest codebase of jax_verify
Browse files Browse the repository at this point in the history
  • Loading branch information
bunelr committed Aug 17, 2023
1 parent 391dca1 commit fe8ea3c
Show file tree
Hide file tree
Showing 101 changed files with 4,366 additions and 1,820 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/run_boundprop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/run_examples.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
7 changes: 4 additions & 3 deletions examples/run_lp_solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,7 @@
import jax_verify
from jax_verify.extensions.sdp_verify import utils
from jax_verify.src.linear import forward_linear_bounds
from jax_verify.src.mip_solver.solve_relaxation import solve_planet_relaxation
import numpy as np

MLP_PATH = 'models/raghunathan18_pgdnn.pkl'
Expand Down Expand Up @@ -103,11 +104,11 @@ def main(unused_args):
jnp.ones_like(dummy_output[0, ...]),
jnp.zeros_like(dummy_output[0, ...]))
objective_bias = 0.
value, _, status = jax_verify.solve_planet_relaxation(
value, _, status = solve_planet_relaxation(
logits_fn, init_bound, boundprop_transform, objective,
objective_bias, index=0)
logging.info('Relaxation LB is : %f, Status is %s', value, status)
value, _, status = jax_verify.solve_planet_relaxation(
value, _, status = solve_planet_relaxation(
logits_fn, init_bound, boundprop_transform, -objective,
objective_bias, index=0)
logging.info('Relaxation UB is : %f, Status is %s', -value, status)
Expand Down
5 changes: 2 additions & 3 deletions examples/run_sdp_verify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
r"""Run SDP verification for adversarial robustness specification.
Example launch commands which achieve good results:
Expand Down Expand Up @@ -232,7 +231,7 @@ def run_verification(writer):
'ibp_bound': ibp_bound,
}
output_dict.update(info)
jax_to_np = lambda x: np.array(x) if isinstance(x, jnp.DeviceArray) else x
jax_to_np = lambda x: np.array(x) if isinstance(x, jax.Array) else x
output_dict = jax.tree_map(jax_to_np, output_dict)
writer.write(output_dict)

Expand Down
9 changes: 3 additions & 6 deletions jax_verify/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,21 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library to perform verificaton on Neural Networks.
"""
"""Library to perform verification on Neural Networks."""

from jax_verify.src.bound_propagation import IntervalBound
from jax_verify.src.ibp import bound_transform as ibp_transform
from jax_verify.src.ibp import interval_bound_propagation
from jax_verify.src.ibp import IntervalBound
from jax_verify.src.intersection import IntersectionBoundTransform
from jax_verify.src.linear.backward_crown import backward_crown_bound_propagation
from jax_verify.src.linear.backward_crown import backward_fastlin_bound_propagation
from jax_verify.src.linear.backward_crown import crownibp_bound_propagation
from jax_verify.src.linear.forward_linear_bounds import forward_crown_bound_propagation
from jax_verify.src.linear.forward_linear_bounds import forward_fastlin_bound_propagation
from jax_verify.src.linear.forward_linear_bounds import ibpforwardfastlin_bound_propagation
from jax_verify.src.mip_solver.cvxpy_relaxation_solver import CvxpySolver
from jax_verify.src.mip_solver.solve_relaxation import solve_planet_relaxation
from jax_verify.src.nonconvex.methods import nonconvex_constopt_bound_propagation
from jax_verify.src.nonconvex.methods import nonconvex_ibp_bound_propagation
from jax_verify.src.utils import open_file
8 changes: 4 additions & 4 deletions jax_verify/extensions/functional_lagrangian/attacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,8 +27,8 @@


IntervalBound = jax_verify.IntervalBound
Tensor = jnp.DeviceArray
PRNGKey = jnp.DeviceArray
Tensor = jax.Array
PRNGKey = jax.Array
DataSpec = verify_utils.DataSpec
LayerParams = verify_utils.LayerParams
ModelParams = verify_utils.ModelParams
Expand Down Expand Up @@ -225,7 +225,7 @@ def max_objective_fn_adversarial_softmax(x, prng_key):
else:
raise ValueError('Unsupported spec.')

return _run_attack(
return _run_attack( # pytype: disable=bad-return-type # jax-devicearray
max_objective_fn=max_objective_fn,
projection_fn=projection_fn,
x_init=data_spec.input,
Expand Down
9 changes: 5 additions & 4 deletions jax_verify/extensions/functional_lagrangian/bounding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -330,8 +330,9 @@ def _bounds_from_cnn_layer(self, index):
self._cnn_bounds[layer_index].ub)


def _get_reciprocal_bound(l: jnp.array, u: jnp.array,
logits_params: LayerParams, label: int) -> jnp.array:
def _get_reciprocal_bound(
l: jnp.ndarray, u: jnp.ndarray, logits_params: LayerParams, label: int
) -> jnp.ndarray:
"""Helped for computing bound on label softmax given interval bounds on pre logits."""

def fwd(x, w, b):
Expand Down Expand Up @@ -378,4 +379,4 @@ def upper_bound_log_softmax(
Upper bound on log softmax of target label.
"""
fwd_bound = _get_reciprocal_bound(l, u, logits_params, target_label)
return -jax.nn.logsumexp(-fwd_bound.upper)
return -jax.nn.logsumexp(-fwd_bound.upper) # pytype: disable=attribute-error # jnp-array
2 changes: 1 addition & 1 deletion jax_verify/extensions/functional_lagrangian/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
48 changes: 25 additions & 23 deletions jax_verify/extensions/functional_lagrangian/dual_build.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,14 +29,14 @@
from jax_verify.src import bound_propagation
from jax_verify.src import graph_traversal
from jax_verify.src import synthetic_primitives
from jax_verify.src.types import Nest
import numpy as np
import optax

Params = verify_utils.Params
ParamsTypes = verify_utils.ParamsTypes
InnerVerifInstance = verify_utils.InnerVerifInstance
LagrangianForm = lag_form.LagrangianForm
Nest = bound_propagation.Nest


class DualOp(bound_propagation.Bound):
Expand All @@ -46,9 +46,10 @@ def __init__(
self,
name,
base_bound: bound_propagation.Bound,
affine_fn: Callable[[jnp.array], jnp.array],
inputs: Optional[Sequence[Union['DualOp', jnp.array]]] = None,
relu_preact_name: Optional[int] = None):
affine_fn: Callable[[jnp.ndarray], jnp.ndarray],
inputs: Optional[Sequence[Union['DualOp', jnp.ndarray]]] = None,
relu_preact_name: Optional[int] = None,
):
self.name = name
self._base_bound = base_bound
self._affine_fn = affine_fn
Expand All @@ -60,11 +61,11 @@ def base_bound(self) -> bound_propagation.Bound:
return self._base_bound

@property
def lower(self) -> jnp.array:
def lower(self) -> jnp.ndarray:
return self._base_bound.lower

@property
def upper(self) -> jnp.array:
def upper(self) -> jnp.ndarray:
return self._base_bound.upper

@property
Expand All @@ -89,20 +90,20 @@ def relu_preact_name(self) -> int:
return self._relu_preact_name

@property
def inputs(self) -> Sequence[Union['DualOp', jnp.array]]:
def inputs(self) -> Sequence[Union['DualOp', jnp.ndarray]]:
if self._inputs is None:
raise ValueError('Input node does not have inputs')
return self._inputs


_affine_primitives_list = (
bound_propagation.AFFINE_PRIMITIVES +
bound_propagation.RESHAPE_PRIMITIVES +
[lax.div_p]
)
_affine_primitives_list = [
*bound_propagation.AFFINE_PRIMITIVES,
*bound_propagation.RESHAPE_PRIMITIVES,
lax.div_p,
]


class _LagrangianTransform(bound_propagation.GraphTransform[DualOp]):
class _LagrangianTransform(graph_traversal.GraphTransform[DualOp]):
"""Identifies graph nodes having Lagrangian dual contributions."""

def __init__(self, boundprop_transform: bound_propagation.BoundTransform):
Expand All @@ -121,7 +122,7 @@ def input_transform(self, context, input_bound):
def primitive_transform(self, context, primitive, *args, **params):
interval_args = [arg.base_bound if isinstance(arg, DualOp) else arg
for arg in args]
out_bounds = self._boundprop_transform.equation_transform(
out_bounds, = self._boundprop_transform.equation_transform(
context, primitive, *interval_args, **params)

if primitive in _affine_primitives_list:
Expand Down Expand Up @@ -159,9 +160,9 @@ def solve_max(
self,
inner_dual_vars: Any,
opt_instance: InnerVerifInstance,
key: jnp.array,
key: jnp.ndarray,
step: int,
) -> jnp.array:
) -> jnp.ndarray:
"""Solve maximization problem of opt_instance.
Args:
Expand Down Expand Up @@ -247,8 +248,8 @@ def init_duals(
boundprop_transform: bound_propagation.BoundTransform,
spec_type: verify_utils.SpecType,
affine_before_relu: bool,
spec_fn: Callable[..., jnp.array],
key: jnp.array,
spec_fn: Callable[..., jnp.ndarray],
key: jnp.ndarray,
lagrangian_form_per_layer: Iterable[LagrangianForm],
*input_bounds: Nest[graph_traversal.GraphInput],
) -> Tuple[Dict[int, DualOp], Params, ParamsTypes]:
Expand Down Expand Up @@ -352,7 +353,7 @@ def build_dual_fun(
affine_before_relu: bool,
spec_type: verify_utils.SpecType,
merge_problems: Optional[Dict[int, int]] = None,
) -> Callable[[Params, jnp.array, int], jnp.array]:
) -> Callable[[Params, jnp.ndarray, int], jnp.ndarray]:
"""Build the dual function that takes as input the inner/outer lagrangian parameters.
Args:
Expand All @@ -376,8 +377,9 @@ def build_dual_fun(
objective, and takes as input the inner and outer dual variables, and the
PRNG key.
"""
def dual_loss_fun(dual_params: Params,
key: jnp.array, step: int) -> jnp.array:
def dual_loss_fun(
dual_params: Params, key: jnp.ndarray, step: int
) -> jnp.ndarray:
lagrange_params = dual_params.outer
inner_vars_list = dual_params.inner

Expand All @@ -404,7 +406,7 @@ def dual_loss_fun(dual_params: Params,
loss += loss_inner_problem

stats['loss'] = loss
return loss, stats
return loss, stats # pytype: disable=bad-return-type # jnp-array

return dual_loss_fun

Expand Down
8 changes: 4 additions & 4 deletions jax_verify/extensions/functional_lagrangian/dual_solve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,7 +44,7 @@ def solve_dual(
config: ConfigDict,
bounds: Sequence[sdp_utils.IntBound],
spec_type: verify_utils.SpecType,
spec_fn: Callable[..., jnp.array],
spec_fn: Callable[..., jnp.ndarray],
params: ModelParams,
dual_state: ConfigDict,
mode: str,
Expand Down Expand Up @@ -122,7 +122,7 @@ def solve_dual_train(
spec_type: verify_utils.SpecType,
dual_params_types: ParamsTypes,
logger: Callable[[int, Mapping[str, Any]], None],
key: jnp.array,
key: jnp.ndarray,
num_steps: int,
affine_before_relu: bool,
device_type=None,
Expand Down Expand Up @@ -224,7 +224,7 @@ def solve_dual_eval(
spec_type: verify_utils.SpecType,
dual_params_types: ParamsTypes,
logger: Callable[[int, Mapping[str, Any]], None],
key: jnp.array,
key: jnp.ndarray,
affine_before_relu: bool,
step: int,
merge_problems: Optional[Dict[int, int]] = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 DeepMind Technologies Limited.
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit fe8ea3c

Please sign in to comment.