Skip to content

Commit

Permalink
Restore mask_position argument name (#1185)
Browse files Browse the repository at this point in the history
This was actually breaking because of a silly bug in keras-core:
keras-team/keras-core#632

We can bring back the original name, though we will need to wait
till the fix is in a release to restore `compute_output_shape`.
  • Loading branch information
mattdangerw authored Jul 31, 2023
1 parent b117fbc commit 1437a3f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
32 changes: 17 additions & 15 deletions keras_nlp/layers/modeling/masked_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self._built = False

if vocabulary_size is None and embedding_weights is None:
raise ValueError(
Expand All @@ -142,7 +141,7 @@ def __init__(
)
self.vocabulary_size = shape[0]

def build(self, inputs_shape, masked_positions_shape=None):
def build(self, inputs_shape):
if self.embedding_weights is not None:
feature_size = self.embedding_weights.shape[-1]
else:
Expand All @@ -157,12 +156,13 @@ def build(self, inputs_shape, masked_positions_shape=None):
self._layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
if masked_positions_shape:
gather_length = masked_positions_shape[1]
shape = (inputs_shape[0], gather_length, inputs_shape[-1])
self._dense.build(shape)
shape = (inputs_shape[0], gather_length, feature_size)
self._layer_norm.build(shape)
# The gather length does not affect any of our built variables, so
# we can pass any value here.
gather_length = None
shape = (inputs_shape[0], gather_length, inputs_shape[-1])
self._dense.build(shape)
shape = (inputs_shape[0], gather_length, feature_size)
self._layer_norm.build(shape)
if self.embedding_weights is None:
self._kernel = self.add_weight(
name="output_kernel",
Expand All @@ -177,10 +177,10 @@ def build(self, inputs_shape, masked_positions_shape=None):
dtype=self.dtype,
)

def call(self, inputs, masked_positions):
def call(self, inputs, mask_positions):
# Gather the encoded tokens at the masked indices.
masked_positions = ops.expand_dims(masked_positions, axis=-1)
x = ops.take_along_axis(inputs, masked_positions, axis=1)
mask_positions = ops.expand_dims(mask_positions, axis=-1)
x = ops.take_along_axis(inputs, mask_positions, axis=1)

# Apply a trainable linear transformation and a layer norm.
x = self._dense(x)
Expand Down Expand Up @@ -221,7 +221,9 @@ def get_config(self):
)
return config

def compute_output_shape(self, inputs_shape, masked_positions_shape):
output_shape = list(masked_positions_shape)
output_shape[-1] = self.vocabulary_size
return tuple(output_shape)
# TODO: restore this after https://github.com/keras-team/keras-core/pull/632
# is in a release!
# def compute_output_shape(self, inputs_shape, mask_positions_shape):
# output_shape = list(mask_positions_shape)
# output_shape[-1] = self.vocabulary_size
# return tuple(output_shape)
8 changes: 4 additions & 4 deletions keras_nlp/layers/modeling/masked_lm_head_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_valid_call(self):
)
encoded_tokens = keras.Input(shape=(10, 16))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(encoded_tokens, masked_positions=positions)
outputs = head(encoded_tokens, mask_positions=positions)
model = keras.Model((encoded_tokens, positions), outputs)

token_data = ops.random.uniform(shape=(4, 10, 16))
Expand All @@ -48,7 +48,7 @@ def test_valid_call_with_embedding_weights(self):
# need to support this in the layer.
sequence = keras.Input(shape=(10, 32))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(sequence, masked_positions=positions)
outputs = head(sequence, mask_positions=positions)
model = keras.Model((sequence, positions), outputs)
sequence_data = ops.random.uniform(shape=(4, 10, 32))
position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5))
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_one_train_step(self):
)
encoded_tokens = keras.Input(shape=(10, 16))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(encoded_tokens, masked_positions=positions)
outputs = head(encoded_tokens, mask_positions=positions)
model = keras.Model((encoded_tokens, positions), outputs)

token_data = ops.random.uniform(shape=(4, 10, 16))
Expand All @@ -126,7 +126,7 @@ def test_saved_model(self):
)
encoded_tokens = keras.Input(shape=(10, 16))
positions = keras.Input(shape=(5,), dtype="int32")
outputs = head(encoded_tokens, masked_positions=positions)
outputs = head(encoded_tokens, mask_positions=positions)
model = keras.Model((encoded_tokens, positions), outputs)

token_data = ops.random.uniform(shape=(4, 10, 16))
Expand Down

0 comments on commit 1437a3f

Please sign in to comment.