diff --git a/larq_zoo/literature/meliusnet.py b/larq_zoo/literature/meliusnet.py index 68a7c814..52e67dda 100644 --- a/larq_zoo/literature/meliusnet.py +++ b/larq_zoo/literature/meliusnet.py @@ -139,40 +139,51 @@ def build(self) -> tf.keras.models.Model: inputs=self.image_input, outputs=x, name=self.name ) + if self.weights == "imagenet": + if self.include_top: + weights_path = self.imagenet_weights_path + else: + weights_path = self.imagenet_no_top_weights_path + model.load_weights(weights_path) + elif self.weights is not None: + model.load_weights(self.weights) + return model +###################### +# Concrete factories # +###################### + + @factory class MeliusNet22Factory(MeliusNetFactory): num_blocks = (4, 5, 4, 4) transition_features = (160, 224, 256, None) name = "meliusnet22" - def build(self) -> tf.keras.models.Model: - model = super().build() + @property + def imagenet_weights_path(self): + return utils.download_pretrained_model( + model="meliusnet22", + version="v0.1.0", + file="meliusnet22_weights.h5", + file_hash="bb8dda20642508bbe5e0ff95012fec450103c4b23989f4c9c9d853d67b6ff806", + ) - # Load weights. - if self.weights == "imagenet": - # Download appropriate file - if self.include_top: - weights_path = utils.download_pretrained_model( - model="meliusnet22", - version="v0.1.0", - file="meliusnet22_weights.h5", - file_hash="51dba19f17023a9c0b92f1a422ece9b987ce55ab6916cc8a8eb31f6b5e09ba8c", - ) - else: - weights_path = utils.download_pretrained_model( - model="meliusnet22", - version="v0.1.0", - file="meliusnet22_weights_notop.h5", - file_hash="8c029fadb78d28f1d9baa2940844244975e029f1bdc1153d3f63729acc06be3c", - ) - model.load_weights(weights_path) - elif self.weights is not None: - model.load_weights(self.weights) + @property + def imagenet_no_top_weights_path(self): + return utils.download_pretrained_model( + model="meliusnet22", + version="v0.1.0", + file="meliusnet22_weights_notop.h5", + file_hash="9ca867806bff0c2995ff5f1ad085d1627c8dabf12bffbf0bea86eb39ab3cf724", + ) - return model + +######################### +# Functional interfaces # +######################### def MeliusNet22( @@ -193,6 +204,13 @@ def MeliusNet22( ```plot-altair /plots/meliusnet22.vg.json ``` + ```summary + literature.MeliusNet22 + ``` + # ImageNet Metrics + | Top-1 Accuracy | Top-5 Accuracy | Parameters | Memory | + | -------------- | -------------- | ---------- | -------- | + | 62.4 % | 83.9 % | 6 944 584 | 3.88 MiB | # Arguments input_shape: Optional shape tuple, to be specified if you would like to use a model @@ -213,9 +231,7 @@ def MeliusNet22( ValueError: in case of invalid argument for `weights`, or invalid input shape. # References - - [Bi-Real Net: Enhancing the Performance of 1-bit CNNs With Improved - Representational Capability and Advanced Training - Algorithm](https://arxiv.org/abs/1808.00278) + - [MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?](https://arxiv.org/abs/2001.05936) """ return MeliusNet22Factory( include_top=include_top,