From f9d2e2806d4adb5a41efaa0c5b9305352799d5cb Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Fri, 23 Aug 2024 17:00:39 -0400 Subject: [PATCH] Move to utils, check in pre_migrate signal --- awx/main/apps.py | 10 +++++++++ awx/main/management/commands/check_db.py | 28 +++++++++--------------- awx/main/utils/db.py | 24 ++++++++++++++++++++ 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/awx/main/apps.py b/awx/main/apps.py index 099caea96a6b..d0348a21e45d 100644 --- a/awx/main/apps.py +++ b/awx/main/apps.py @@ -1,6 +1,10 @@ from django.apps import AppConfig from django.utils.translation import gettext_lazy as _ +from django.core.management.base import CommandError +from django.db.models.signals import pre_migrate + from awx.main.utils.named_url_graph import _customize_graph, generate_graph +from awx.main.utils.db import db_requirement_violations from awx.conf import register, fields @@ -8,6 +12,11 @@ class MainConfig(AppConfig): name = 'awx.main' verbose_name = _('Main') + def check_db_requirement(self, *args, **kwargs): + violations = db_requirement_violations() + if violations: + raise CommandError(violations) + def load_named_url_feature(self): models = [m for m in self.get_models() if hasattr(m, 'get_absolute_url')] generate_graph(models) @@ -38,3 +47,4 @@ def ready(self): super().ready() self.load_named_url_feature() + pre_migrate.connect(self.check_db_requirement, sender=self) diff --git a/awx/main/management/commands/check_db.py b/awx/main/management/commands/check_db.py index 02d8851f3c22..0d34340f3d20 100644 --- a/awx/main/management/commands/check_db.py +++ b/awx/main/management/commands/check_db.py @@ -1,30 +1,22 @@ # Copyright (c) 2015 Ansible, Inc. # All Rights Reserved -import sys - -from django.core.management.base import BaseCommand +from django.core.management.base import BaseCommand, CommandError from django.db import connection +from awx.main.utils.db import db_requirement_violations + class Command(BaseCommand): """Checks connection to the database, and prints out connection info if not connected""" def handle(self, *args, **options): - if connection.vendor == 'postgresql': - - with connection.cursor() as cursor: - cursor.execute("SELECT version()") - version = str(cursor.fetchone()[0]) + with connection.cursor() as cursor: + cursor.execute("SELECT version()") + version = str(cursor.fetchone()[0]) - # enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1 - # In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that. - # The return of connection.pg_version is something like 12013 - if (connection.pg_version // 10000) < 12: - self.stderr.write(f"At a minimum, postgres version 12 is required, found {version}\n") - sys.exit(1) + violations = db_requirement_violations() + if violations: + raise CommandError(violations) - return "Database Version: {}".format(version) - else: - self.stderr.write(f"Running server with '{connection.vendor}' type database is not supported\n") - sys.exit(1) + return "Database Version: {}".format(version) diff --git a/awx/main/utils/db.py b/awx/main/utils/db.py index 8cc6aacce9f2..8f549f80c229 100644 --- a/awx/main/utils/db.py +++ b/awx/main/utils/db.py @@ -1,10 +1,34 @@ # Copyright (c) 2017 Ansible by Red Hat # All Rights Reserved. +from typing import Optional from awx.settings.application_name import set_application_name +from awx import MODE + from django.conf import settings +from django.db import connection def set_connection_name(function): set_application_name(settings.DATABASES, settings.CLUSTER_HOST_ID, function=function) + + +MIN_PG_VERSION = 12 + + +def db_requirement_violations() -> Optional[str]: + if connection.vendor == 'postgresql': + + # enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1 + # In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that. + # The return of connection.pg_version is something like 12013 + major_version = connection.pg_version // 10000 + if major_version < MIN_PG_VERSION: + return f"At a minimum, postgres version {MIN_PG_VERSION} is required, found {major_version}\n" + + return None + else: + if MODE == 'production': + return f"Running server with '{connection.vendor}' type database is not supported\n" + return None