Skip to content

Commit

Permalink
Dynamically set capabilities endpoint based on configuration (#281)
Browse files Browse the repository at this point in the history
* set supported capabilities dynamically

* review: small improvements

* adds tests

* review: improvements to tests, better mocking, improve maintainability
  • Loading branch information
nkpng2k authored Jan 11, 2022
1 parent e68f076 commit 8ee42a4
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ public Model getModelInfo() {
return modelInfoConverter.apply(pipeline);
}

public ShapleyLoadOption getEnabledShapleyTypes() {
return enabledShapleyTypes;
}

/**
* Method to load mojo pipelines for shapley scoring based on configuration
*
Expand Down
4 changes: 2 additions & 2 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ awsLambdaEventsVersion = 2.2.3
awsSdkS3Version = 1.11.445
javaxAnnotationVersion = 1.3.2
gsonVersion = 2.8.5
jupiterVersion = 5.3.1
jupiterVersion = 5.4.0
jupiterSystemStubsVersion = 1.2.0
mockitoVersion = 3.0.0
mockitoVersion = 3.4.0
springFoxVersion = 3.0.0
swaggerCodegenVersion = 3.0.0
swaggerCoreVersion = 2.0.5
Expand Down
2 changes: 2 additions & 0 deletions local-rest-scorer/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ dependencies {

testImplementation group: 'org.springframework.boot', name: 'spring-boot-starter-test'
testImplementation group: 'com.google.truth.extensions', name: 'truth-java8-extension'
testImplementation group: 'org.mockito', name: 'mockito-inline'
testImplementation group: 'org.mockito', name : 'mockito-core'
testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api'
testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-params'
testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse;
import ai.h2o.mojos.deploy.common.transform.MojoScorer;
import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder;
import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption;
import com.google.common.base.Strings;
import java.io.IOException;
import java.util.Arrays;
Expand All @@ -24,13 +25,13 @@
@Controller
public class ModelsApiController implements ModelApi {

private static final List<CapabilityType> SUPPORTED_CAPABILITIES
= Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED);
private static final Logger log = LoggerFactory.getLogger(ModelsApiController.class);

private final MojoScorer scorer;
private final SampleRequestBuilder sampleRequestBuilder;

private final List<CapabilityType> supportedCapabilities;

/**
* Simple Api controller. Inherits from {@link ModelApi}, which controls global, expected request
* mappings for the rest service.
Expand All @@ -44,6 +45,9 @@ public class ModelsApiController implements ModelApi {
public ModelsApiController(MojoScorer scorer, SampleRequestBuilder sampleRequestBuilder) {
this.scorer = scorer;
this.sampleRequestBuilder = sampleRequestBuilder;
this.supportedCapabilities = assembleSupportedCapabilities(
scorer.getEnabledShapleyTypes()
);
}

@Override
Expand All @@ -58,7 +62,7 @@ public ResponseEntity<String> getModelId() {

@Override
public ResponseEntity<List<CapabilityType>> getCapabilities() {
return ResponseEntity.ok(SUPPORTED_CAPABILITIES);
return ResponseEntity.ok(supportedCapabilities);
}

@Override
Expand Down Expand Up @@ -116,4 +120,22 @@ public ResponseEntity<ContributionResponse> getContribution(
public ResponseEntity<ScoreRequest> getSampleRequest() {
return ResponseEntity.ok(sampleRequestBuilder.build(scorer.getPipeline().getInputMeta()));
}

private static List<CapabilityType> assembleSupportedCapabilities(
ShapleyLoadOption enabledShapleyTypes) {
switch (enabledShapleyTypes) {
case ALL:
return Arrays.asList(
CapabilityType.SCORE,
CapabilityType.CONTRIBUTION_ORIGINAL,
CapabilityType.CONTRIBUTION_TRANSFORMED);
case ORIGINAL:
return Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_ORIGINAL);
case TRANSFORMED:
return Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED);
case NONE:
default:
return Arrays.asList(CapabilityType.SCORE);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package ai.h2o.mojos.deploy.local.rest.controller;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import ai.h2o.mojos.deploy.common.rest.model.CapabilityType;
import ai.h2o.mojos.deploy.common.transform.MojoScorer;
import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder;
import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption;
import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.api.MojoPipelineService;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.http.ResponseEntity;

@ExtendWith(MockitoExtension.class)
class ModelsApiControllerTest {
@Mock private SampleRequestBuilder sampleRequestBuilder;

@BeforeAll
static void setup() throws IOException {
File tmpModel = File.createTempFile("pipeline", ".mojo");
System.setProperty("mojo.path", tmpModel.getAbsolutePath());
mockMojoPipeline(tmpModel);
}

private static void mockMojoPipeline(File tmpModel) {
MojoPipeline mojoPipeline = Mockito.mock(MojoPipeline.class);
MockedStatic<MojoPipelineService> theMock = Mockito.mockStatic(MojoPipelineService.class);
theMock.when(() -> MojoPipelineService
.loadPipeline(new File(tmpModel.getAbsolutePath()))).thenReturn(mojoPipeline);
}

@Test
void verifyCapabilities_DefaultShapley_ReturnsExpected() {
// Given
List<CapabilityType> expectedCapabilities = Arrays.asList(CapabilityType.SCORE);

MojoScorer scorer = mock(MojoScorer.class);
when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.NONE);

ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder);

// When
ResponseEntity<List<CapabilityType>> response = controller.getCapabilities();

// Then
assertEquals(expectedCapabilities, response.getBody());
}

@Test
void verifyCapabilities_AllShapleyEnabled_ReturnsExpected() {
// Given
List<CapabilityType> expectedCapabilities = Arrays.asList(
CapabilityType.SCORE,
CapabilityType.CONTRIBUTION_ORIGINAL,
CapabilityType.CONTRIBUTION_TRANSFORMED);
MojoScorer scorer = mock(MojoScorer.class);
when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ALL);

ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder);

// When
ResponseEntity<List<CapabilityType>> response = controller.getCapabilities();

// Then
assertEquals(expectedCapabilities, response.getBody());
}

@Test
void verifyCapabilities_OriginalShapleyEnabled_ReturnsExpected() {
// Given
List<CapabilityType> expectedCapabilities = Arrays.asList(
CapabilityType.SCORE,
CapabilityType.CONTRIBUTION_ORIGINAL);
MojoScorer scorer = mock(MojoScorer.class);
when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ORIGINAL);

ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder);

// When
ResponseEntity<List<CapabilityType>> response = controller.getCapabilities();

// Then
assertEquals(expectedCapabilities, response.getBody());
}

@Test
void verifyCapabilities_TransformedShapleyEnabled_ReturnsExpected() {
// Given
List<CapabilityType> expectedCapabilities = Arrays.asList(
CapabilityType.SCORE,
CapabilityType.CONTRIBUTION_TRANSFORMED);
MojoScorer scorer = mock(MojoScorer.class);
when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED);

ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder);

// When
ResponseEntity<List<CapabilityType>> response = controller.getCapabilities();

// Then
assertEquals(expectedCapabilities, response.getBody());
}
}

0 comments on commit 8ee42a4

Please sign in to comment.