-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing BatchBlock #1192
Introducing BatchBlock #1192
Conversation
Documentation preview |
fc1d25f
to
1582d6a
Compare
1582d6a
to
69d17df
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks good to me! I've just left some minor comments/questions
A dictionary containing all the flattened features, targets, and sequences. | ||
""" | ||
flat_dict: Dict[str, torch.Tensor] = self._flatten() | ||
dummy_tensor = torch.tensor(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we need the dummy_tensor
variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We never use it, but we need to have it in order to keep the type of the Dict[str, torch.Tensor]
. We could store the original value but that might take more memory, so that's why I added the dummy. We only care about the keys of the original inputs.
result = batch.flatten_as_dict(input_batch) | ||
assert len(result) == 9 # input keys are considered | ||
assert ( | ||
len([k for k in result if k.startswith("inputs.")]) == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the difference between inputs.
and features.
keys ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputs
is what went into the batch-transformation, used for the mechanism to restore values of the batch that were not transformed by any of the branches.
|
||
def test_in_parallel(self): | ||
feat, target = torch.tensor([1, 2]), torch.tensor([3, 4]) | ||
outputs = module_utils.module_test( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to how we can now create different preprocessing blocks for each subset of the inputs! I love it!!
Goals ⚽
BatchBlock
will be used inside theModel
to create theBatch
object. It's also useful for things like masking & padding.