Skip to content

Latest commit

 

History

History
88 lines (72 loc) · 3.17 KB

inference-api.md

File metadata and controls

88 lines (72 loc) · 3.17 KB

Gematria inference API

This document describes the APIs for inference with trained Gematria models.

Command-line inference API

The module gematria.model.python.main_function provides an inference mode where the binary reads a .tfrecord file where each record contains a single BasicBlockWithThroughputProto in the serialized proto format. The output is written in the same format and preserving the order of the samples to another file.

Model binaries using this module support inference automatically. The required flags to run inference are:

  • --gematria_action=predict: required to run the model in batch inference mode.
  • --gematria_input_file={filename}: The path to the input .tfrecord file.
  • --gematria_output_file={filename}: The path to the output .tfrecord file.
  • --gematria_checkpoint_file={checkpoint}: The path to a TensorFlow checkpoint that contains the trained model used for inference.

In addition to these flags, you must also provide the parameters of the model in model-specific flags with the same values as those used to train the model.

Example command-line:

$ bazel run -c opt \
    //gematria/granite/python:run_granite_model \
    -- \
    --gematria_action=predict \
    --gematria_input_file=/tmp/input.tfrecord \
    --gematria_output_file=/tmp/output.tfrecord \
    --gematria_tokens_file=/tmp/tokens.txt \
    --gematria_checkpoint_file=/tmp/granite_model/model.ckpt-10000

Python inference API

Python code can interact directly with the Gematria model class, without going through a .tfrecord file or. Gematria models based on the gematria.model.python.main_function.ModelBase class all provide a Predict method that takes a list of BasicBlockWithThroughputProto and returns a list of the same protos with the predictions added to them.

Example code using the Python API:

import tensorflow.compat.v1 as tf

from gematria.basic_block.python import tokens
from gematria.granite.python import token_graph_builder_model
from gematria.model.python import options

_INPUT_BLOCKS = []     # Replace with a list of BasicBlockWithThroughputProtos.
_CHECKPOINT_FILE = ''  # Replace with a path to the TensorFlow checkpoint.

_MODEL_TOKENS = []     # Replace with a list of tokens used for training the model.

model = token_graph_builder_model.TokenGraphBuilderModel(
    tokens=_MODEL_TOKENS,
    dtype=tf.dtypes.float32,
    immediate_token=tokens.IMMEDIATE,
    fp_immediate_token=tokens.IMMEDIATE,
    address_token=tokens.ADDRESS,
    memory_token=tokens.MEMORY,
    node_embedding_size=256,
    edge_embedding_size=256,
    global_embedding_size=256,
    node_update_layers=(256, 256),
    edge_update_layers=(256, 256),
    global_update_layers=(256, 256),
    readout_layers=(256, 256),
    task_readout_layers=(256, 256),
    num_message_passing_iterations=8,
    loss_type=options.LossType.MEAN_SQUARED_ERROR,
    loss_normalization=options.ErrorNormalization.PERCENTAGE_ERROR
)
model.Initialize()
with tf.Session() as sess:
  saver = tf.train.Saver()
  saver.restore(sess, _CHECKPOINT_FILE)
  output_blocks = model.Predict(sess, _INPUT_BLOCKS)