-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
73 lines (56 loc) · 2.35 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Disable tensorflow debugging output
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import sys
from helpers import preprocess, plot
import numpy as np
import requests
from mlserver.types import InferenceRequest
from mlserver.codecs import NumpyCodec
import tensorflow_datasets as tfds
import tensorflow as tf
# Inference variables
if len(sys.argv) < 2:
sys.exit("Please provide inference mode (--local or --remote)")
if sys.argv[1] == "--local":
inference_url = 'http://localhost:8080/v2/models/cassava/infer'
elif sys.argv[1] == "--remote":
inference_url = 'http://localhost:8080/seldon/default/cassava/v2/models/infer'
else:
sys.exit("Please provide inference mode (--local or --remote)")
batch_size = 16
# Load the dataset and class names
print("Lodaing dataset...")
dataset, info = tfds.load('cassava', with_info=True)
class_names = info.features['label'].names + ['unknown']
# Shuffle the dataset with a buffer size equal to the number of examples in the 'validation' split
validation_dataset = dataset['validation']
buffer_size = info.splits['validation'].num_examples
shuffled_validation_dataset = validation_dataset.shuffle(buffer_size)
# Select a batch of examples from the validation dataset
batch = shuffled_validation_dataset.map(preprocess).batch(batch_size).as_numpy_iterator()
examples = next(batch)
# Convert the TensorFlow tensor to a numpy array
input_data = np.array(examples['image'])
# Build the inference request
inference_request = InferenceRequest(
inputs=[
NumpyCodec.encode_input(name="payload", payload=input_data)
]
)
# Send the inference request and capture response
print("Sending Inference Request...")
res = requests.post(inference_url, json=inference_request.dict())
print("Got Response...")
# Parse the JSON string into a Python dictionary
response_dict = res.json()
# Extract the data array and shape from the output, assuming only one output or the target output is at index 0
data_list = response_dict["outputs"][0]["data"]
data_shape = response_dict["outputs"][0]["shape"]
# Convert the data list to a numpy array and reshape it
data_array = np.array(data_list).reshape(data_shape)
print("Predictions:", data_array)
# Convert the numpy array to tf tensor
data_tensor = tf.convert_to_tensor(np.squeeze(data_array), dtype=tf.int64)
# Plot the examples with their predictions
plot(examples, class_names, data_tensor)