Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GDCN implementation #716

Open
an-tran528 opened this issue May 5, 2024 · 2 comments
Open

GDCN implementation #716

an-tran528 opened this issue May 5, 2024 · 2 comments

Comments

@an-tran528
Copy link

I'm trying to search around for the implementation GDCN, an updated version for DCN but seems like it's not yet supported.

I'm trying to tweak the Cross layer implementation by adding gate layers with sigmoid activation:

      self._gate_u = tf.keras.layers.Dense(
          self._projection_dim,
          kernel_initializer=_clone_initializer(self._kernel_initializer),
          kernel_regularizer=self._kernel_regularizer,
          use_bias=False,
          dtype=self.dtype,
      )
      self._gate_v = tf.keras.layers.Dense(
          last_dim,
          kernel_initializer=_clone_initializer(self._kernel_initializer),
          bias_initializer=self._bias_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          use_bias=self._use_bias,
          dtype=self.dtype,
          activation="sigmoid",
      )
    ....
def call:
    return x0 * prod_output + self._gate_v(self._gate_u(x)) + x

But loss doesn't converge for my use case. Is the implementation correct?

@zhangfan555
Copy link

From the paper, it should be "x0 * prod_output * self._gate_v(self._gate_u(x)) + x" ?

@rlcauvin
Copy link

I'd be interested in seeing the full code for the GDCN once you get it working. Hopefully, the correction from @zhangfan555 will make the loss converge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants