From 04c5b7623ab0677c289a88626163c7b69bebee9d Mon Sep 17 00:00:00 2001 From: David Ruiz Falco Date: Tue, 7 Jan 2025 10:35:20 +0100 Subject: [PATCH] ndb client --- nuclia_e2e/tests/conftest.py | 10 +++++++++- nuclia_e2e/tests/test_kb.py | 13 +++++++++++-- nuclia_e2e/tests/test_onboarding.py | 4 ++-- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/nuclia_e2e/tests/conftest.py b/nuclia_e2e/tests/conftest.py index 8d62cf7..148f05e 100644 --- a/nuclia_e2e/tests/conftest.py +++ b/nuclia_e2e/tests/conftest.py @@ -5,8 +5,10 @@ from nuclia.config import set_config_file from nuclia.data import get_auth from nuclia.data import get_config +from nuclia.lib.kb import AsyncNucliaDBClient +from nuclia.lib.kb import Environment from nuclia.sdk.kbs import NucliaKBS -from nuclia_e2e.tests.data import TEST_ACCOUNT_SLUG +from tests.data import TEST_ACCOUNT_SLUG import aiohttp import asyncio @@ -193,6 +195,12 @@ def regional_api_config(request, global_api_config): config.set_default_account(global_api_config["permanent_account_slug"]) config.set_default_zone(zone_config["zone_slug"]) zone_config["test_kb_slug"] = "{test_kb_slug}-{name}".format(**zone_config) + zone_config["ndb"] = AsyncNucliaDBClient( + environment=Environment.CLOUD, + url=regional_api_config["base_url"], + user_token=regional_api_config["user_token"], + region=regional_api_config["name"], + ) return zone_config diff --git a/nuclia_e2e/tests/test_kb.py b/nuclia_e2e/tests/test_kb.py index dc1578a..ff1411e 100644 --- a/nuclia_e2e/tests/test_kb.py +++ b/nuclia_e2e/tests/test_kb.py @@ -249,9 +249,10 @@ async def run_test_ask(regional_api_config): features=["keyword", "semantic", "relations"], query="why cocoa prices high?", model="chatgpt4o", + ndb=regional_api_config["ndb"], ) - assert "climate change" in ask_result.answer.decode().lower() + ask_more_result = await kb.search.ask( autofilter=True, rephrase=True, @@ -263,6 +264,7 @@ async def run_test_ask(regional_api_config): ], query="when?", model="chatgpt4o", + ndb=regional_api_config["ndb"], ) assert "earlier" in ask_more_result.answer.decode().lower() @@ -341,7 +343,14 @@ async def condition() -> tuple[bool, Any]: async def run_test_kb_deletion(regional_api_config): kbs = NucliaKBS() print("deleting " + regional_api_config["test_kb_slug"]) - await asyncio.to_thread(partial(kbs.delete, slug=regional_api_config["test_kb_slug"])) + await asyncio.to_thread( + partial( + kbs.delete, + slug=regional_api_config["test_kb_slug"], + zone=regional_api_config["name"], + account=regional_api_config["permanent_account_slug"], + ) + ) kbid = await get_kbid_from_slug(regional_api_config["test_kb_slug"]) assert kbid is None diff --git a/nuclia_e2e/tests/test_onboarding.py b/nuclia_e2e/tests/test_onboarding.py index 4ece819..bd8ecf3 100644 --- a/nuclia_e2e/tests/test_onboarding.py +++ b/nuclia_e2e/tests/test_onboarding.py @@ -1,5 +1,5 @@ -from nuclia_e2e.tests.data import TEST_ACCOUNT_SLUG -from nuclia_e2e.tests.data import TEST_ONBOARD_INQUIRY +from tests.data import TEST_ACCOUNT_SLUG +from tests.data import TEST_ONBOARD_INQUIRY import pytest