diff --git a/backend/infrahub/core/initialization.py b/backend/infrahub/core/initialization.py index ae6989a849..9d53d510a5 100644 --- a/backend/infrahub/core/initialization.py +++ b/backend/infrahub/core/initialization.py @@ -28,6 +28,7 @@ from infrahub.core.schema.manager import SchemaManager from infrahub.database import InfrahubDatabase from infrahub.exceptions import DatabaseError +from infrahub.graphql.manager import GraphQLSchemaManager from infrahub.log import get_logger from infrahub.menu.menu import default_menu from infrahub.menu.utils import create_menu_children @@ -178,6 +179,16 @@ async def initialization(db: InfrahubDatabase) -> None: ) await branch.save(db=db) + default_branch = registry.get_branch_from_registry(branch=registry.default_branch) + schema_branch = registry.schema.get_schema_branch(name=default_branch.name) + gqlm = GraphQLSchemaManager.get_manager_for_branch(branch=default_branch, schema_branch=schema_branch) + gqlm.get_graphql_schema( + include_query=True, + include_mutation=True, + include_subscription=True, + include_types=True, + ) + # --------------------------------------------------- # Load Default Namespace # --------------------------------------------------- diff --git a/backend/infrahub/tasks/registry.py b/backend/infrahub/tasks/registry.py index 3d726a7bc7..9ed4d6b210 100644 --- a/backend/infrahub/tasks/registry.py +++ b/backend/infrahub/tasks/registry.py @@ -18,6 +18,7 @@ async def refresh_branches(db: InfrahubDatabase) -> None: If a branch is already present with a different value for the hash We pull the new schema from the database and we update the registry. """ + from infrahub.graphql.manager import GraphQLSchemaManager # pylint: disable=import-outside-toplevel,cyclic-import async with lock.registry.local_schema_lock(): branches = await registry.branch_object.get_list(db=db) @@ -38,11 +39,27 @@ async def refresh_branches(db: InfrahubDatabase) -> None: ) await registry.schema.load_schema(db=db, branch=new_branch) registry.branch[new_branch.name] = new_branch + schema_branch = registry.schema.get_schema_branch(name=new_branch.name) + gqlm = GraphQLSchemaManager.get_manager_for_branch(branch=new_branch, schema_branch=schema_branch) + gqlm.get_graphql_schema( + include_query=True, + include_mutation=True, + include_subscription=True, + include_types=True, + ) else: log.info("New branch detected, pulling schema", branch=new_branch.name, worker=WORKER_IDENTITY) await registry.schema.load_schema(db=db, branch=new_branch) registry.branch[new_branch.name] = new_branch + schema_branch = registry.schema.get_schema_branch(name=new_branch.name) + gqlm = GraphQLSchemaManager.get_manager_for_branch(branch=new_branch, schema_branch=schema_branch) + gqlm.get_graphql_schema( + include_query=True, + include_mutation=True, + include_subscription=True, + include_types=True, + ) for branch_name in list(registry.branch.keys()): if branch_name not in active_branches: diff --git a/backend/tests/helpers/schema/__init__.py b/backend/tests/helpers/schema/__init__.py index a066165800..f093f65ae7 100644 --- a/backend/tests/helpers/schema/__init__.py +++ b/backend/tests/helpers/schema/__init__.py @@ -4,6 +4,7 @@ from infrahub.core import registry from infrahub.core.schema import SchemaRoot +from infrahub.graphql.manager import GraphQLSchemaManager from .car import CAR from .manufacturer import MANUFACTURER @@ -28,6 +29,7 @@ async def load_schema(db: InfrahubDatabase, schema: SchemaRoot, branch_name: str await registry.schema.update_schema_branch( schema=tmp_schema, db=db, branch=branch_name or default_branch_name, update_db=True ) + GraphQLSchemaManager.clear_cache() __all__ = ["CAR", "CAR_SCHEMA", "MANUFACTURER", "PERSON", "TICKET", "WIDGET"] diff --git a/backend/tests/unit/api/test_50_config_api.py b/backend/tests/unit/api/test_50_config_api.py index a1c5b3ca0b..f9dfd6fabc 100644 --- a/backend/tests/unit/api/test_50_config_api.py +++ b/backend/tests/unit/api/test_50_config_api.py @@ -1,7 +1,9 @@ from infrahub.database import InfrahubDatabase -async def test_config_endpoint(db: InfrahubDatabase, client, client_headers, default_branch): +async def test_config_endpoint( + db: InfrahubDatabase, client, client_headers, default_branch, register_core_models_schema: None +): with client: response = client.get( "/api/config", diff --git a/backend/tests/unit/api/test_auth.py b/backend/tests/unit/api/test_auth.py index 6ed790a453..eaf57cee03 100644 --- a/backend/tests/unit/api/test_auth.py +++ b/backend/tests/unit/api/test_auth.py @@ -1,6 +1,8 @@ import jwt +from fastapi.testclient import TestClient from infrahub import config +from infrahub.core.branch import Branch from infrahub.database import InfrahubDatabase EXPIRED_ACCESS_TOKEN = ( @@ -163,7 +165,9 @@ async def test_password_based_login_invalid_password(db: InfrahubDatabase, defau } -async def test_use_expired_token(db: InfrahubDatabase, default_branch, client): +async def test_use_expired_token( + db: InfrahubDatabase, default_branch: Branch, client: TestClient, register_core_models_schema: None +) -> None: with client: response = client.get( "/api/transform/jinja2/testing", headers={"Authorization": f"Bearer {EXPIRED_ACCESS_TOKEN}"} @@ -173,7 +177,9 @@ async def test_use_expired_token(db: InfrahubDatabase, default_branch, client): assert response.json() == {"data": None, "errors": [{"message": "Expired Signature", "extensions": {"code": 401}}]} -async def test_refresh_access_token_with_expired_refresh_token(db: InfrahubDatabase, default_branch, client): +async def test_refresh_access_token_with_expired_refresh_token( + db: InfrahubDatabase, default_branch: Branch, client: TestClient, register_core_models_schema: None +) -> None: """Validate that the correct error is returned for an expired refresh token""" with client: response = client.post("/api/auth/refresh", headers={"Authorization": f"Bearer {EXPIRED_REFRESH_TOKEN}"}) diff --git a/backend/tests/unit/api/test_openapi.py b/backend/tests/unit/api/test_openapi.py index 3158315544..8b209e52b2 100644 --- a/backend/tests/unit/api/test_openapi.py +++ b/backend/tests/unit/api/test_openapi.py @@ -1,10 +1,9 @@ +from fastapi.testclient import TestClient + from infrahub.core.branch import Branch -async def test_openapi( - client, - default_branch: Branch, -): +async def test_openapi(client: TestClient, default_branch: Branch, register_core_models_schema: None) -> None: """Validate that the OpenAPI specs can be generated.""" with client: response = client.get(