From 90b7ed2d54c2417bfce28b197615a555e12891b8 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 28 Apr 2020 14:04:40 +0100 Subject: [PATCH] Don't set quantizers as fields (#149) * Don't set quantizers as fields * Move quantizer defaults to model factory * Fix kernel regularizers --- larq_zoo/core/model_factory.py | 13 ++-- larq_zoo/literature/binary_alex_net.py | 6 +- larq_zoo/literature/birealnet.py | 6 +- larq_zoo/literature/densenet.py | 14 +++- larq_zoo/literature/dorefanet.py | 14 ++-- larq_zoo/literature/meliusnet.py | 15 ++++- larq_zoo/literature/real_to_bin_nets.py | 67 ++++++++++++------- larq_zoo/literature/resnet_e.py | 14 +++- larq_zoo/literature/xnornet.py | 12 ++-- larq_zoo/sota/quicknet.py | 14 +++- .../knowledge_distillation.py | 6 +- 11 files changed, 110 insertions(+), 71 deletions(-) diff --git a/larq_zoo/core/model_factory.py b/larq_zoo/core/model_factory.py index ca334a48..79f7c470 100644 --- a/larq_zoo/core/model_factory.py +++ b/larq_zoo/core/model_factory.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple import tensorflow as tf from zookeeper import ComponentField, Field @@ -6,20 +6,15 @@ from larq_zoo.core import utils -QuantizerType = Union[ - tf.keras.layers.Layer, Callable[[tf.Tensor], tf.Tensor], str, None -] -ConstraintType = Union[tf.keras.constraints.Constraint, str, None] DimType = Optional[int] class ModelFactory: """A base class for Larq Zoo models. Defines some common fields.""" - # Don't set any defaults here. - input_quantizer: QuantizerType = Field() - kernel_quantizer: QuantizerType = Field() - kernel_constraint: ConstraintType = Field() + input_quantizer = None + kernel_quantizer = None + kernel_constraint = None # This field is included for automatic inference of `num_clases`, if no # value is otherwise provided. We set `allow_missing` because we don't want diff --git a/larq_zoo/literature/binary_alex_net.py b/larq_zoo/literature/binary_alex_net.py index 96ee57b4..3cf6a328 100644 --- a/larq_zoo/literature/binary_alex_net.py +++ b/larq_zoo/literature/binary_alex_net.py @@ -17,9 +17,9 @@ class BinaryAlexNetFactory(ModelFactory): inflation_ratio: int = Field(1) - input_quantizer = Field("ste_sign") - kernel_quantizer = Field("ste_sign") - kernel_constraint = Field("weight_clip") + input_quantizer = "ste_sign" + kernel_quantizer = "ste_sign" + kernel_constraint = "weight_clip" def conv_block( self, diff --git a/larq_zoo/literature/birealnet.py b/larq_zoo/literature/birealnet.py index cffbd6e6..aebea155 100644 --- a/larq_zoo/literature/birealnet.py +++ b/larq_zoo/literature/birealnet.py @@ -14,9 +14,9 @@ class BiRealNetFactory(ModelFactory): filters: int = Field(64) - input_quantizer = Field("approx_sign") - kernel_quantizer = Field("magnitude_aware_sign") - kernel_constraint = Field("weight_clip") + input_quantizer = "approx_sign" + kernel_quantizer = "magnitude_aware_sign" + kernel_constraint = "weight_clip" kernel_initializer: Union[tf.keras.initializers.Initializer, str] = Field( "glorot_normal" diff --git a/larq_zoo/literature/densenet.py b/larq_zoo/literature/densenet.py index 49dc266a..3e82d5cb 100644 --- a/larq_zoo/literature/densenet.py +++ b/larq_zoo/literature/densenet.py @@ -21,9 +21,17 @@ class BinaryDenseNet(tf.keras.models.Model): class BinaryDenseNetFactory(ModelFactory): """Implementation of [BinaryDenseNet](https://arxiv.org/abs/1906.08637)""" - input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.3)) - kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.3)) - kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.3)) + @property + def input_quantizer(self): + return lq.quantizers.SteSign(clip_value=1.3) + + @property + def kernel_quantizer(self): + return lq.quantizers.SteSign(clip_value=1.3) + + @property + def kernel_constraint(self): + return lq.constraints.WeightClip(clip_value=1.3) initial_filters: int = Field(64) growth_rate: int = Field(64) diff --git a/larq_zoo/literature/dorefanet.py b/larq_zoo/literature/dorefanet.py index cfe204d9..8463b75a 100644 --- a/larq_zoo/literature/dorefanet.py +++ b/larq_zoo/literature/dorefanet.py @@ -38,11 +38,15 @@ class DoReFaNetFactory(ModelFactory): activations_k_bit: int = Field(2) - input_quantizer = Field( - lambda self: lq.quantizers.DoReFaQuantizer(k_bit=self.activations_k_bit) - ) - kernel_quantizer = Field(lambda: magnitude_aware_sign_unclipped) - kernel_constraint = Field(None) + @property + def input_quantizer(self): + return lq.quantizers.DoReFaQuantizer(k_bit=self.activations_k_bit) + + @property + def kernel_quantizer(self): + return magnitude_aware_sign_unclipped + + kernel_constraint = None def conv_block( self, x, filters, kernel_size, strides=1, pool=False, pool_padding="same" diff --git a/larq_zoo/literature/meliusnet.py b/larq_zoo/literature/meliusnet.py index dd691965..6b6151cb 100644 --- a/larq_zoo/literature/meliusnet.py +++ b/larq_zoo/literature/meliusnet.py @@ -26,9 +26,18 @@ class MeliusNetFactory(ModelFactory): kernel_initializer: Optional[Union[str, tf.keras.initializers.Initializer]] = Field( "glorot_normal" ) - input_quantizer = Field(lambda: lq.quantizers.SteSign(1.3)) - kernel_quantizer = Field(lambda: lq.quantizers.SteSign(1.3)) - kernel_constraint = Field(lambda: lq.constraints.WeightClip(1.3)) + + @property + def input_quantizer(self): + return lq.quantizers.SteSign(1.3) + + @property + def kernel_quantizer(self): + return lq.quantizers.SteSign(1.3) + + @property + def kernel_constraint(self): + return lq.constraints.WeightClip(1.3) def pool(self, x: tf.Tensor, name: str = None) -> tf.Tensor: return tf.keras.layers.MaxPool2D(2, strides=2, padding="same", name=name)(x) diff --git a/larq_zoo/literature/real_to_bin_nets.py b/larq_zoo/literature/real_to_bin_nets.py index be1e6b7d..d41869ca 100644 --- a/larq_zoo/literature/real_to_bin_nets.py +++ b/larq_zoo/literature/real_to_bin_nets.py @@ -11,7 +11,7 @@ from zookeeper import Field, factory from larq_zoo.core import utils -from larq_zoo.core.model_factory import ModelFactory, QuantizerType +from larq_zoo.core.model_factory import ModelFactory class _SharedBaseFactory(ModelFactory, metaclass=ABCMeta): @@ -21,7 +21,7 @@ class _SharedBaseFactory(ModelFactory, metaclass=ABCMeta): model_name: str = Field() momentum: float = Field(0.99) kernel_initializer: str = Field("glorot_normal") - kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = Field(None) + kernel_regularizer = None def first_block( self, x: tf.Tensor, use_prelu: bool = True, name: str = "" @@ -126,8 +126,8 @@ class StrongBaselineNetFactory(_SharedBaseFactory): scaling_r: int = 8 - input_quantizer: QuantizerType = Field(None) - kernel_quantizer: QuantizerType = Field(None) + input_quantizer = None + kernel_quantizer = None class LearnedRescaleLayer(tf.keras.layers.Layer): """Implements the learned activation rescaling XNOR-Net++ style. @@ -359,53 +359,68 @@ def block( @factory class StrongBaselineNetBANFactory(StrongBaselineNetFactory): model_name = Field("baseline_ban") - input_quantizer = Field("ste_sign") - kernel_quantizer = Field(None) - kernel_constraint = Field(None) - kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5)) + input_quantizer = "ste_sign" + kernel_quantizer = None + kernel_constraint = None + + @property + def kernel_regularizer(self): + return tf.keras.regularizers.l2(1e-5) @factory class StrongBaselineNetBNNFactory(StrongBaselineNetFactory): model_name = Field("baseline_bnn") - input_quantizer = Field("ste_sign") - kernel_quantizer = Field("ste_sign") - kernel_constraint = Field("weight_clip") + input_quantizer = "ste_sign" + kernel_quantizer = "ste_sign" + kernel_constraint = "weight_clip" @factory class RealToBinNetFPFactory(RealToBinNetFactory): model_name = Field("r2b_fp") - input_quantizer = Field(lambda: tf.keras.layers.Activation("tanh")) - kernel_quantizer = Field(None) - kernel_constraint = Field(None) - kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5)) + kernel_quantizer = None + kernel_constraint = None + + @property + def input_quantizer(self): + return tf.keras.layers.Activation("tanh") + + @property + def kernel_regularizer(self): + return tf.keras.regularizers.l2(1e-5) @factory class RealToBinNetBANFactory(RealToBinNetFactory): model_name = Field("r2b_ban") - input_quantizer = Field("ste_sign") - kernel_quantizer = Field(None) - kernel_constraint = Field(None) - kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5)) + input_quantizer = "ste_sign" + kernel_quantizer = None + kernel_constraint = None + + @property + def kernel_regularizer(self): + return tf.keras.regularizers.l2(1e-5) @factory class RealToBinNetBNNFactory(RealToBinNetFactory): model_name = Field("r2b_bnn") - input_quantizer = Field("ste_sign") - kernel_quantizer = Field("ste_sign") - kernel_constraint = Field("weight_clip") + input_quantizer = "ste_sign" + kernel_quantizer = "ste_sign" + kernel_constraint = "weight_clip" @factory class ResNet18FPFactory(ResNet18Factory): model_name = Field("resnet_fp") - input_quantizer = Field(None) - kernel_quantizer = Field(None) - kernel_constraint = Field(None) - kernel_regularizer = Field(lambda: tf.keras.regularizers.l2(1e-5)) + input_quantizer = None + kernel_quantizer = None + kernel_constraint = None + + @property + def kernel_regularizer(self): + return tf.keras.regularizers.l2(1e-5) def RealToBinaryNet( diff --git a/larq_zoo/literature/resnet_e.py b/larq_zoo/literature/resnet_e.py index 93b82c37..974e0ce8 100644 --- a/larq_zoo/literature/resnet_e.py +++ b/larq_zoo/literature/resnet_e.py @@ -15,9 +15,17 @@ class BinaryResNetE18Factory(ModelFactory): num_layers: int = Field(18) initial_filters: int = Field(64) - input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25)) - kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25)) - kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.25)) + @property + def input_quantizer(self): + return lq.quantizers.SteSign(clip_value=1.25) + + @property + def kernel_quantizer(self): + return lq.quantizers.SteSign(clip_value=1.25) + + @property + def kernel_constraint(self): + return lq.constraints.WeightClip(clip_value=1.25) @property def spec(self): diff --git a/larq_zoo/literature/xnornet.py b/larq_zoo/literature/xnornet.py index 27381a5c..76cc0467 100644 --- a/larq_zoo/literature/xnornet.py +++ b/larq_zoo/literature/xnornet.py @@ -2,7 +2,7 @@ import larq as lq import tensorflow as tf -from zookeeper import Field, factory +from zookeeper import factory from larq_zoo.core import utils from larq_zoo.core.model_factory import ModelFactory @@ -24,13 +24,9 @@ def xnor_weight_scale(x): class XNORNetFactory(ModelFactory): """Implementation of [XNOR-Net](https://arxiv.org/abs/1603.05279)""" - input_quantizer = Field("ste_sign") - kernel_quantizer = Field("xnor_weight_scale") - kernel_constraint = Field("weight_clip") - - kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = Field( - lambda: tf.keras.regularizers.l2(5e-7) - ) + input_quantizer = "ste_sign" + kernel_quantizer = "xnor_weight_scale" + kernel_constraint = "weight_clip" @property def kernel_regularizer(self): diff --git a/larq_zoo/sota/quicknet.py b/larq_zoo/sota/quicknet.py index 6edaca29..d7830a56 100644 --- a/larq_zoo/sota/quicknet.py +++ b/larq_zoo/sota/quicknet.py @@ -65,9 +65,17 @@ class QuickNetBaseFactory(ModelFactory, abc.ABC): transition_block: Callable[..., tf.Tensor] = Field() stem_filters: int = Field(64) - input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25)) - kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25)) - kernel_constraint = Field(lambda: lq.constraints.WeightClip(clip_value=1.25)) + @property + def input_quantizer(self): + return lq.quantizers.SteSign(clip_value=1.25) + + @property + def kernel_quantizer(self): + return lq.quantizers.SteSign(clip_value=1.25) + + @property + def kernel_constraint(self): + return lq.constraints.WeightClip(clip_value=1.25) def __post_configure__(self): assert ( diff --git a/larq_zoo/training/knowledge_distillation/knowledge_distillation.py b/larq_zoo/training/knowledge_distillation/knowledge_distillation.py index 5a7ac765..9591ef0a 100644 --- a/larq_zoo/training/knowledge_distillation/knowledge_distillation.py +++ b/larq_zoo/training/knowledge_distillation/knowledge_distillation.py @@ -4,7 +4,7 @@ import tensorflow as tf from zookeeper import ComponentField, Field, factory -from larq_zoo.core.model_factory import ConstraintType, ModelFactory, QuantizerType +from larq_zoo.core.model_factory import ModelFactory class AttentionMatchingLossLayer(tf.keras.layers.Layer): @@ -296,10 +296,6 @@ class TeacherStudentModelFactory(ModelFactory): teacher_model: tf.keras.models.Model = ComponentField(allow_missing=True) student_model: tf.keras.models.Model = ComponentField() - input_quantizer: QuantizerType = Field(allow_missing=True) - kernel_quantizer: QuantizerType = Field(allow_missing=True) - kernel_constraint: ConstraintType = Field(allow_missing=True) - # Must be set if there is a teacher and allow_missing teacher weights is not True. # Either a full path or the name of a network (in which case it will be sought in the current `model_dir`). initialize_teacher_weights_from: str = Field(allow_missing=True)