From 0c42695697eb7940660efccede4422a99de999d4 Mon Sep 17 00:00:00 2001 From: charles_moulhaud Date: Wed, 12 Jun 2024 21:09:24 +0200 Subject: [PATCH] resolve #1628 [NLP-Model] Fix problems with Sagemaker model integration + add integration test for english request --- nlp/admin/server/pom.xml | 5 ++ nlp/front/ioc/pom.xml | 4 ++ nlp/model/sagemaker/pom.xml | 5 ++ .../main/kotlin/SagemakerEntityClassifier.kt | 3 +- .../SagemakerAwsClientIntegrationTest.kt | 48 +++++++++++++++++-- 5 files changed, 61 insertions(+), 4 deletions(-) diff --git a/nlp/admin/server/pom.xml b/nlp/admin/server/pom.xml index 697cf02898..fd67c549d3 100644 --- a/nlp/admin/server/pom.xml +++ b/nlp/admin/server/pom.xml @@ -67,6 +67,11 @@ org.apache.commons commons-csv + + + org.jasypt + jasypt + diff --git a/nlp/front/ioc/pom.xml b/nlp/front/ioc/pom.xml index 1a917913ac..b1aed23b8c 100644 --- a/nlp/front/ioc/pom.xml +++ b/nlp/front/ioc/pom.xml @@ -41,6 +41,10 @@ ai.tock tock-nlp-model-opennlp + + ai.tock + tock-nlp-model-sagemaker + ai.tock tock-nlp-model-service diff --git a/nlp/model/sagemaker/pom.xml b/nlp/model/sagemaker/pom.xml index 3344526367..20870139bf 100644 --- a/nlp/model/sagemaker/pom.xml +++ b/nlp/model/sagemaker/pom.xml @@ -51,6 +51,11 @@ aws-query-protocol ${aws.version} + + commons-logging + commons-logging + 1.3.2 + diff --git a/nlp/model/sagemaker/src/main/kotlin/SagemakerEntityClassifier.kt b/nlp/model/sagemaker/src/main/kotlin/SagemakerEntityClassifier.kt index 696bf08653..7d174bf226 100644 --- a/nlp/model/sagemaker/src/main/kotlin/SagemakerEntityClassifier.kt +++ b/nlp/model/sagemaker/src/main/kotlin/SagemakerEntityClassifier.kt @@ -48,7 +48,8 @@ internal class SagemakerEntityClassifier(model: EntityModelHolder) : NlpEntityCl e.start, e.end, // entity is entityType in fact -- do not modify for the moment - Entity(EntityType(e.entity),e.role.toString()) + Entity(EntityType(e.entity),e.role.toString()), + e.value ), e.confidence ) diff --git a/nlp/model/sagemaker/src/test/kotlin/SagemakerAwsClientIntegrationTest.kt b/nlp/model/sagemaker/src/test/kotlin/SagemakerAwsClientIntegrationTest.kt index d5dbf4b799..b11cdc20c0 100644 --- a/nlp/model/sagemaker/src/test/kotlin/SagemakerAwsClientIntegrationTest.kt +++ b/nlp/model/sagemaker/src/test/kotlin/SagemakerAwsClientIntegrationTest.kt @@ -38,20 +38,24 @@ import kotlin.test.assertTrue * limitations under the License. */ +/** + * Tests are disabled because it calls sagemaker endpoints on aws that can be expensive. So be careful if you want to really execute it !! + * In each test, set endpoint name and profile you want to test in SagemakerAwsClientProperties + */ class SagemakerAwsClientIntegrationTest { @Test - @Disabled // Test is disabled because it calls a sagemaker endpoint on aws that can be expensive. So be careful if you want to really execute it + @Disabled fun testParseIntents() { val config = SagemakerAwsClientProperties(Region.EU_WEST_3, "default", "application/json", "default") val client = SagemakerAwsClient(config) val response = client.parseIntent(SagemakerAwsClient.ParsedRequest("je veux un TGV Paris Marseille demain à 18h")) assertEquals(response.intent?.name, "evoyageurs:search_by_od") - assertTrue { response.intent?.score!! > 0.98 } + assertTrue { response.intent?.score!! > 0.93 } } @Test - @Disabled // Test is disabled because it calls a sagemaker endpoint on aws that can be expensive. So be careful if you want to really execute it + @Disabled fun testParseEntities() { val config = SagemakerAwsClientProperties(Region.EU_WEST_3, "default", "application/json", "default") val client = SagemakerAwsClient(config) @@ -86,4 +90,42 @@ class SagemakerAwsClientIntegrationTest { assertEquals(response.entities[3].role , "destination") assert(response.entities[3].confidence > 0.99) } + + @Test + @Disabled + fun testParseEntitiesEnglishRequest() { + val config = SagemakerAwsClientProperties(Region.EU_WEST_3, "default", "application/json", "default") + val client = SagemakerAwsClient(config) + val response = client.parseEntities(SagemakerAwsClient.ParsedRequest("Is my TGV 8536 from Cannes to Montpellier delayed?")) + println(response) + assertEquals(response.entities[0].value , "TGV") + assertEquals(response.entities[0].start , 6) + assertEquals(response.entities[0].end , 9) + assertEquals(response.entities[0].entity , "evoyageurs:mode") + assertEquals(response.entities[0].role , "mode") + assert(response.entities[0].confidence > 0.99) + + + assertEquals(response.entities[1].value , "8536") + assertEquals(response.entities[1].start , 10) + assertEquals(response.entities[1].end , 14) + assertEquals(response.entities[1].entity , "evoyageurs:train") + assertEquals(response.entities[1].role , "train") + assert(response.entities[1].confidence > 0.99) + + assertEquals(response.entities[2].value , "Cannes") + assertEquals(response.entities[2].start , 20) + assertEquals(response.entities[2].end , 26) + assertEquals(response.entities[2].entity , "evoyageurs:location") + assertEquals(response.entities[2].role , "origin") + assert(response.entities[2].confidence > 0.99) + + assertEquals(response.entities[3].value , "Montpellier") + assertEquals(response.entities[3].start , 30) + assertEquals(response.entities[3].end , 41) + assertEquals(response.entities[3].entity , "evoyageurs:location") + assertEquals(response.entities[3].role , "destination") + assert(response.entities[3].confidence > 0.99) + } + }