Skip to content

Commit

Permalink
fix layer bug and point to updated pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
koenhelwegen committed Sep 8, 2019
1 parent 316f387 commit abae31b
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions larq_zoo/birealnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,17 @@ def residual_block(x, double_filters=False, filters=None):
out = tf.keras.layers.BatchNormalization(momentum=0.8)(out)
out = tf.keras.layers.MaxPool2D(3, strides=2, padding="same")(out)

# layer 2 - 5
# layer 2
out = residual_block(out, filters=args.filters)
for _ in range(1, 5):

# layer 3 - 5
for _ in range(3):
out = residual_block(out)

# layer 6 - 17
for i in range(1, 4):
for _ in range(3):
out = residual_block(out, double_filters=True)
for _ in range(1, 4):
for _ in range(3):
out = residual_block(out)

# layer 18
Expand Down Expand Up @@ -153,16 +155,16 @@ def BiRealNet(
if include_top:
weights_path = utils.download_pretrained_model(
model="birealnet",
version="v0.2.0",
version="v0.3.0",
file="birealnet_weights.h5",
file_hash="e8b29d6204663997dded5629804c0c2e309ec422512a54a17d98802fb39415ec",
file_hash="6e6efac1584fcd60dd024198c87f42eb53b5ec719a5ca1f527e1fe7e8b997117",
)
else:
weights_path = utils.download_pretrained_model(
model="birealnet",
version="v0.2.0",
version="v0.3.0",
file="birealnet_weights_notop.h5",
file_hash="746ff2d2d2b794226e66f0fa3fd0ff19db836df5a9ea9a0f7e59a724e1364757",
file_hash="5148b61c0c2a1094bdef811f68bf4957d5ba5f83ad26437b7a4a6855441ab46b",
)
model.load_weights(weights_path)
elif weights is not None:
Expand Down

0 comments on commit abae31b

Please sign in to comment.