Skip to content
This repository has been archived by the owner on Nov 29, 2024. It is now read-only.

Commit

Permalink
remove requestPredictionIntervals flag from scoring request;ensure ou…
Browse files Browse the repository at this point in the history
…tput schema consistent with score response
  • Loading branch information
jackjii79 committed Sep 7, 2023
1 parent d7e79c9 commit 7347396
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 140 deletions.
21 changes: 12 additions & 9 deletions .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,22 @@ jobs:

# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
- if: matrix.language == 'python'
name: Autobuild
uses: github/codeql-action/autobuild@v2

# ℹ️ Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
# ℹ️ Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl

# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
# and modify them (or add more) to build your code if your project
# uses a compiled language
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
# and modify them (or add more) to build your code if your project
# uses a compiled language

#- run: |
# make bootstrap
# make release
- if: matrix.language == 'java'
name: Build Java
uses: gradle/gradle-build-action@v2
with:
arguments: assemble

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
9 changes: 2 additions & 7 deletions common/swagger/v1/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,6 @@ definitions:
type: array
items:
$ref: '#/definitions/Row'
requestPredictionIntervals:
type: boolean
description: |
If set to `true`, the scorer will try to fill field `predictionIntervals` in response if it is supported.
ScoreResponse:
type: object
properties:
Expand Down Expand Up @@ -326,9 +322,8 @@ definitions:
type: object
description: >
Prediction interval consist of an array of interval bound names
and rows of array of bounds per bound name. Setting `requestPredictionIntervals`
to true will enable populating the field. The field will be empty or an error
response returned if prediction intervals are not returned or supported by the model.
and rows of array of bounds per bound name. The field will be empty
if not supported by the model.
$ref: '#/definitions/PredictionInterval'
DataField:
type: object
Expand Down
9 changes: 2 additions & 7 deletions common/swagger/v1openapi3/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,6 @@ components:
An array of rows consisting the actual input data for scoring, one scoring request per row.
items:
$ref: '#/components/schemas/Row'
requestPredictionIntervals:
type: boolean
description: |
If set to `true`, the scorer will try to fill field `predictionIntervals` in response if it is supported.
ScoreResponse:
type: object
properties:
Expand Down Expand Up @@ -390,9 +386,8 @@ components:
type: object
description: >-
Prediction interval consist of an array of interval bound names
and rows of array of bounds per bound name. Setting `requestPredictionIntervals`
to true will enable populating the field. The field will be empty or an error
response returned if prediction intervals are not returned or supported by the model.
and rows of array of bounds per bound name. The field will be empty
if not supported by the model.
properties:
fields:
$ref: '#/components/schemas/Row'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ai.h2o.mojos.deploy.common.rest.model.ScoreRequest;
import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -19,28 +20,30 @@
import java.util.UUID;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Converts the resulting predicted {@link MojoFrame} into the API response object {@link
* ScoreResponse}.
*/
public class MojoFrameToScoreResponseConverter
implements BiFunction<MojoFrame, ScoreRequest, ScoreResponse> {
private static final Logger log =
LoggerFactory.getLogger(MojoFrameToScoreResponseConverter.class);

// If true then pipeline support prediction interval, otherwise false.
// Note: assumption is that pipeline supports Prediction interval.
// However for some h2o3 model, even classification model may still set
// this to be true.
private Boolean supportPredictionInterval;
private List<String> outputFieldNames;

public MojoFrameToScoreResponseConverter(boolean supportPredictionInterval) {
/**
* Converts the resulting predicted {@link MojoFrame} into the API response object {@link
* ScoreResponse}.
*/
public MojoFrameToScoreResponseConverter(
boolean supportPredictionInterval, List<String> outputFieldNames) {
this.supportPredictionInterval = supportPredictionInterval;
this.outputFieldNames = outputFieldNames;
}

public MojoFrameToScoreResponseConverter() {
Expand All @@ -56,6 +59,14 @@ public MojoFrameToScoreResponseConverter() {
@Override
public ScoreResponse apply(
MojoFrame mojoFrame, ScoreRequest scoreRequest) {
Preconditions.checkArgument(
new HashSet<>(Arrays.asList(mojoFrame.getColumnNames()))
.containsAll(getTargetFields(mojoFrame)),
String.format(
"MOJO response frame columns [%s] does not contain all requested output fields [%s]",
String.join(",", mojoFrame.getColumnNames()), String.join(",", getTargetFields(mojoFrame))
)
);
Set<String> includedFields = getSetOfIncludedFields(scoreRequest);
List<Row> outputRows =
Stream.generate(Row::new).limit(mojoFrame.getNrows()).collect(Collectors.toList());
Expand All @@ -66,11 +77,11 @@ public ScoreResponse apply(
response.setScore(outputRows);

if (!Boolean.TRUE.equals(scoreRequest.isNoFieldNamesInOutput())) {
List<String> outputFieldNames = getFilteredInputFieldNames(scoreRequest, includedFields);
outputFieldNames.addAll(getTargetField(mojoFrame));
response.setFields(outputFieldNames);
List<String> outputNames = getFilteredInputFieldNames(scoreRequest, includedFields);
outputNames.addAll(getTargetFields(mojoFrame));
response.setFields(outputNames);
}
fillWithPredictionInterval(mojoFrame, scoreRequest, response);
fillWithPredictionInterval(mojoFrame, response);
return response;
}

Expand All @@ -90,28 +101,20 @@ private void fillOutputRows(
}

/**
* Populate Prediction Interval value into response field.
* Only when score request set requestPredictionIntervals be true
* and MOJO pipeline support prediction interval.
* Populate Prediction Interval value into response field if possible.
*/
private void fillWithPredictionInterval(
MojoFrame mojoFrame, ScoreRequest scoreRequest, ScoreResponse scoreResponse) {
if (Boolean.TRUE.equals(scoreRequest.isRequestPredictionIntervals())) {
if (!supportPredictionInterval) {
throw new IllegalStateException(
"Unexpected error, prediction interval should be supported, but actually not");
MojoFrame mojoFrame, ScoreResponse scoreResponse) {
if (supportPredictionInterval && mojoFrame.getNcols() > 1) {
int targetIdx = getTargetColIdx(getTargetFields(mojoFrame));
// Need to ensure target column is singular (regression).
if (targetIdx >= 0) {
PredictionInterval predictionInterval =
new PredictionInterval().fields(new Row()).rows(Collections.emptyList());
predictionInterval.setFields(getPredictionIntervalFields(mojoFrame, targetIdx));
predictionInterval.setRows(getPredictionIntervalRows(mojoFrame, targetIdx));
scoreResponse.setPredictionIntervals(predictionInterval);
}
PredictionInterval predictionInterval =
new PredictionInterval().fields(new Row()).rows(Collections.emptyList());
if (mojoFrame.getNcols() > 1) {
int targetIdx = getTargetColIdx(Arrays.asList(mojoFrame.getColumnNames()));
// Need to ensure target column is singular (regression).
if (targetIdx >= 0) {
predictionInterval.setFields(getPredictionIntervalFields(mojoFrame, targetIdx));
predictionInterval.setRows(getPredictionIntervalRows(mojoFrame, targetIdx));
}
}
scoreResponse.setPredictionIntervals(predictionInterval);
}
}

Expand Down Expand Up @@ -141,25 +144,12 @@ private List<Row> getTargetRows(MojoFrame mojoFrame) {
* column from MOJO frame, otherwise all columns names
* will be extracted.
*/
private List<String> getTargetField(
private List<String> getTargetFields(
MojoFrame mojoFrame) {
if (mojoFrame.getNcols() > 0) {
List<String> targetColumns = Arrays.asList(mojoFrame.getColumnNames());
if (supportPredictionInterval) {
int targetIdx = getTargetColIdx(targetColumns);
if (targetIdx < 0) {
log.debug(
"singular target column does not exist in MOJO response frame,"
+ " this could be a classification model."
);
} else {
return targetColumns.subList(targetIdx, targetIdx + 1);
}
}
return targetColumns;
} else {
return Collections.emptyList();
if (outputFieldNames != null && !outputFieldNames.isEmpty()) {
return outputFieldNames;
}
return Arrays.asList(mojoFrame.getColumnNames());
}

/**
Expand All @@ -169,23 +159,12 @@ private List<String> getTargetField(
* columns indices will be extracted.
*/
private List<Integer> getTargetFieldIndices(MojoFrame mojoFrame) {
if (mojoFrame.getNcols() > 0) {
List<String> targetColumns = Arrays.asList(mojoFrame.getColumnNames());
if (supportPredictionInterval) {
int targetIdx = getTargetColIdx(targetColumns);
if (targetIdx < 0) {
log.debug(
"singular target column does not exist in MOJO response frame,"
+ " this could be a classification model."
);
} else {
return Collections.singletonList(targetIdx);
}
}
return IntStream.range(0, mojoFrame.getNcols()).boxed().collect(Collectors.toList());
} else {
return Collections.emptyList();
}
List<String> targetColumns = getTargetFields(mojoFrame);
List<String> frameColumns = Arrays.asList(mojoFrame.getColumnNames());
return targetColumns
.stream()
.map(frameColumns::indexOf)
.collect(Collectors.toList());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -50,6 +52,7 @@ public class MojoScorer {
supportPredictionInterval
? loadMojoPipelineFromFile(buildPipelineConfigWithPredictionInterval())
: loadMojoPipelineFromFile();
public static final List<String> modelOutputFieldNames = getModelOutputFieldNames(pipeline);
private final ShapleyLoadOption enabledShapleyTypes;
private final boolean shapleyEnabled;
private static MojoPipeline pipelineTransformedShapley;
Expand Down Expand Up @@ -100,13 +103,6 @@ public MojoScorer(
* @return response {@link ScoreResponse}
*/
public ScoreResponse score(ScoreRequest request) {
if (Boolean.TRUE.equals(request.isRequestPredictionIntervals())
&& !supportPredictionInterval) {
throw new IllegalArgumentException(
"requestPredictionIntervals set to true, but model does not support it"
);
}

scoreRequestTransformer.accept(request, getModelInfo().getSchema().getInputFields());
MojoFrame requestFrame = scoreRequestConverter
.apply(request, pipeline.getInputFrameBuilder());
Expand Down Expand Up @@ -418,6 +414,17 @@ private static File getMojoFile() {
return mojoFile;
}

private static List<String> getModelOutputFieldNames(MojoPipeline pipeline) {
return getOutputFields(pipeline.getOutputMeta());
}

private static List<String> getOutputFields(MojoFrameMeta outputMeta) {
return IntStream
.range(0, outputMeta.size())
.mapToObj(outputMeta::getColumnName)
.collect(Collectors.toList());
}

private static boolean checkIfPredictionIntervalSupport() {
File mojoFile = getMojoFile();
try {
Expand Down
Loading

0 comments on commit 7347396

Please sign in to comment.