diff --git a/.travis.yml b/.travis.yml index f01246f..1faa374 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,8 +5,6 @@ python: - 3.6 env: - KERAS_BACKEND=tensorflow - - KERAS_BACKEND=tensorflow TF_KERAS=1 - - KERAS_BACKEND=tensorflow TF_KERAS=1 TF_EAGER=1 - KERAS_BACKEND=tensorflow TF_KERAS=1 TF_2=1 # - KERAS_BACKEND=theano THEANO_FLAGS=optimizer=fast_compile # - KERAS_BACKEND=cntk PYTHONWARNINGS=ignore diff --git a/README.md b/README.md index 1b7e4b4..3b08385 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,6 @@ [![Build Status](https://travis-ci.com/jernsting/keras-drop-connect.svg?branch=master)](https://travis-ci.com/jernsting/keras-drop-connect) ![](https://img.shields.io/badge/keras-tensorflow-blue.svg) -![](https://img.shields.io/badge/keras-tf.keras-blue.svg) -![](https://img.shields.io/badge/keras-tf.keras/eager-blue.svg) ![](https://img.shields.io/badge/keras-tf.keras/2.0.0_beta-blue.svg) diff --git a/keras_drop_connect/wrappers.py b/keras_drop_connect/wrappers.py index 6962986..83a3c9a 100644 --- a/keras_drop_connect/wrappers.py +++ b/keras_drop_connect/wrappers.py @@ -55,17 +55,16 @@ def _dropped_weight(): raise Exception("Unknown name: {}".format(name)) else: pass - for name in names: - name = name.split("/")[1].split(":")[0] # this is not nice, but it works for now with all Versions + for name in dir(self.layer): try: w = getattr(self.layer, name) - origins[name] = w - if 0. < self.rate < 1.: - setattr(self.layer, name, K.in_train_phase(dropped_weight(w, self.rate), w, - training=training)) + if K.is_keras_tensor(w) and w.name in names: + origins[name] = w + if 0. < self.rate < 1.: + setattr(self.layer, name, K.in_train_phase(dropped_weight(w, self.rate), w, + training=training)) except Exception as e: continue - outputs = self.layer.call(inputs, **kwargs) for name, w in origins.items(): setattr(self.layer, name, w)