Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditional and joint span probabilities in TreeCRF #107

Open
rubencart opened this issue Sep 27, 2021 · 5 comments
Open

Conditional and joint span probabilities in TreeCRF #107

rubencart opened this issue Sep 27, 2021 · 5 comments

Comments

@rubencart
Copy link

I want to use the TreeCRF class to learn latent tree distributions for constituency trees for sentences. I noticed you can easily obtain the text span marginals with .marginals. However, I am interested in computing more probabilities in the tree distribution, like the conditional probability that one span occurs in the tree, given that another one occurs, or the joint probability of two spans. Is there an easy way to compute these probabilities from the marginals? Or using different torch-struct functionality?

A 'dirty' trick for the conditional probability could be to compute the marginals again, with the potential of the span you want to condition on set to a very high value? The new marginals would then actually be conditional probabilities? But that requires running the parsing algorithm once per condition, which ideally I would like to avoid.

@srush
Copy link
Collaborator

srush commented Sep 27, 2021

What a great question...

If there are two specific spans that you need, then your "dirty trick" is the right way to do it.

If you want to do it efficiently for any pair of spans, then there are some fun auto-diff tricks you can use. If I remember correctly, the hessian of the log-partition (with respect to the log-potentials) will give you the joint of all pairs of spans. I don't think this is currently implemented in the library, but wouldn't be that hard to add.

maybe look at https://pytorch.org/docs/stable/generated/torch.autograd.functional.hessian.html

@rubencart
Copy link
Author

rubencart commented Sep 27, 2021

Thank you :-). I am more interested in an efficient way for any pair of spans.
The hessian trick sounds interesting, could you perhaps point me to a reference that explains this relation? And/or relations between other auto-diff tricks and other probabilities (like for the conditionals)?

@srush
Copy link
Collaborator

srush commented Sep 27, 2021

I think this is a nice reference for bayes nets https://dl.acm.org/doi/pdf/10.1145/765568.765570

image

Alternatively you can think of CRF as exponential families and therefore the log-partition generates moments:

https://www.cs.cmu.edu/~epxing/Class/10708-14/scribe_notes/scribe_note_lecture6.pdf

I can't find a nice reference though to explain the hessian, but you can derive it from differentiating the log-partition twice \log \sum_i \exp (l_i) with respect to l_j and l_k (you should end up with a term that has the sum over all structure with part j and k in the numerator and the partition function in the denominator.)

If you are feeling brave it is also in this paper.

https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00391/102843/Efficient-Computation-of-Expectations-under

@rubencart
Copy link
Author

That's super interesting, thank you for your help!

I am going to try to use this for my project, but would you like me to make an attempt for a PR to add it to this library as well? In that case, do you have any pointers on where to add it?

Some last questions :-) :
The derivations in the last paper are for distributions of spanning trees in graphs, defined by edge weights, but I think they remain valid for distributions of constituency trees, if you just replace all edge weights by span weights, right? (Not the efficient calculations using the Matrix-Tree Theorem though, but the relations between partial derivatives and expectations).

Also, am I right to deduce that 3rd order partial derivatives of the partition function then give joint probabilities of 3 edges (assuming the 2nd order derivatives are differentiable)? Unless I'm making a mistake, this is easy to prove in the same way as proving the relation between the first and second order derivatives and the marginals and joint of 2 edges resp.?

@srush
Copy link
Collaborator

srush commented Sep 30, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants