diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index 7367024..664b7ed 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -345,23 +345,12 @@ def step( if linear_state is not None: state_next.cg_state = linear_state - # Compute termination criteria. - state_next.termination_criteria, state_next.termination_deltas = ( - self.termination._check_convergence( - state, - cost_updated=proposed_cost, - tangent=local_delta, - tangent_ordering=graph.tangent_ordering, - ATb=ATb, - ) - ) - # Always accept Gauss-Newton steps. if self.trust_region is None: state_next.vals = vals state_next.residual_vector = proposed_residual_vector state_next.cost = proposed_cost - + accept_flag = None # For Levenberg-Marquardt, we need to evaluate the step quality. else: step_quality = (proposed_cost - state.cost) / ( @@ -373,11 +362,6 @@ def step( ) accept_flag = step_quality >= self.trust_region.step_quality_min - # Should not terminate if we're rejecting step. - state_next.termination_criteria = jnp.logical_and( - accept_flag, state_next.termination_criteria - ) - state_next.vals = jax.tree_map( lambda proposed, current: jnp.where(accept_flag, proposed, current), vals, @@ -401,6 +385,18 @@ def step( ), ) + # Compute termination criteria. + state_next.termination_criteria, state_next.termination_deltas = ( + self.termination._check_convergence( + state, + cost_updated=proposed_cost, + tangent=local_delta, + tangent_ordering=graph.tangent_ordering, + ATb=ATb, + accept_flag=accept_flag, + ) + ) + state_next.iterations += 1 return state_next @@ -441,12 +437,14 @@ def _check_convergence( tangent: jax.Array, tangent_ordering: VarTypeOrdering, ATb: jax.Array, + accept_flag: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: """Check for convergence!""" # Cost tolerance - cost_delta = jnp.abs(cost_updated - state_prev.cost) / state_prev.cost - converged_cost = cost_delta < self.cost_tolerance + cost_absdelta = jnp.abs(cost_updated - state_prev.cost) + cost_reldelta = cost_absdelta / state_prev.cost + converged_cost = cost_reldelta < self.cost_tolerance # Gradient tolerance flat_vals = jax.flatten_util.ravel_pytree(state_prev.vals)[0] @@ -468,11 +466,24 @@ def _check_convergence( ) converged_parameters = param_delta < self.parameter_tolerance - return jnp.array( + # Check termination flags. We'll terminate if any of the conditions are met. + term_flags = jnp.array( [ converged_cost, converged_gradient, converged_parameters, state_prev.iterations >= (self.max_iterations - 1), ] - ), jnp.array([cost_delta, gradient_mag, param_delta]) + ) + + # Only consider the first three conditions if steps are accepted. + if accept_flag is not None: + term_flags = term_flags.at[:3].set( + jnp.logical_and( + term_flags[:3], + # We ignore accept_flag if the cost _actually_ didn't change at all. + jnp.logical_or(accept_flag, cost_absdelta == 0.0), + ) + ) + + return term_flags, jnp.array([cost_reldelta, gradient_mag, param_delta])