Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply_gradients AttributeError: 'ResourceVariable' object has no attribute 'overwrite_with_gradient' #20517

Closed
andrewl36 opened this issue Nov 19, 2024 · 5 comments · Fixed by #20534
Assignees
Labels

Comments

@andrewl36
Copy link

When I have a mix of tf.Variable and KerasVariables I get the following error:

--> 632     if v.overwrite_with_gradient:
    633         if self.gradient_accumulation_steps:
    634             # Utilize a stateless manner for JAX compatibility
    635             steps = self.gradient_accumulation_steps

AttributeError: 'ResourceVariable' object has no attribute 'overwrite_with_gradient'

I suspect this is because my list of variables is [KerasVariables] + [tf.Variables]
and the following line only checks the first in the list as to whether overwrite_with_gradient can be used?

if not hasattr(vars[0], "overwrite_with_gradient"):

@sachinprasadhs
Copy link
Collaborator

Could you please provide some sample reproducible script to replicate the reported behavior. Thanks!

@andrewl36
Copy link
Author

Could you please provide some sample reproducible script to replicate the reported behavior. Thanks!

@sachinprasadhs Sure, please try this cut down simple example showing the problem:

import tensorflow as tf
from tensorflow import keras 

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        # Keras model Layers
        self.hidden_layers = [tf.keras.layers.Dense(32, activation='tanh') for _ in range(2)]
        self.output_layer = tf.keras.layers.Dense(1)
        
        # Custom variable
        self.my_var = tf.Variable(0.1, trainable=True, dtype=tf.float32, name="my_var")

    def call(self, inputs):
        x = inputs
        for layer in self.hidden_layers:
            x = layer(x)
        return self.output_layer(x)
    
data = np.array([
    [0.0,    10.4],
    [900.0,  21.1],
    [3900.0, 64.2],
]
)

model   = MyModel()
inputs  = data[:, 0:1]
outputs = data[:, 1:]

epochs = 1000
learning_rate = 0.005
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

for epoch in range(epochs):
    with tf.GradientTape() as tp:
        y_pred = model(inputs)
        loss   = tf.reduce_mean(tf.square((outputs - y_pred)))
    
    params = model.trainable_variables + [model.my_var]
        
    gradients = tp.gradient(loss, params)
    optimizer.apply_gradients(zip(gradients, params))
    del tp

@james77777778
Copy link
Contributor

james77777778 commented Nov 22, 2024

@andrewl36
Try this one:

import numpy as np
import tensorflow as tf

import keras


class MyModel(keras.Model):
    def __init__(self):
        super().__init__()

        # Keras model Layers
        self.hidden_layers = [
            keras.layers.Dense(32, activation="tanh") for _ in range(2)
        ]
        self.output_layer = keras.layers.Dense(1)

        # Custom variable
        self.my_var = self.add_weight(shape=(), dtype="float32", name="my_var")
        self.my_var.assign(0.1)

    def call(self, inputs):
        x = inputs
        for layer in self.hidden_layers:
            x = layer(x)
        return self.output_layer(x)


data = np.array(
    [
        [0.0, 10.4],
        [900.0, 21.1],
        [3900.0, 64.2],
    ]
)

model = MyModel()
inputs = data[:, 0:1]
outputs = data[:, 1:]
epochs = 1000
learning_rate = 0.005
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

pbar = keras.utils.Progbar(epochs)
for epoch in range(epochs):
    with tf.GradientTape() as tp:
        y_pred = model(inputs)
        loss = tf.reduce_mean(tf.square((outputs - y_pred)))
    gradients = tp.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    values = {"loss": loss.numpy()}
    pbar.add(1, values.items())

The key is to use self.add_weight for creating custom variables.

EDITED:
I have submitted a PR to fix this.

@andrewl36
Copy link
Author

@james77777778 thank you, yes that does work now, cheers

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants