Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Chain CRF with SGD method #11

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open

Conversation

chyikwei
Copy link

Hi,

I tried to implement linear chain CRF with SGD method based on these two paper:

C. Sutton, "An Introduction to Conditional Random Fields for Relational Learning"

N. Schraudolph, "Accelerated Training of Conditional Random Fields with Stochastic Gradient Methods"

For performance, I tested it on CoNLL 2000 shared task and got:

Model iteration F1 training F1 testing training time
CRF 5 0.894 0.844 68 mins
Structure Perceptron 10 0.892 0.834 1 min

I think the biggest problem in my implementation is the training time. I will try to work on parallel training or improve the cython code.

Maybe you can review it and give me some feedback.
Any suggestions are welcomed.
Thanks!

@larsmans
Copy link
Owner

Sweet! As for the time, have you tried profiling with kernprof.py?

I suspect calling logsumexp from inside Cython might have to with this, too.

@chyikwei
Copy link
Author

I have not tried it yet. Will do this next week.

@chyikwei
Copy link
Author

Just profiling LinearChainCRF.fit() and here are some results:

  1. computer posterior takes about 20% of total time:

Line: 146, Time Per Hit: 12021.4, % Time: 19.2
post_state, post_trans, ll = _posterior(score, None, b_trans, b_init, b_final)

  1. compute w_update takes 41% of the total time:

Line: 162, Time Per Hit: 26131.2, % Time: 41.7
w_update = lr * (safe_sparse_dot(y_t_i.T, X_i) - safe_sparse_dot(post_state.T, X_i) - (reg * w))

  1. compute objective function takes about 20% of total time:

Line: 149, Time Per Hit: 7178.1, % Time: 11.5
feature_val = np.sum(w_true * w)

Line: 153, Time Per Hit: 6169.9, % Time: 9.9
sum_obj_val += feature_val + trans_val + init_val + final_val - ll - (0.5 * reg * np.sum(w * w))

@chyikwei
Copy link
Author

Saved some time on computing w_update, new profiling result on w_update:

Line: 162, Time Per Hit: 10535.2, % Time: 25.1
w_update = lr * (safe_sparse_dot(y_t_i.T, X_i) - safe_sparse_dot(post_state.T, X_i) - (reg * w))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants