API Reference

Verification methods

jax_verify.crown_bound_propagation(function, *bounds)[source]

Performs CROWN as described in https://arxiv.org/abs/1811.00866.

Parameters
  • function – Function performing computation to obtain bounds for. Takes as only argument the network inputs.

  • *bounds – jax_verify.IntervalBound, bounds on the inputs of the function.

Returns

Bounds on the output of the function obtained by FastLin

Return type

output_bound

jax_verify.crownibp_bound_propagation(function, bounds)[source]

Performs Crown-IBP as described in https://arxiv.org/abs/1906.06316.

We first perform IBP to obtain intermediate bounds and then propagate linear bounds backwards.

Parameters
  • function – Function performing computation to obtain bounds for. Takes as only argument the network inputs.

  • bounds – jax_verify.IntervalBounds, bounds on the inputs of the function.

Returns

Bounds on the output of the function obtained by Crown-IBP

Return type

output_bound

jax_verify.fastlin_bound_propagation(function, *bounds)[source]

Performs FastLin as described in https://arxiv.org/abs/1804.09699.

Parameters
  • function – Function performing computation to obtain bounds for. Takes as only argument the network inputs.

  • *bounds – jax_verify.IntervalBound, bounds on the inputs of the function.

Returns

Bounds on the output of the function obtained by FastLin

Return type

output_bound

jax_verify.ibpfastlin_bound_propagation(function, *bounds)[source]

Obtains the best of IBP and Fastlin bounds.

Parameters
  • function – Function performing computation to obtain bounds for. Takes as only argument the network inputs.

  • *bounds – jax_verify.IntervalBound, bounds on the inputs of the function.

Returns

Bounds on the output of the function obtained by FastLin

Return type

output_bound

jax_verify.interval_bound_propagation(function, *bounds)[source]

Performs IBP as described in https://arxiv.org/abs/1810.12715.

Parameters
  • function – Function performing computation to obtain bounds for. Takes as only argument the network inputs.

  • *bounds – jax_verify.IntervalBounds, bounds on the inputs of the function.

Returns

Bounds on the output of the function obtained by IBP

Return type

output_bound

jax_verify.solve_planet_relaxation(logits_fn, initial_bounds, boundprop_transform, objective, objective_bias, index, solver=<class 'jax_verify.src.cvxpy_relaxation_solver.CvxpySolver'>)[source]

Solves the “Planet” (Ehlers 17) or “triangle” relaxation.

The general approach is to use jax_verify to generate constraints, which can then be passed to generic solvers. Note that using CVXPY will incur a large overhead when defining the LP, because we define all constraints element-wise, to avoid representing convolutional layers as a single matrix multiplication, which would be inefficient. In CVXPY, defining large numbers of constraints is slow.

Parameters
  • logits_fn – Mapping from inputs (batch_size x input_size) -> (batch_size, num_classes)

  • initial_boundsIntervalBound with initial bounds on inputs, with lower and upper bounds of dimension (batch_size x input_size).

  • boundprop_transform – bound_propagation.BoundTransform instance, such as jax_verify.ibp_transform. Used to pre-compute interval bounds for intermediate activations used in defining the Planet relaxation.

  • objective – Objective to optimize, given as an array of coefficients to be applied to the output of logits_fn defining the objective to minimize

  • objective_bias – Bias to add to objective

  • index – Index in the batch for which to solve the relaxation

  • solver – A relaxation.RelaxationSolver, which specifies the backend to solve the resulting LP.

Returns

The optimal value from the relaxation status: The status of the relaxation solver

Return type

val

Bound objects

class jax_verify.LinearBound(lower_bound: jax_verify.src.fastlin.LinearExpression, upper_bound: jax_verify.src.fastlin.LinearExpression, reference: Optional[LinearBound])[source]

Represent a pair of linear functions that encompass feasible activations.

We store the linear functions as LinearExpressions objects in lower_lin and upper_lin, and also maintain a reference to the initial bounds on the input to be able to concretize the bounds when needed.

class jax_verify.IntervalBound(lower_bound: jax._src.numpy.lax_numpy.ndarray, upper_bound: jax._src.numpy.lax_numpy.ndarray)[source]

Represent an interval where some activations might be valid.

Utility methods

jax_verify.open_file(name, *open_args, **open_kwargs)[source]

Load file, downloading to /tmp/jax_verify first if necessary.