API Reference

Verification methods

jax_verify.crownibp_bound_propagation(function: Callable[[...], Union[jax._src.numpy.lax_numpy.ndarray, Sequence[jax._src.numpy.lax_numpy.ndarray], Dict[Any, jax._src.numpy.lax_numpy.ndarray]]], *bounds: Union[jax_verify.src.graph_traversal.InputBound, jax_verify.src.graph_traversal.JittableInputBound, jax._src.numpy.lax_numpy.ndarray, Sequence[Union[jax_verify.src.graph_traversal.InputBound, jax_verify.src.graph_traversal.JittableInputBound, jax._src.numpy.lax_numpy.ndarray]], Dict[Any, Union[jax_verify.src.graph_traversal.InputBound, jax_verify.src.graph_traversal.JittableInputBound, jax._src.numpy.lax_numpy.ndarray]]]) Union[jax_verify.src.bound_propagation.Bound, jax._src.numpy.lax_numpy.ndarray, Sequence[Union[jax_verify.src.bound_propagation.Bound, jax._src.numpy.lax_numpy.ndarray]], Dict[Any, Union[jax_verify.src.bound_propagation.Bound, jax._src.numpy.lax_numpy.ndarray]]][source]

Performs Crown-IBP as described in https://arxiv.org/abs/1906.06316.

We first perform IBP to obtain intermediate bounds and then propagate linear bounds backwards.

Parameters
  • function – Function performing computation to obtain bounds for. Takes as only argument the network inputs.

  • *bounds – jax_verify.IntervalBounds, bounds on the inputs of the function.

Returns

Bounds on the outputs of the function obtained by Crown-IBP

Return type

output_bounds

jax_verify.interval_bound_propagation(function, *bounds)[source]

Performs IBP as described in https://arxiv.org/abs/1810.12715.

Parameters
  • function – Function performing computation to obtain bounds for. Takes as only argument the network inputs.

  • *bounds – jax_verify.IntervalBounds, bounds on the inputs of the function.

Returns

Bounds on the output of the function obtained by IBP

Return type

output_bound

jax_verify.solve_planet_relaxation(logits_fn, initial_bounds, boundprop_transform, objective, objective_bias, index, solver=<class 'jax_verify.src.mip_solver.cvxpy_relaxation_solver.CvxpySolver'>)[source]

Solves the “Planet” (Ehlers 17) or “triangle” relaxation.

The general approach is to use jax_verify to generate constraints, which can then be passed to generic solvers. Note that using CVXPY will incur a large overhead when defining the LP, because we define all constraints element-wise, to avoid representing convolutional layers as a single matrix multiplication, which would be inefficient. In CVXPY, defining large numbers of constraints is slow.

Parameters
  • logits_fn – Mapping from inputs (batch_size x input_size) -> (batch_size, num_classes)

  • initial_boundsIntervalBound with initial bounds on inputs, with lower and upper bounds of dimension (batch_size x input_size).

  • boundprop_transform – bound_propagation.BoundTransform instance, such as jax_verify.ibp_transform. Used to pre-compute interval bounds for intermediate activations used in defining the Planet relaxation.

  • objective – Objective to optimize, given as an array of coefficients to be applied to the output of logits_fn defining the objective to minimize

  • objective_bias – Bias to add to objective

  • index – Index in the batch for which to solve the relaxation

  • solver – A relaxation.RelaxationSolver, which specifies the backend to solve the resulting LP.

Returns

The optimal value from the relaxation solution: The optimal solution found by the solver status: The status of the relaxation solver

Return type

val

Bound objects

class jax_verify.IntervalBound(lower_bound: jax._src.numpy.lax_numpy.ndarray, upper_bound: jax._src.numpy.lax_numpy.ndarray)[source]

Represent an interval where some activations might be valid.

Utility methods

jax_verify.open_file(name, *open_args, **open_kwargs)[source]

Load file, downloading to /tmp/jax_verify first if necessary.