Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tfrs.metrics.FactorizedTopK #712

Open
soheil-asgari opened this issue Apr 3, 2024 · 12 comments
Open

tfrs.metrics.FactorizedTopK #712

soheil-asgari opened this issue Apr 3, 2024 · 12 comments

Comments

@soheil-asgari
Copy link

hi
when i used this code i seen this error

    self.task = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=(
                movies.batch(128)
                .cache()
                .map(lambda title: (title, self.movie_model(title)))
            )
        )
    )

error:

ValueError: Cannot convert '('c', 'o', 'u', 'n', 't', 'e', 'r')' to a shape. Found invalid entry 'c' of type '<class 'str'>'.
Exception ignored in: <function AtomicFunction.del at 0x7d6fc6688d60>

@Lopera47
Copy link

Lopera47 commented Apr 5, 2024

Hello @soheil-asgari I got the same error and solved it. It is a dependency issue with the latest version of TensorFlow. In my case, in order to solve it, I installed TFRS without dependencies (pip install tensorflow-recommenders --no-deps), because I already had the needed version of tensorflow (2.11.0) and other dependencies.

Here's another similar approach: https://stackoverflow.com/questions/78144515/error-initializing-factorizedtopk-in-tensorflow-recommenders-on-sagemaker-cann

@soheil-asgari
Copy link
Author

Hello @Lopera47
I also installed the version (2.11.0) of TensorFlow, but when I install scann, it automatically updates the version of TensorFlow to the latest version.

@Lopera47
Copy link

Lopera47 commented Apr 5, 2024

@soheil-asgari are you explicitly using Scann?, in my case I'm not. Maybe, you can try installing Scann without dependencies to not make it update TensorFlow.

@rlcauvin
Copy link

rlcauvin commented Apr 6, 2024

This is a bug in the implementation of the tfrs.layers.factorized_top_k module.

In recommenders/tensorflow_recommenders/layers/factorized_top_k.py, we find:

self._counter = self.add_weight("counter", dtype=tf.int32, trainable=False)

Note that the first argument passed to add_weight is the “counter” string.

In the Keras 2.15.0 implementation of add_weight, we find:

def add_weight(
    self,
    name=None,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    constraint=None,
    use_resource=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.VariableAggregation.NONE,
    **kwargs,
):

Note that first argument is the name of the weight variable. Passing “counter” as the name in this manner works fine in Keras 2.15.0, and we won’t get the ValueError: Cannot convert '('c', 'o', 'u', 'n', 't', 'e', 'r')' to a shape error.

But the Keras 3.1.1 implementation of add_weight expects shape as the first argument:

def add_weight(
    self,
    shape=None,
    initializer=None,
    dtype=None,
    trainable=True,
    regularizer=None,
    constraint=None,
    name=None,
):

So when it goes on to check for a valid shape, the shape is set to “counter”, which is not valid, and it results in ValueError: Cannot convert '('c', 'o', 'u', 'n', 't', 'e', 'r')' to a shape.

The bug, therefore, is in TensorFlow Recommenders’ implementation of the tfrs.layers.factorized_top_k module.

@Chm-vinicius
Copy link

I can confirm that is a bug with tensorflow2.16.0 and tensorflow-recommenders0.7.3 with the output error:
Cannot convert '('c', 'o', 'u', 'n', 't', 'e', 'r')' to a shape

I only was able to perform a training by downgrading tensorflow to 2.15.1 version, but i can't save model with signature_inputs, the kernel get in infinty loop, this behavior is refered on this issue

@rlcauvin
Copy link

Seems to me that this issue should be marked as a bug in TensorFlow Recommenders (see my diagnosis of exactly where the bug in the code is) and fixed ASAP. Maybe @soheil-asgari can add the Bug label?

@soheil-asgari
Copy link
Author

Hello @rlcauvin, unfortunately, I cannot add label bugs

@celha
Copy link

celha commented May 14, 2024

I'm facing the same error. Also with tensorflow version 2.15. But only for tensorflow-recommenders version 0.7.2. Any ideas?

@rlcauvin
Copy link

For now, I've worked around the TensorFlow Recommenders bug by including the following code in my notebook before installing any TensorFlow related packages:

import os
os.environ['TF_USE_LEGACY_KERAS'] = '1'

My notebook installs tensorflow-recommenders version 0.7.3 and tensorflow version 2.16.1.

I've been unable to install and import tensorflow-ranking, however, without reverting to a prior version of tensorflow.

@celha
Copy link

celha commented May 15, 2024

Alright! Thank you @rlcauvin :)

@y-71
Copy link

y-71 commented Aug 3, 2024

what worked for me is

python3.7

and my requirements.txt is

Tensorflow==2.10.0
tensorflow_recommenders==0.7.2
tensorflow-datasets
scann

@peterimeokparia
Copy link

peterimeokparia commented Oct 25, 2024

I am still experiencing this problem on the Colab editor and runtime. I am able to define the model but when I go to set up the task:

Define your objectives.

task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
movies.batch(128).map(movie_model)
)

it throws the following exception about the shape not being properly defined in FactorizedTopK.


ValueError Traceback (most recent call last)
in <cell line: 13>()
11
12 # Define your objectives.
---> 13 task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
14 movies.batch(128).map(movie_model)
15 )

5 frames
/usr/local/lib/python3.10/dist-packages/tensorflow_recommenders/metrics/factorized_top_k.py in init(self, candidates, ks, name)
77 if isinstance(candidates, tf.data.Dataset):
78 candidates = (
---> 79 layers.factorized_top_k.Streaming(k=max(ks))
80 .index_from_dataset(candidates)
81 )

/usr/local/lib/python3.10/dist-packages/tensorflow_recommenders/layers/factorized_top_k.py in init(self, query_model, k, handle_incomplete_batches, num_parallel_calls, sorted_order)
374 self._sorted = sorted_order
375
--> 376 self._counter = self.add_weight("counter", dtype=tf.int32, trainable=False)
377
378 def index_from_dataset(

/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py in add_weight(self, shape, initializer, dtype, trainable, autocast, regularizer, constraint, aggregation, name)
520 initializer = initializers.get(initializer)
521 with backend.name_scope(self.name, caller=self):
--> 522 variable = backend.Variable(
523 initializer=initializer,
524 shape=shape,

/usr/local/lib/python3.10/dist-packages/keras/src/backend/common/variables.py in init(self, initializer, shape, dtype, trainable, autocast, aggregation, name)
159 else:
160 if callable(initializer):
--> 161 shape = self._validate_shape(shape)
162 value = initializer(shape, dtype=dtype)
163 else:

/usr/local/lib/python3.10/dist-packages/keras/src/backend/common/variables.py in _validate_shape(self, shape)
182
183 def _validate_shape(self, shape):
--> 184 shape = standardize_shape(shape)
185 if None in shape:
186 raise ValueError(

/usr/local/lib/python3.10/dist-packages/keras/src/backend/common/variables.py in standardize_shape(shape)
548 continue
549 if not is_int_dtype(type(e)):
--> 550 raise ValueError(
551 f"Cannot convert '{shape}' to a shape. "
552 f"Found invalid entry '{e}' of type '{type(e)}'. "

ValueError: Cannot convert '('c', 'o', 'u', 'n', 't', 'e', 'r')' to a shape. Found invalid entry 'c' of type '<class 'str'>'.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants