-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model inference script #2
Comments
Hi @CretuCalin , Thank you for your feedback! There isn't a script for single-instance inference in the repository, but you can do it as follows: import torch
from torchvision import transforms
from torchvision.io import read_image
from train import ChessResNeXt
# Load the checkpoint
model = ChessResNeXt.load_from_checkpoint("path/to/checkpoint.ckpt", hparams_file="path/to/hparams.yaml")
# Read the input image
img = read_image(str("path/to/img")).float()
# Initialize transform
transform = transforms.Compose([
transforms.Resize(1024, antialias=None),
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.47225544, 0.51124555, 0.55296206],
std=[0.27787283, 0.27054584, 0.27802786]),
])
# Transform the input image
img = transform(img)
# Add batch dimension
img = img.unsqueeze(dim=0)
# Perform inference
model.eval()
with torch.no_grad():
preds = model.forward(img)
# Transform predictions to readable format
preds_readable = torch.argmax(preds.reshape((-1, 64, 13)), dim=2) Then, the variable For the initial state of the chessboard (no moves made), the tensor would look as follows: tensor([[ 7, 8, 9, 10, 11, 9, 8, 7, 6, 6, 6, 6, 6, 6, 6, 6, 12, 12,
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0,
0, 0, 1, 2, 3, 4, 5, 3, 2, 1]]) I hope this will help you. |
First of all, I really enjoy your paper! Especially the extensive comparison with Chesscog 👍
How can I perform inference with the checkpoint that you provided. Is there any script for this?
The text was updated successfully, but these errors were encountered: