You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey guys, I was trying to use haiku to create a convolutional neural network architecture (to reproduce a paper, whose implementation was in TF2.0). The CNN works correctly, however, when I use jax.test_util.check_grads, there seems to be an error. The code is as follows:
import jax.numpy as np
import jax
from functools import partial
class CNNParameterization(hk.Module):
def __init__(self):
super().__init__()
self.layers = self._build_layers()
def _build_layers(self):
activation = jax.nn.leaky_relu
Nx = 64
Ny = 64
total_resize = onp.prod((1, 2, 2, 2, 1 ))
h = Nx // total_resize
w = Ny // total_resize
layers = []
self.latent_params = hk.get_parameter(
"beta", shape=(128, ),
init=hk.initializers.RandomNormal())
dense_output_size = 32*w*h
dense_init = hk.initializers.Orthogonal(
scale=1.0*onp.sqrt(onp.max(
(dense_output_size/128, 1)
)))
dense_layer = hk.Linear(output_size=dense_output_size,
name='dense_layer',
w_init=dense_init)
layers.append(dense_layer)
# Reshape preserves batch dimension
layers.append(hk.Reshape((h, w, 32),
name="reshape"))
counter = 0
for resize, conv_filters in zip((1, 2, 2, 2, 1), \
(32, 16, 8, 4, 1)):
layers.append(activation)
layers.append(hk.Conv2D(output_channels=conv_filters,
kernel_shape=(5, 5),
padding='SAME',
name='conv_layer',
w_init=hk.initializers.VarianceScaling()))
counter += 1
return layers
def __call__(self, model_input: jax.Array = None):
"""Forward pass.
The model input is unused.
"""
del model_input
x = self.latent_params
for layer_no, layer in enumerate(self.layers):
if layer_no == 1: # Only for reshaping layer
x = layer(x.reshape((1, ) + layer.output_shape))
else:
x = layer(x)
x = x.ravel()
return x
# Test the gradients
def mapping_fn(x):
result = CNNParameterization()(x)
return result
model_input = np.ones((100, 3))
forward_pass_pure = hk.without_apply_rng(
hk.transform_with_state(mapping_fn))
init_params, init_state = forward_pass_pure.init(x=model_input,
rng=rng_key)
forward_func = jax.jit(forward_pass_pure.apply)
def dummy_func(params, state, x):
return forward_func(params, state, x)[0].mean()
check_grads(dummy_func, (init_params, init_state, model_input),
order=2, eps=1e-4)
The error is :
AssertionError:
Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.04407668
Max relative difference: 0.01132658
x: array(-3.847362, dtype=float32)
y: array(-3.891438, dtype=float32)
The text was updated successfully, but these errors were encountered:
Hey guys, I was trying to use haiku to create a convolutional neural network architecture (to reproduce a paper, whose implementation was in TF2.0). The CNN works correctly, however, when I use
jax.test_util.check_grads
, there seems to be an error. The code is as follows:The error is :
The text was updated successfully, but these errors were encountered: