Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
koenhelwegen committed Apr 21, 2020
1 parent b9fa71b commit eefd6d4
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions larq_zoo/literature/meliusnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import tensorflow as tf
from zookeeper import Field, factory

from core import utils
from core.model_factory import ModelFactory
from larq_zoo.core import utils
from larq_zoo.core.model_factory import ModelFactory


@factory
Expand Down Expand Up @@ -117,7 +117,7 @@ def block(self, x):
return self.improvement_block(x)

def build(self) -> tf.keras.models.Model:
x = self.input_tensor
x = self.image_input
x = self.group_stem(x)
for i, (n, f) in enumerate(zip(self.num_blocks, self.transition_features)):
for j in range(n):
Expand All @@ -127,24 +127,18 @@ def build(self) -> tf.keras.models.Model:

x = self.norm(x)
x = self.act(x)
x = utils.global_pool(x)
x = tf.keras.layers.Dense(
self.num_classes, kernel_initializer=self.kernel_initializer
)(x)
x = tf.keras.layers.Activation("softmax", dtype="float32")(x)

if self.include_top:
x = utils.global_pool(x)
x = tf.keras.layers.Dense(
self.num_classes, kernel_initializer=self.kernel_initializer
)(x)
x = tf.keras.layers.Activation("softmax", dtype="float32")(x)

model = tf.keras.models.Model(
inputs=self.input_tensor, outputs=x, name=self.name
inputs=self.image_input, outputs=x, name=self.name
)

if self.weights:
weights = self.weights

if weights.startswith("gs://"):
model.load_weights(utils.get_gcp_weights(weights))
else:
model.load_weights(weights)

return model


Expand All @@ -154,7 +148,7 @@ class MeliusNet22Factory(MeliusNetFactory):
transition_features = (160, 224, 256, None)
name = "meliusnet22"

def build(self):
def build(self) -> tf.keras.models.Model:
model = super().build()

# Load weights.
Expand All @@ -165,18 +159,19 @@ def build(self):
model="meliusnet22",
version="v0.1.0",
file="meliusnet22_weights.h5",
file_hash="", # TODO
file_hash="51dba19f17023a9c0b92f1a422ece9b987ce55ab6916cc8a8eb31f6b5e09ba8c",
)
else:
weights_path = utils.download_pretrained_model(
model="meliusnet22",
version="v0.1.0",
file="meliusnet22_weights_notop.h5",
file_hash="", # TODO
file_hash="8c029fadb78d28f1d9baa2940844244975e029f1bdc1153d3f63729acc06be3c",
)
model.load_weights(weights_path)
elif self.weights is not None:
model.load_weights(self.weights)

return model


Expand Down

0 comments on commit eefd6d4

Please sign in to comment.