Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ContrastiveOutput #1191

Merged
merged 5 commits into from
Jul 11, 2023
Merged

Adding ContrastiveOutput #1191

merged 5 commits into from
Jul 11, 2023

Conversation

marcromeyn
Copy link
Contributor

@marcromeyn marcromeyn commented Jul 7, 2023

Goals ⚽

This PR builds on the CategoricalOutput and adds ContrastiveOutput. This can be used with dot-product, categorical-prediction or weight-tying.

Implementation Details 🚧

The contrastive part of CategoricalOutput doesn't work with torch-script, this is fine since it's during training only.

@github-actions
Copy link

github-actions bot commented Jul 7, 2023

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1191

@marcromeyn marcromeyn self-assigned this Jul 8, 2023
@marcromeyn marcromeyn added enhancement New feature or request area/pytorch labels Jul 8, 2023
@marcromeyn marcromeyn marked this pull request as ready for review July 8, 2023 11:46
@marcromeyn marcromeyn requested a review from sararb July 10, 2023 13:28
Copy link
Member

@gabrielspmoreira gabrielspmoreira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. Good job!

assert isinstance(dot.to_call, DotProduct)

target = ContrastiveOutput(schema=Schema([item_id_col_schema]))
assert isinstance(target.to_call, CategoricalTarget)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat! Same interface for contrastive for two-tower like architectures and for sampled softmax.

from merlin.models.torch.block import registry


class LogUniformSampler(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcromeyn Have you checked the LogUniformSampler class I created in T4Rec based on this one?
I adds some additional options like returning unique samples and also provides the probs for the items for logQ correction (considering whether returning unique samples or popularity biased).
It matches the implementation of sampling and probs from tf.random.log_uniform_candidate_sampler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not, yours looks much better! I can port that

self.false_negative_score = false_negative_score

@classmethod
def with_weight_tying(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great syntax sugar!

@marcromeyn marcromeyn merged commit 190cd48 into main Jul 11, 2023
37 checks passed
@marcromeyn marcromeyn deleted the torch/contrastive-output branch July 11, 2023 08:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants