Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For dcn model #6

Open
jiarenyf opened this issue May 3, 2018 · 5 comments
Open

For dcn model #6

jiarenyf opened this issue May 3, 2018 · 5 comments

Comments

@jiarenyf
Copy link

jiarenyf commented May 3, 2018

In

x_l = torch.sum(x_0 * x_l, 1).view([-1,1]) * getattr(self,'cross_weight_'+str(i+1)).view([1,-1]) + getattr(self,'cross_bias_'+str(i+1)) + x_l
, x_0 * x_l should be replaced by torch.matmul(x_0, x_l.t()), right ?

@nzc
Copy link
Owner

nzc commented May 4, 2018

The x_l's size should be batch_size * [field_size * embedding_size],in DCN paper,x_l's size should be [filed_sizeembedding], so result of x_0x_l^T is rank-one. Thinking of expanding to batch_size [field_sizeembedding_size], the result's size should be batch_size1. If use torch.matmul(x_0, x_l.t()), the size is wrong.

@nzc
Copy link
Owner

nzc commented May 4, 2018

@jiarenyf

@jiarenyf
Copy link
Author

jiarenyf commented May 6, 2018

The shapes of x0, xl, wl, bl are all [fieldSzie*embeddingSize, 1] (ignoring the batch_axis), and the formulation of xl should be xl = matmul(matmul(x0, xl.T), wl) + bl + xl, as shown in the following image:
image ...

I am not familiar with pytorch, but in mxnet I use batch_dot to implement the calculation of xl, as in here ...

@nzc
Copy link
Owner

nzc commented May 6, 2018

@jiarenyf The x_0 in my code is two-dimension . And I do the same thing as batch_dot in my code

@jiarenyf
Copy link
Author

jiarenyf commented May 7, 2018

But the result of x0*xl^T is not rank-one, it should be [batchSize, fieldSize*embeddingSize, fieldSize*embeddingSize], and the shape of x0*xl^T*wl is [batchSize, fieldSize*embeddingSize] ... Here I use * to represent matmul ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants