You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In 4-3, it is not a bug but just a question that I dont understand. Why we can do this model.add(Linear(units = 1,input_shape = (2,))) without this parameter in init method "input_shape = (2,)"
classLinear(layers.Layer):
def__init__(self, units=32, **kwargs):
# super(Linear, self).__init__(**kwargs)super().__init__(**kwargs)
self.units=units# The trainable parameters are defined in build method# Since we do not need the input_shape except the build function,# we do not need to store then in the __init__ functiondefbuild(self, input_shape):
self.w=self.add_weight("w",shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True) # Parameter named "w" is compulsory or an error will be thrown outself.b=self.add_weight("b",shape=(self.units,),
initializer='random_normal',
trainable=True)
super().build(input_shape) # Identical to self.built = True# The logic of forward propagation is defined in call method, and is called by __call__ method@tf.functiondefcall(self, inputs):
returntf.matmul(inputs, self.w) +self.b# Use customized get-config method to save the model as h5 format, specifically for the model composed through Functional API with customized Layerdefget_config(self):
config=super().get_config()
config.update({'units': self.units})
returnconfigtf.keras.backend.clear_session()
model=models.Sequential()
# Note: the input_shape here will be modified by the model, so we don't have to fill None in the dimension representing the number of samples.model.add(Linear(units=1,input_shape= (2,)))
print("model.input_shape: ",model.input_shape)
print("model.output_shape: ",model.output_shape)
model.summary()
The text was updated successfully, but these errors were encountered:
In 4-3, it is not a bug but just a question that I dont understand. Why we can do this model.add(Linear(units = 1,input_shape = (2,))) without this parameter in init method "input_shape = (2,)"
The text was updated successfully, but these errors were encountered: