-
Notifications
You must be signed in to change notification settings - Fork 15
/
score.py
76 lines (60 loc) · 1.98 KB
/
score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# This script generates the scoring and schema files
# necessary to operationalize your model
from azureml.api.schema.dataTypes import DataTypes
from azureml.api.schema.sampleDefinition import SampleDefinition
from azureml.api.realtime.services import generate_schema
from azureml.assets import get_local_path
# Import frameworks
import pandas as pd
import xgboost
import arcgis
import numpy as np
import pickle
import json
# Prepare the web service definition by authoring
# init() and run() functions. Test the functions
# before deploying the web service.
model = None
wrangler = None
def init():
"""
Initializes the model and any supporting data required.
* Credentials
* Road Static Features
* Data Transfomations
* XGBoost Model File
:return: None
"""
global model, wrangler
# Load model.
with open('wrangler.pkl', 'rb') as fp:
wrangler = pickle.load(fp)
model = xgboost.Booster(model_file='0001.xgbmodel')
def run(input_df):
import json
# Predict using appropriate functions
# prediction = model.predict(input_df)
prediction = "%s %d" % (str(input_df), model)
return json.dumps(str(prediction))
def generate_api_schema():
import os
print("create schema")
df = pd.read_csv("sample.csv")
inputs = {"input_df": SampleDefinition(DataTypes.PANDAS, df)}
os.makedirs('outputs', exist_ok=True)
print(generate_schema(inputs=inputs, filepath="outputs/schema.json", run_func=run))
# Implement test code to run in IDE or Azure ML Workbench
if __name__ == '__main__':
# Import the logger only for Workbench runs
from azureml.logging import get_azureml_logger
logger = get_azureml_logger()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--generate', action='store_true', help='Generate Schema')
args = parser.parse_args()
if args.generate:
generate_api_schema()
init()
input = "{}"
result = run(input)
logger.log("Result",result)