Skip to content

Commit

Permalink
Fix/simplify tutorial.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Aug 24, 2024
1 parent 0e488b9 commit c387209
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tutorials/StochasticGradientDescent.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -95,32 +95,32 @@ p_opt2 = stochastic_gradient_descent(M, gradf, p0)

This result is reasonably close. But we can improve it by using a `DirectionUpdateRule`, namely:

On the one hand [`MomentumGradient`](https://manoptjl.org/stable/solvers/gradient_descent.html#Manopt.MomentumGradient), which requires both the manifold and the initial value, to keep track of the iterate and parallel transport the last direction to the current iterate.
On the one hand [`MomentumGradient`](@ref), which requires both the manifold and the initial value, to keep track of the iterate and parallel transport the last direction to the current iterate.
The necessary `vector_transport_method` keyword is set to a suitable default on every manifold,
see ``[`default_vector_transport_method`](@extref `ManifoldsBase.default_vector_transport_method-Tuple{AbstractManifold}`)``{=commonmark}. We get
"""

```{julia}
p_opt3 = stochastic_gradient_descent(
M, gradf, p0; direction=MomentumGradient(M, p0; direction=StochasticGradient(M))
M, gradf, p0; direction=MomentumGradient(; direction=StochasticGradient())
)
```

```{julia}
MG = MomentumGradient(M, p0; direction=StochasticGradient(M));
@benchmark stochastic_gradient_descent($M, $gradf, $p0; direction=$MG)
MG = MomentumGradient(; direction=StochasticGradient());
@benchmark stochastic_gradient_descent($M, $gradf, p=$p0; direction=$MG)
```

And on the other hand the [`AverageGradient`](https://manoptjl.org/stable/solvers/gradient_descent.html#Manopt.AverageGradient) computes an average of the last `n` gradients. This is done by

```{julia}
p_opt4 = stochastic_gradient_descent(
M, gradf, p0; direction=AverageGradient(M, p0; n=10, direction=StochasticGradient(M)), debug=[],
M, gradf, p0; direction=AverageGradient(; n=10, direction=StochasticGradient()), debug=[],
)
```
```{julia}
AG = AverageGradient(M, p0; n=10, direction=StochasticGradient(M));
@benchmark stochastic_gradient_descent($M, $gradf, $p0; direction=$AG, debug=[])
AG = AverageGradient(; n=10, direction=StochasticGradient(M));
@benchmark stochastic_gradient_descent($M, $gradf, p=$p0; direction=$AG, debug=[])
```

Note that the default `StoppingCriterion` is a fixed number of iterations which helps the comparison here.
Expand Down

0 comments on commit c387209

Please sign in to comment.