forked from aws/amazon-sagemaker-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mars.R
92 lines (68 loc) · 3.06 KB
/
mars.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
# Bring in library that contains multivariate adaptive regression splines (MARS)
library(mda)
# Bring in library that allows parsing of JSON training parameters
library(jsonlite)
# Bring in library for prediction server
library(plumber)
# Setup parameters
# Container directories
prefix <- '/opt/ml'
input_path <- paste(prefix, 'input/data', sep='/')
output_path <- paste(prefix, 'output', sep='/')
model_path <- paste(prefix, 'model', sep='/')
param_path <- paste(prefix, 'input/config/hyperparameters.json', sep='/')
# Channel holding training data
channel_name = 'train'
training_path <- paste(input_path, channel_name, sep='/')
# Setup training function
train <- function() {
# Read in hyperparameters
training_params <- read_json(param_path)
target <- training_params$target
if (!is.null(training_params$degree)) {
degree <- as.numeric(training_params$degree)}
else {
degree <- 2}
if (!is.null(training_params$thresh)) {
thresh <- as.numeric(training_params$thresh)}
else {
thresh <- 0.001}
if (!is.null(training_params$prune)) {
prune <- as.logical(training_params$prune)}
else {
prune <- TRUE}
# Bring in data
training_files = list.files(path=training_path, full.names=TRUE)
training_data = do.call(rbind, lapply(training_files, read.csv))
# Convert to model matrix
training_X <- model.matrix(~., training_data[, colnames(training_data) != target])
# Save factor levels for scoring
factor_levels <- lapply(training_data[, sapply(training_data, is.factor), drop=FALSE],
function(x) {levels(x)})
# Run multivariate adaptive regression splines algorithm
model <- mars(x=training_X, y=training_data[, target], degree=degree, thresh=thresh, prune=prune)
# Generate outputs
mars_model <- model[!(names(model) %in% c('x', 'residuals', 'fitted.values'))]
attributes(mars_model)$class <- 'mars'
save(mars_model, factor_levels, file=paste(model_path, 'mars_model.RData', sep='/'))
print(summary(mars_model))
print(paste('gcv:', mars_model$gcv))
print(paste('mse:', sum((model$fitted.values - training_data[, target]) ** 2)))
write('success', file=paste(output_path, 'success', sep='/'))}
# Setup scoring function
serve <- function() {
app <- plumb(paste(prefix, 'plumber.R', sep='/'))
app$run(host='0.0.0.0', port=8080)}
# Run at start-up
args <- commandArgs()
if (any(grepl('train', args))) {
train()}
if (any(grepl('serve', args))) {
serve()}