Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model inference script #2

Closed
CretuCalin opened this issue Jan 19, 2024 · 1 comment
Closed

Model inference script #2

CretuCalin opened this issue Jan 19, 2024 · 1 comment

Comments

@CretuCalin
Copy link

First of all, I really enjoy your paper! Especially the extensive comparison with Chesscog 👍

How can I perform inference with the checkpoint that you provided. Is there any script for this?

@ThanosM97
Copy link
Owner

Hi @CretuCalin ,

Thank you for your feedback!

There isn't a script for single-instance inference in the repository, but you can do it as follows:

import torch
from torchvision import transforms
from torchvision.io import read_image

from train import ChessResNeXt

# Load the checkpoint
model = ChessResNeXt.load_from_checkpoint("path/to/checkpoint.ckpt", hparams_file="path/to/hparams.yaml")

# Read the input image
img = read_image(str("path/to/img")).float()

# Initialize transform
transform = transforms.Compose([
    transforms.Resize(1024, antialias=None),
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.47225544, 0.51124555, 0.55296206],
        std=[0.27787283, 0.27054584, 0.27802786]),
])

# Transform the input image
img = transform(img)

# Add batch dimension
img = img.unsqueeze(dim=0)

# Perform inference
model.eval()
with torch.no_grad():
    preds = model.forward(img)

# Transform predictions to readable format
preds_readable = torch.argmax(preds.reshape((-1, 64, 13)), dim=2)

Then, the variable preds_readable will contain a tensor of shape (1, 64), with each of the 64 cells containing the predicted class_id for the piece (or empty) on the corresponding square.

For the initial state of the chessboard (no moves made), the tensor would look as follows:

tensor([[ 7,  8,  9, 10, 11,  9,  8,  7,  6,  6,  6,  6,  6,  6,  6,  6, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,  0,  0,  0,  0,  0,  0,
          0,  0,  1,  2,  3,  4,  5,  3,  2,  1]])

I hope this will help you.

@ThanosM97 ThanosM97 pinned this issue Jan 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants