Skip to content

Iterative NUTS

Du Phan edited this page May 30, 2019 · 1 revision

Du Phan, Neeraj Pradhan


In this note, we would like to present how we convert the recursive nature of NUTS to an iterative one, which integrates nicely with JAX compiling mechanism. Two main references which we use for NUTS sampler are:

  1. The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo, arxiv
    Matthew D. Hoffman, Andrew Gelman
  2. A Conceptual Introduction to Hamiltonian Monte Carlo, arxiv
    Michael Betancourt

Recursive NUTS

While for HMC, we know the number of leapfrog/verlet steps at the beginning of each trajectory, we don't know how far (or how many number of leapfrog steps) we will go with NUTS. We will keep moving as far as we can until a turning condition happens (indeed, there are two more conditions: we are going too far and the energy is diverging, but we will skip the discussion about those conditions because they are easy to check). While moving, we are constructing a binary tree which keeps track of all information we need: information of the left leaf, right leaf, and a proposal for the next sampling step. The following figure from reference [1] nicely illustrates this process.

binary-tree

The algorithm works as follows:

  • We start with a node called a basetree.
  • Keep doubling the tree until turning happens:
    • Choose a direction: forward (or right) or backward (or left).
    • Build a subtree (by following that direction) with the same depth = d as the current tree by recursion, where the base case of this recursion process is to build a node (or a subtree with depth = 0):
      1. Recursively build the first half (with depth = d - 1) of subtree.
      2. If the first half is turning, we stop. Otherwise, recursively build the second half with depth = d - 1.
      3. This subtree is turning if: the first half is turning, or the second half is turning, or the combined one is turning.

By using recursion together with a neat transition kernel (which governs the probability to accept a proposal from the build process, see more at section A.3 of reference 2) from authors of NUTS, the memory requirement is O(d), where d is the depth of subtree. This is very effective given that to build that subtree, we have moved by 2^d steps.

Iterative NUTS

In this section, we will illustrate how to build a subtree with depth 4 iteratively (instead of using recursion as above). This means we want to construct 16 nodes (numbered from 0 -> 15) for the tree. Instead of recursively build subtrees 0 -> 7 and 8 -> 15, we iteratively go straight from 0 to 15 but will stop when turning condition happens. The trickiest parts are:

  1. Decide a stopping condition which is equivalent to the recursive algorithm,
  2. Maintain the memory effectiveness of the recursion.

For example, at node 3, we need to check the turning conditions for the following trees: 0 -> 3, 2 -> 3. At node 7, we check the turning conditions for the following trees: 0 -> 7, 4 -> 7, 6 -> 7. At node 12, we don't need to check the turning condition. But at node 13, we need to check the turning condition of the tree 12->13. It is helpful to draw the binary tree and track down the process.

subtree

The first case requires to store information at nodes 0, 2. The second case requires to store information at nodes 0, 4, 6. The forth case requires to store information at node 12. The number of nodes which we need to store will dynamically change when we move. But the maximum number will be 4, which is the depth of the subtree we need to build. This maximum is attained at node 15, where we need to check the turning conditions for trees: 0 -> 15, 8 -> 15, 12 -> 15, 14 -> 15.

First, we create a storage R[4] to store information. Then, the whole process works as follows

  • Step 0: R[0] = node_0
  • Step 1: check turning condition of R[0] (or node_0) and node_1
  • Step 2: R[1] = node_2
  • Step 3: check turning condition of R[0] and node_3, R[1] and node_3
  • Step 4: R[1] = node_4 (we update at index 1 because node_2 is no longer needed for further process)
  • Step 5: check turning condition of R[1] and node_5 (though it is reasonable, we won't check turning condition of R[0] and node_5; the reason is to to make it equivalent to recursive algorithm)
  • Step 6: R[2] = node_6
  • Step 7: check turning condition of R[0] and node_7, R[1] and node_7, R[2] and node_7
  • Step 8: R[1] = node_8
  • Step 9: check turning condition of R[1] and node_9
  • Step 10: R[2] = node_10
  • Step 11: check turning condition of R[1] and node_11, R[2] and node_11
  • Step 12: R[2] = node_12
  • Step 13: check turning condition of R[2] and node_13
  • Step 14: R[3] = node_14
  • Step 15: check turning condition of R[0] and node_15, R[1] and node_15, R[2] and node_15, R[3] and node_15

In summary, at even steps, we update the storage and at odd steps, we verify the turning conditions. If the turning conditions are matched, we stop. Otherwise, we go to the next step.

Now, we'll discuss two technical points of iterative scheme:

  • At even steps, which index of the memory we need to update?
  • At odd steps, which portion of the memory we need to check for turning conditions with the current node?

Though it seems a bit tricky, things will be easier to see when we look at the binary representation of node indices. The binary representation of node 7 is 111, which corresponds to the right leaves of trees with depths 1, 2, 3. The binary representation of node 12 is 1100, which corresponds to the left leaf of a tree with depth 1, left leaf of a tree with depth 2, right leaf of a tree with depth 3, right leaf of a tree with depth 4. In summary, 1 at position i (counted from right to left in binary representation) corresponds to the right leaf of a tree with depth i; 0 at position i corresponds to the left leaf.

Even step: The mapping from node index (presented in binary representation) to memory index is: 0=0 -> 0, 2=10 -> 1, 4=100 -> 1, 6=110 -> 2, 8=1000 -> 1, 10=1010 -> 2, 12=1100 -> 2, 14=1110 -> 3. We can see that the memory index we need to update is the number of 1 in the binary form of node index. So we can use bitcount algorithm to calculate the node index: R_idx = bitcount(node_idx).

Odd step: We need to decide idx_min and idx_max of the memory so that we'll check the turning conditions for all indices from idx_min and idx_max w.r.t. the current node.

  • idx_max is the node index which we have updated in the previous even step. So assume we are at node 9=1001, idx_max = bitcount(1001 - 1) = bitcount(100) = 1. In other words, we count the number of 1 except the last one in binary form.
  • Instead of calculating idx_min directly, we will calculate the number of indices we need to check turning conditions, which is to compute num_idxs = idx_max - idx_min + 1. The mapping from node index (presented in binary representation) to num_idxs is: 1=1 -> 1, 3=11 -> 2, 5=101 -> 1, 7=111 -> 3, 9=1001 -> 1, 11=1011 -> 2, 13=1101 -> 1, 15=1111 -> 4. We can see that num_idxs is the number of contiguous last one bits of the binary form (which is the contiguous right leaves of subtrees in the whole binary tree).

Conclusion

By resolving the above two technical points, we have succeeded in converting recursive nature of NUTS to an iterative one. The iterative scheme allows more control on the process (e.g. it might allow adding more stopping conditions to the scheme). But that's for future research. The biggest advantage right now is it allows to use [JAX] to compile the computation of whole trajectory, so the overhead cost will be hugely reduced. We also want to note that recursion has been one of technical issues of having a NUTS implementation in graph (v.s. eager) mode, which is mentioned in Simple, Distributed, and Accelerated Probabilistic Programming paper. We also hope that our solution will help researchers in HMC investigate more cool features of NUTS.

Clone this wiki locally