Skip to content

Commit

Permalink
correct hashes, docstr etc
Browse files Browse the repository at this point in the history
  • Loading branch information
koenhelwegen committed Apr 21, 2020
1 parent eefd6d4 commit 169efe2
Showing 1 changed file with 42 additions and 26 deletions.
68 changes: 42 additions & 26 deletions larq_zoo/literature/meliusnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,40 +139,51 @@ def build(self) -> tf.keras.models.Model:
inputs=self.image_input, outputs=x, name=self.name
)

if self.weights == "imagenet":
if self.include_top:
weights_path = self.imagenet_weights_path
else:
weights_path = self.imagenet_no_top_weights_path
model.load_weights(weights_path)
elif self.weights is not None:
model.load_weights(self.weights)

return model


######################
# Concrete factories #
######################


@factory
class MeliusNet22Factory(MeliusNetFactory):
num_blocks = (4, 5, 4, 4)
transition_features = (160, 224, 256, None)
name = "meliusnet22"

def build(self) -> tf.keras.models.Model:
model = super().build()
@property
def imagenet_weights_path(self):
return utils.download_pretrained_model(
model="meliusnet22",
version="v0.1.0",
file="meliusnet22_weights.h5",
file_hash="bb8dda20642508bbe5e0ff95012fec450103c4b23989f4c9c9d853d67b6ff806",
)

# Load weights.
if self.weights == "imagenet":
# Download appropriate file
if self.include_top:
weights_path = utils.download_pretrained_model(
model="meliusnet22",
version="v0.1.0",
file="meliusnet22_weights.h5",
file_hash="51dba19f17023a9c0b92f1a422ece9b987ce55ab6916cc8a8eb31f6b5e09ba8c",
)
else:
weights_path = utils.download_pretrained_model(
model="meliusnet22",
version="v0.1.0",
file="meliusnet22_weights_notop.h5",
file_hash="8c029fadb78d28f1d9baa2940844244975e029f1bdc1153d3f63729acc06be3c",
)
model.load_weights(weights_path)
elif self.weights is not None:
model.load_weights(self.weights)
@property
def imagenet_no_top_weights_path(self):
return utils.download_pretrained_model(
model="meliusnet22",
version="v0.1.0",
file="meliusnet22_weights_notop.h5",
file_hash="9ca867806bff0c2995ff5f1ad085d1627c8dabf12bffbf0bea86eb39ab3cf724",
)

return model

#########################
# Functional interfaces #
#########################


def MeliusNet22(
Expand All @@ -193,6 +204,13 @@ def MeliusNet22(
```plot-altair
/plots/meliusnet22.vg.json
```
```summary
literature.MeliusNet22
```
# ImageNet Metrics
| Top-1 Accuracy | Top-5 Accuracy | Parameters | Memory |
| -------------- | -------------- | ---------- | -------- |
| 62.4 % | 83.9 % | 6 944 584 | 3.88 MiB |
# Arguments
input_shape: Optional shape tuple, to be specified if you would like to use a model
Expand All @@ -213,9 +231,7 @@ def MeliusNet22(
ValueError: in case of invalid argument for `weights`, or invalid input shape.
# References
- [Bi-Real Net: Enhancing the Performance of 1-bit CNNs With Improved
Representational Capability and Advanced Training
Algorithm](https://arxiv.org/abs/1808.00278)
- [MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?](https://arxiv.org/abs/2001.05936)
"""
return MeliusNet22Factory(
include_top=include_top,
Expand Down

0 comments on commit 169efe2

Please sign in to comment.