Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on the Complexity of CKY #99

Open
allanj opened this issue Mar 8, 2021 · 8 comments
Open

Question on the Complexity of CKY #99

allanj opened this issue Mar 8, 2021 · 8 comments

Comments

@allanj
Copy link

allanj commented Mar 8, 2021

  1. By using the GPU, we are able to reduce the complexity of linear-chain CRF from O(NT^2) to O(log N). where N is the sentence length and T is the number of labels.

So, if I view the linear-chain CRF as a specific case of Tree, where the height H = sequence length N, so the complexity can be re-written as O(log H).

  1. Then, in the case of CKY, I can see that the complexity can be reduced to O(log N)/ O(H) by parallel computing. I'm wondering if it can be furthered reduced to O(log H) as well using the parallel scanning algorithm mentioned in the tutorial?
@srush
Copy link
Collaborator

srush commented Mar 8, 2021

So just to unify terminology

The way we compute linear-chain CRF in O(log N) time is by viewing it as a single balanced binary tree of height log N. Each layer of this tree is computed in parallel for a total of O(log N) sequential operations. (Each sequential operation is quite expensive in the process).

For CKY, we need to at least consider all trees of height O(N) since we may have fully right-branching trees. As far as I know there is no way to do better then O(N) serial operations, one for each layer (i.e. width of span).

The other nasty thing about CKY is that the shape of operations changes drastically as you go up the tree. At the bottom layer, there are N spans, with 1 child sizes, whereas at the top layer there is 1 span with N children sizes. This is an extremely non-GPU friendly operation. (Still haven't figured out the ideal way to compute it).

@allanj
Copy link
Author

allanj commented Mar 9, 2021

Thanks for the clarification. Now I understand the situation.

One more question regarding the decoding procedure for Linear-chain CRF.

  1. I figured out how to do argmax/Viterbi in O(log N) in Linear-chain CRF.
  2. But it seems when we want to recover the state, we still need O(N) to do back-tracking to recover the state sequence
  3. What I found in this code repo, seems calculating the edge marginal, but not the global Viterbi. (
    def marginals(self, logpotentials, lengths=None, _raw=False):
    """
    Compute the marginals of a structured model.
    Parameters:
    logpotentials : generic params (see class)
    lengths: None or b long tensor mask
    Returns:
    marginals: b x (N-1) x C x C table
    """
    v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True)
    if _raw:
    all_m = []
    for k in range(v.shape[0]):
    obj = v[k].sum(dim=0)
    marg = torch.autograd.grad(
    obj,
    edges,
    create_graph=True,
    only_inputs=True,
    allow_unused=False,
    )
    all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
    return torch.stack(all_m, dim=0)
    else:
    obj = self.semiring.unconvert(v).sum(dim=0)
    marg = torch.autograd.grad(
    obj, edges, create_graph=True, only_inputs=True, allow_unused=False
    )
    a_m = self._arrange_marginals(marg)
    return self.semiring.unconvert(a_m)
    ). Not sure if I'm correct here.

My question: is it true that we can't avoid the O(N) to do back-tracking?

For example, we use an O(log N) algorithm to obtain the matrix with size (batch_size, sequence_length, label_size). This matrix indicates the best indices from the start tag to the current label.
But after this, it seems I still need O(N) to do back-tracking.

@srush
Copy link
Collaborator

srush commented Mar 9, 2021

You can do O(log N) [parallel] backtracking. The trick in the code is that we never implement the backward / backpointer step. We really on the fact that in pytorch the (sub)gradient used for the max operator is the 1-hot argmax vector. Therefore if you compute the max score then calling .backward will give you the argmax / viterbi sequence. If you look at the code here (https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/distributions.py#L123) you will see that is exactly what it does.

(This trick is really cool, and should be documented much better in the codebase.)

@allanj
Copy link
Author

allanj commented Mar 11, 2021

Thanks. That helps a lot. I think I need some more time to figure out the details of .backward.

What I did at my side at the moment is to implement the complete parallel scan algorithm, for inference and backtracking. (https://github.com/allanj/pytorch_lstmcrf/blob/40f8980dde/src/model/module/fast_linear_crf_inferencer.py)Maybe this is stupid, :(. But just help me better understand the procedure.
I will spend some more time to figure out more on this.

Right now, I think in my back-tracking code, the memory could be larger, because during the parallel scan "forward/backward" pass in the tree, I store a N-length vector to represent the best sequence in each node.

@srush
Copy link
Collaborator

srush commented Mar 11, 2021

Neat! Yeah this should result in the same outcome. My assertion though is that you do not need to write the backward pass manually. It might lead to some speed-ups though in practice.

@srush
Copy link
Collaborator

srush commented Mar 11, 2021

Btw, your repo looks really nice.

If you wanted to build some more transformer-backed, structured prediction models, I would be happy to collaborate. Would be nice to have a single repo with a bert tagger / ner / cky all backed by Hugging Face and HF datasets

@allanj
Copy link
Author

allanj commented Mar 12, 2021

Sure. That would be great. One of my goals is exactly building these models incorporated with the current HF backend. CKY is something that I'm really looking forward to.

Another thing is the general hypergraph framework (from a research perspective), though I think it is still pretty challenging to implement the general framework in PyTorch in practice. Thus, I come to learn from this repo as well.

@srush
Copy link
Collaborator

srush commented Mar 12, 2021

Neat! Me too. General hypergraphs would be really interesting. But I think we would need to write that in CUDA or TVM manually.

The one thing I care a lot about is testing. I want to be sure that when people implement CRF they are really computing the partition function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants