Skip to content

Commit

Permalink
resolve #1628 [NLP-Model] Fix problems with Sagemaker model integrati…
Browse files Browse the repository at this point in the history
…on + add integration test for english request
  • Loading branch information
charles_moulhaud authored and vsct-jburet committed Jun 13, 2024
1 parent 08b2b65 commit 0c42695
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 4 deletions.
5 changes: 5 additions & 0 deletions nlp/admin/server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
</dependency>

<dependency>
<groupId>org.jasypt</groupId>
<artifactId>jasypt</artifactId>
</dependency>
</dependencies>

</project>
4 changes: 4 additions & 0 deletions nlp/front/ioc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
<groupId>ai.tock</groupId>
<artifactId>tock-nlp-model-opennlp</artifactId>
</dependency>
<dependency>
<groupId>ai.tock</groupId>
<artifactId>tock-nlp-model-sagemaker</artifactId>
</dependency>
<dependency>
<groupId>ai.tock</groupId>
<artifactId>tock-nlp-model-service</artifactId>
Expand Down
5 changes: 5 additions & 0 deletions nlp/model/sagemaker/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
<artifactId>aws-query-protocol</artifactId>
<version>${aws.version}</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
<version>1.3.2</version>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ internal class SagemakerEntityClassifier(model: EntityModelHolder) : NlpEntityCl
e.start,
e.end,
// entity is entityType in fact -- do not modify for the moment
Entity(EntityType(e.entity),e.role.toString())
Entity(EntityType(e.entity),e.role.toString()),
e.value
),
e.confidence
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,24 @@ import kotlin.test.assertTrue
* limitations under the License.
*/

/**
* Tests are disabled because it calls sagemaker endpoints on aws that can be expensive. So be careful if you want to really execute it !!
* In each test, set endpoint name and profile you want to test in SagemakerAwsClientProperties
*/
class SagemakerAwsClientIntegrationTest {

@Test
@Disabled // Test is disabled because it calls a sagemaker endpoint on aws that can be expensive. So be careful if you want to really execute it
@Disabled
fun testParseIntents() {
val config = SagemakerAwsClientProperties(Region.EU_WEST_3, "default", "application/json", "default")
val client = SagemakerAwsClient(config)
val response = client.parseIntent(SagemakerAwsClient.ParsedRequest("je veux un TGV Paris Marseille demain à 18h"))
assertEquals(response.intent?.name, "evoyageurs:search_by_od")
assertTrue { response.intent?.score!! > 0.98 }
assertTrue { response.intent?.score!! > 0.93 }
}

@Test
@Disabled // Test is disabled because it calls a sagemaker endpoint on aws that can be expensive. So be careful if you want to really execute it
@Disabled
fun testParseEntities() {
val config = SagemakerAwsClientProperties(Region.EU_WEST_3, "default", "application/json", "default")
val client = SagemakerAwsClient(config)
Expand Down Expand Up @@ -86,4 +90,42 @@ class SagemakerAwsClientIntegrationTest {
assertEquals(response.entities[3].role , "destination")
assert(response.entities[3].confidence > 0.99)
}

@Test
@Disabled
fun testParseEntitiesEnglishRequest() {
val config = SagemakerAwsClientProperties(Region.EU_WEST_3, "default", "application/json", "default")
val client = SagemakerAwsClient(config)
val response = client.parseEntities(SagemakerAwsClient.ParsedRequest("Is my TGV 8536 from Cannes to Montpellier delayed?"))
println(response)
assertEquals(response.entities[0].value , "TGV")
assertEquals(response.entities[0].start , 6)
assertEquals(response.entities[0].end , 9)
assertEquals(response.entities[0].entity , "evoyageurs:mode")
assertEquals(response.entities[0].role , "mode")
assert(response.entities[0].confidence > 0.99)


assertEquals(response.entities[1].value , "8536")
assertEquals(response.entities[1].start , 10)
assertEquals(response.entities[1].end , 14)
assertEquals(response.entities[1].entity , "evoyageurs:train")
assertEquals(response.entities[1].role , "train")
assert(response.entities[1].confidence > 0.99)

assertEquals(response.entities[2].value , "Cannes")
assertEquals(response.entities[2].start , 20)
assertEquals(response.entities[2].end , 26)
assertEquals(response.entities[2].entity , "evoyageurs:location")
assertEquals(response.entities[2].role , "origin")
assert(response.entities[2].confidence > 0.99)

assertEquals(response.entities[3].value , "Montpellier")
assertEquals(response.entities[3].start , 30)
assertEquals(response.entities[3].end , 41)
assertEquals(response.entities[3].entity , "evoyageurs:location")
assertEquals(response.entities[3].role , "destination")
assert(response.entities[3].confidence > 0.99)
}

}

0 comments on commit 0c42695

Please sign in to comment.