-
Notifications
You must be signed in to change notification settings - Fork 381
/
AndroidRealtimePrediction.java
87 lines (73 loc) · 3.29 KB
/
AndroidRealtimePrediction.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
/*
* Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express
* or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
import java.util.HashMap;
import java.util.Map;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.services.machinelearning.AmazonMachineLearningClient;
import com.amazonaws.services.machinelearning.model.EntityStatus;
import com.amazonaws.services.machinelearning.model.GetMLModelRequest;
import com.amazonaws.services.machinelearning.model.GetMLModelResult;
import com.amazonaws.services.machinelearning.model.PredictRequest;
import com.amazonaws.services.machinelearning.model.PredictResult;
import com.amazonaws.services.machinelearning.model.RealtimeEndpointStatus;
/**
* Android code to make realtime predictions from Android
* using Amazon Machine Learning.
*
* Instantiate this class with an mlModelId, and then call
* predict() method with your record.
*/
public class AndroidRealtimePrediction {
// Model id
private final String mlModelId;
// Real-time endpoint for your model
private String endpoint;
// AML Client
private AmazonMachineLearningClient client;
public AndroidRealtimePrediction(String mlModelId, AWSCredentials credentials) {
this.mlModelId = mlModelId;
this.client = new AmazonMachineLearningClient(credentials);
getRealtimeEndpoint(); // look up and cache the realtime endpoint for this model
}
/**
* Calls GetMLModel.
* Checks if the model is completed and real-time endpoint is ready for predict calls
*/
private void getRealtimeEndpoint() {
GetMLModelRequest request = new GetMLModelRequest();
request.setMLModelId(mlModelId);
GetMLModelResult result = client.getMLModel(request);
if (!result.getStatus().equals(EntityStatus.COMPLETED.toString())) {
throw new IllegalStateException("ML model " + mlModelId + " needs to be completed.");
}
if (!result.getEndpointInfo().getEndpointStatus().equals(RealtimeEndpointStatus.READY.toString())) {
throw new IllegalStateException("ML model " + mlModelId + "'s real-time endpoint is not yet ready or needs to be created.");
}
this.endpoint = result.getEndpointInfo().getEndpointUrl();
}
/**
* Once the real-time endpoint is acquired, we can start calling predict for our model
* Pass in a Map with attribute=value pairs. Render numbers as strings.
*/
public PredictResult predict(Map<String, String> record) {
PredictRequest request = new PredictRequest();
request.setMLModelId(mlModelId);
request.setPredictEndpoint(endpoint);
// Populate record with data relevant to the ML model
request.setRecord(record);
PredictResult result = client.predict(request);
return result;
}
}