Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default initialisations can produce nan loss #279

Open
andreasgrv opened this issue Oct 8, 2024 · 2 comments
Open

Default initialisations can produce nan loss #279

andreasgrv opened this issue Oct 8, 2024 · 2 comments
Labels
documentation Improvements or additions to documentation

Comments

@andreasgrv
Copy link
Collaborator

Code to reproduce:

import random
import numpy as np
import torch

from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

from cirkit.symbolic.circuit import Circuit
from cirkit.templates.region_graph import RandomBinaryTree
from cirkit.symbolic.layers import CategoricalLayer
from cirkit.templates.circuit_templates._factories import name_to_parameter_factory, name_to_initializer
from cirkit.pipeline import compile


NUM_INPUT_UNITS = 64
NUM_SUM_UNITS = 64
PIXEL_RANGE=255

# Load the MNIST data set and data loaders
transform = transforms.Compose([
    transforms.ToTensor(),
    # Set pixel values in the [0-255] range
    transforms.Lambda(lambda x: (PIXEL_RANGE * x).long())
])


def define_circuit_from_rg(rg):

    # Here is where Overparametrisation comes in
    input_factory = lambda x, y, z: CategoricalLayer(scope=x,
                                                     num_categories=PIXEL_RANGE+1,
                                                     num_channels=1, # These are grayscale images
                                                     num_output_units=NUM_INPUT_UNITS # Overparametrisation
                                                    )

    ### =========== With init below model trains fine ===================================
    #  sum_weight_init = name_to_initializer('normal')
    #  sum_weight_params = name_to_parameter_factory('softmax', initializer=sum_weight_init)
    ### ========== but if no init - as below, we get nan loss ===========================
    sum_weight_params = None   # This line leads to nan loss
    
    circuit = Circuit.from_region_graph(rg,
                                        input_factory=input_factory,
                                        sum_weight_factory= sum_weight_params,
                                        num_sum_units=NUM_SUM_UNITS,
                                        sum_product='cp')
    return circuit


def train_circuit(cc):

    # Set some seeds
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    # torch.cuda.manual_seed(42)

    # Set the torch device to use
    device = torch.device('cuda')

    # Compile the circuit
    circuit = compile(cc)

    # Move the circuit to chosen device
    circuit = circuit.to(device)

    num_epochs = 5
    step_idx = 0
    running_loss = 0.0


    # Initialize a torch optimizer of your choice,
    #  e.g., Adam, by passing the parameters of the circuit
    optimizer = optim.Adam(circuit.parameters(), lr=0.01)

    for epoch_idx in range(num_epochs):
        for i, (batch, _) in enumerate(train_dataloader):
            # The circuit expects an input of shape (batch_dim, num_channels, num_variables),
            # so we unsqueeze a dimension for the channel.
            BS = batch.shape[0]
            batch = batch.view(BS, 1, -1).to(device)

            # Compute the log-likelihoods of the batch, by evaluating the circuit
            log_likelihoods = circuit(batch)

            # We take the negated average log-likelihood as loss
            loss = -torch.mean(log_likelihoods)
            loss.backward()
            # Update the parameters of the circuits, as any other model in PyTorch
            optimizer.step()
            optimizer.zero_grad()
            running_loss += loss.detach() * len(batch)
            step_idx += 1
            if step_idx % 100 == 0:
                print(f"Step {step_idx}: Average NLL: {running_loss / (100 * len(batch)):.3f}")
                running_loss = 0.0


data_train = datasets.MNIST('datasets', train=True, download=True, transform=transform)
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)

# We can also specify depth and number of repetitions
# depth=None means maximum possible
rnd = RandomBinaryTree(28*28, depth=None, num_repetitions=1)

circuit = define_circuit_from_rg(rnd)

train_circuit(circuit)

In the above code when the sum weight parameterisation is not specified, the result is a loss of nan during training.
This may be confusing for somebody not familiar with the internals of the library - is there a way to avoid this?

@andreasgrv
Copy link
Collaborator Author

Output for

    sum_weight_params = None   # This line leads to nan loss
    
    circuit = Circuit.from_region_graph(rg,
                                        input_factory=input_factory,
                                        sum_weight_factory= sum_weight_params,
                                        num_sum_units=NUM_SUM_UNITS,
                                        sum_product='cp')

python example.py
Step 100: Average NLL: nan
Step 200: Average NLL: nan
Step 300: Average NLL: nan

On the other hand, if:

    sum_weight_init = name_to_initializer('normal')
    sum_weight_params = name_to_parameter_factory('softmax', initializer=sum_weight_init)

    
    circuit = Circuit.from_region_graph(rg,
                                        input_factory=input_factory,
                                        sum_weight_factory= sum_weight_params,
                                        num_sum_units=NUM_SUM_UNITS,
                                        sum_product='cp')

python example.py
Step 100: Average NLL: 3422.423
Step 200: Average NLL: 1614.733
Step 300: Average NLL: 1013.035

@lkct
Copy link
Member

lkct commented Oct 9, 2024

This is due to sum weights being inited to Normal by default, but they are expected to be positive in "common" circuits, and negative values generate nan in log-sum-exp.

However we also have many projects using negative weights (with sum-product or complex-lse-sum semiring), so it makes sense to use Normal init.

This may be confusing for somebody not familiar with the internals of the library - is there a way to avoid this?

Considering this, I would agree to change the default init for sum.

But in any way, we should properly doc the default init for layers and tell the users when they should NOT rely on the default.

@lkct lkct added the documentation Improvements or additions to documentation label Oct 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants