Skip to content

Commit

Permalink
Don't set quantizers as fields (#149)
Browse files Browse the repository at this point in the history
* Don't set quantizers as fields

* Move quantizer defaults to model factory

* Fix kernel regularizers
  • Loading branch information
lgeiger authored Apr 28, 2020
1 parent 2afdaa3 commit 90b7ed2
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 71 deletions.
13 changes: 4 additions & 9 deletions larq_zoo/core/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from typing import Callable, Optional, Tuple, Union
from typing import Optional, Tuple

import tensorflow as tf
from zookeeper import ComponentField, Field
from zookeeper.tf import Dataset

from larq_zoo.core import utils

QuantizerType = Union[
tf.keras.layers.Layer, Callable[[tf.Tensor], tf.Tensor], str, None
]
ConstraintType = Union[tf.keras.constraints.Constraint, str, None]
DimType = Optional[int]


class ModelFactory:
"""A base class for Larq Zoo models. Defines some common fields."""

# Don't set any defaults here.
input_quantizer: QuantizerType = Field()
kernel_quantizer: QuantizerType = Field()
kernel_constraint: ConstraintType = Field()
input_quantizer = None
kernel_quantizer = None
kernel_constraint = None

# This field is included for automatic inference of `num_clases`, if no
# value is otherwise provided. We set `allow_missing` because we don't want
Expand Down
6 changes: 3 additions & 3 deletions larq_zoo/literature/binary_alex_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class BinaryAlexNetFactory(ModelFactory):

inflation_ratio: int = Field(1)

input_quantizer = Field("ste_sign")
kernel_quantizer = Field("ste_sign")
kernel_constraint = Field("weight_clip")
input_quantizer = "ste_sign"
kernel_quantizer = "ste_sign"
kernel_constraint = "weight_clip"

def conv_block(
self,
Expand Down
6 changes: 3 additions & 3 deletions larq_zoo/literature/birealnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class BiRealNetFactory(ModelFactory):

filters: int = Field(64)

input_quantizer = Field("approx_sign")
kernel_quantizer = Field("magnitude_aware_sign")
kernel_constraint = Field("weight_clip")
input_quantizer = "approx_sign"
kernel_quantizer = "magnitude_aware_sign"
kernel_constraint = "weight_clip"

kernel_initializer: Union[tf.keras.initializers.Initializer, str] = Field(
"glorot_normal"
Expand Down
14 changes: 11 additions & 3 deletions larq_zoo/literature/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@ class BinaryDenseNet(tf.keras.models.Model):
class BinaryDenseNetFactory(ModelFactory):
"""Implementation of [BinaryDenseNet](https://arxiv.org/abs/1906.08637)"""

input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.3))
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.3))
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.3))
@property
def input_quantizer(self):
return lq.quantizers.SteSign(clip_value=1.3)

@property
def kernel_quantizer(self):
return lq.quantizers.SteSign(clip_value=1.3)

@property
def kernel_constraint(self):
return lq.constraints.WeightClip(clip_value=1.3)

initial_filters: int = Field(64)
growth_rate: int = Field(64)
Expand Down
14 changes: 9 additions & 5 deletions larq_zoo/literature/dorefanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ class DoReFaNetFactory(ModelFactory):

activations_k_bit: int = Field(2)

input_quantizer = Field(
lambda self: lq.quantizers.DoReFaQuantizer(k_bit=self.activations_k_bit)
)
kernel_quantizer = Field(lambda: magnitude_aware_sign_unclipped)
kernel_constraint = Field(None)
@property
def input_quantizer(self):
return lq.quantizers.DoReFaQuantizer(k_bit=self.activations_k_bit)

@property
def kernel_quantizer(self):
return magnitude_aware_sign_unclipped

kernel_constraint = None

def conv_block(
self, x, filters, kernel_size, strides=1, pool=False, pool_padding="same"
Expand Down
15 changes: 12 additions & 3 deletions larq_zoo/literature/meliusnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,18 @@ class MeliusNetFactory(ModelFactory):
kernel_initializer: Optional[Union[str, tf.keras.initializers.Initializer]] = Field(
"glorot_normal"
)
input_quantizer = Field(lambda: lq.quantizers.SteSign(1.3))
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(1.3))
kernel_constraint = Field(lambda: lq.constraints.WeightClip(1.3))

@property
def input_quantizer(self):
return lq.quantizers.SteSign(1.3)

@property
def kernel_quantizer(self):
return lq.quantizers.SteSign(1.3)

@property
def kernel_constraint(self):
return lq.constraints.WeightClip(1.3)

def pool(self, x: tf.Tensor, name: str = None) -> tf.Tensor:
return tf.keras.layers.MaxPool2D(2, strides=2, padding="same", name=name)(x)
Expand Down
67 changes: 41 additions & 26 deletions larq_zoo/literature/real_to_bin_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zookeeper import Field, factory

from larq_zoo.core import utils
from larq_zoo.core.model_factory import ModelFactory, QuantizerType
from larq_zoo.core.model_factory import ModelFactory


class _SharedBaseFactory(ModelFactory, metaclass=ABCMeta):
Expand All @@ -21,7 +21,7 @@ class _SharedBaseFactory(ModelFactory, metaclass=ABCMeta):
model_name: str = Field()
momentum: float = Field(0.99)
kernel_initializer: str = Field("glorot_normal")
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = Field(None)
kernel_regularizer = None

def first_block(
self, x: tf.Tensor, use_prelu: bool = True, name: str = ""
Expand Down Expand Up @@ -126,8 +126,8 @@ class StrongBaselineNetFactory(_SharedBaseFactory):

scaling_r: int = 8

input_quantizer: QuantizerType = Field(None)
kernel_quantizer: QuantizerType = Field(None)
input_quantizer = None
kernel_quantizer = None

class LearnedRescaleLayer(tf.keras.layers.Layer):
"""Implements the learned activation rescaling XNOR-Net++ style.
Expand Down Expand Up @@ -359,53 +359,68 @@ def block(
@factory
class StrongBaselineNetBANFactory(StrongBaselineNetFactory):
model_name = Field("baseline_ban")
input_quantizer = Field("ste_sign")
kernel_quantizer = Field(None)
kernel_constraint = Field(None)
kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5))
input_quantizer = "ste_sign"
kernel_quantizer = None
kernel_constraint = None

@property
def kernel_regularizer(self):
return tf.keras.regularizers.l2(1e-5)


@factory
class StrongBaselineNetBNNFactory(StrongBaselineNetFactory):
model_name = Field("baseline_bnn")
input_quantizer = Field("ste_sign")
kernel_quantizer = Field("ste_sign")
kernel_constraint = Field("weight_clip")
input_quantizer = "ste_sign"
kernel_quantizer = "ste_sign"
kernel_constraint = "weight_clip"


@factory
class RealToBinNetFPFactory(RealToBinNetFactory):
model_name = Field("r2b_fp")
input_quantizer = Field(lambda: tf.keras.layers.Activation("tanh"))
kernel_quantizer = Field(None)
kernel_constraint = Field(None)
kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5))
kernel_quantizer = None
kernel_constraint = None

@property
def input_quantizer(self):
return tf.keras.layers.Activation("tanh")

@property
def kernel_regularizer(self):
return tf.keras.regularizers.l2(1e-5)


@factory
class RealToBinNetBANFactory(RealToBinNetFactory):
model_name = Field("r2b_ban")
input_quantizer = Field("ste_sign")
kernel_quantizer = Field(None)
kernel_constraint = Field(None)
kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5))
input_quantizer = "ste_sign"
kernel_quantizer = None
kernel_constraint = None

@property
def kernel_regularizer(self):
return tf.keras.regularizers.l2(1e-5)


@factory
class RealToBinNetBNNFactory(RealToBinNetFactory):
model_name = Field("r2b_bnn")
input_quantizer = Field("ste_sign")
kernel_quantizer = Field("ste_sign")
kernel_constraint = Field("weight_clip")
input_quantizer = "ste_sign"
kernel_quantizer = "ste_sign"
kernel_constraint = "weight_clip"


@factory
class ResNet18FPFactory(ResNet18Factory):
model_name = Field("resnet_fp")
input_quantizer = Field(None)
kernel_quantizer = Field(None)
kernel_constraint = Field(None)
kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5))
input_quantizer = None
kernel_quantizer = None
kernel_constraint = None

@property
def kernel_regularizer(self):
return tf.keras.regularizers.l2(1e-5)


def RealToBinaryNet(
Expand Down
14 changes: 11 additions & 3 deletions larq_zoo/literature/resnet_e.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ class BinaryResNetE18Factory(ModelFactory):
num_layers: int = Field(18)
initial_filters: int = Field(64)

input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.25))
@property
def input_quantizer(self):
return lq.quantizers.SteSign(clip_value=1.25)

@property
def kernel_quantizer(self):
return lq.quantizers.SteSign(clip_value=1.25)

@property
def kernel_constraint(self):
return lq.constraints.WeightClip(clip_value=1.25)

@property
def spec(self):
Expand Down
12 changes: 4 additions & 8 deletions larq_zoo/literature/xnornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import larq as lq
import tensorflow as tf
from zookeeper import Field, factory
from zookeeper import factory

from larq_zoo.core import utils
from larq_zoo.core.model_factory import ModelFactory
Expand All @@ -24,13 +24,9 @@ def xnor_weight_scale(x):
class XNORNetFactory(ModelFactory):
"""Implementation of [XNOR-Net](https://arxiv.org/abs/1603.05279)"""

input_quantizer = Field("ste_sign")
kernel_quantizer = Field("xnor_weight_scale")
kernel_constraint = Field("weight_clip")

kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = Field(
lambda: tf.keras.regularizers.l2(5e-7)
)
input_quantizer = "ste_sign"
kernel_quantizer = "xnor_weight_scale"
kernel_constraint = "weight_clip"

@property
def kernel_regularizer(self):
Expand Down
14 changes: 11 additions & 3 deletions larq_zoo/sota/quicknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,17 @@ class QuickNetBaseFactory(ModelFactory, abc.ABC):
transition_block: Callable[..., tf.Tensor] = Field()
stem_filters: int = Field(64)

input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.25))
@property
def input_quantizer(self):
return lq.quantizers.SteSign(clip_value=1.25)

@property
def kernel_quantizer(self):
return lq.quantizers.SteSign(clip_value=1.25)

@property
def kernel_constraint(self):
return lq.constraints.WeightClip(clip_value=1.25)

def __post_configure__(self):
assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tensorflow as tf
from zookeeper import ComponentField, Field, factory

from larq_zoo.core.model_factory import ConstraintType, ModelFactory, QuantizerType
from larq_zoo.core.model_factory import ModelFactory


class AttentionMatchingLossLayer(tf.keras.layers.Layer):
Expand Down Expand Up @@ -296,10 +296,6 @@ class TeacherStudentModelFactory(ModelFactory):
teacher_model: tf.keras.models.Model = ComponentField(allow_missing=True)
student_model: tf.keras.models.Model = ComponentField()

input_quantizer: QuantizerType = Field(allow_missing=True)
kernel_quantizer: QuantizerType = Field(allow_missing=True)
kernel_constraint: ConstraintType = Field(allow_missing=True)

# Must be set if there is a teacher and allow_missing teacher weights is not True.
# Either a full path or the name of a network (in which case it will be sought in the current `model_dir`).
initialize_teacher_weights_from: str = Field(allow_missing=True)
Expand Down

0 comments on commit 90b7ed2

Please sign in to comment.