From 0c1a4bb4ac91744149a673b474f04e0b7607fd30 Mon Sep 17 00:00:00 2001 From: Mohammed Naser Date: Wed, 28 Aug 2024 19:16:07 -0400 Subject: [PATCH] Switch to using enginefacade Closes-Bug: #2067345 Change-Id: If9a2c96628cfcb819fee5e19f872ea015979b30f --- magnum/common/context.py | 2 + magnum/db/sqlalchemy/alembic/env.py | 4 +- magnum/db/sqlalchemy/api.py | 1098 ++++++++++------- magnum/db/sqlalchemy/models.py | 9 - magnum/tests/unit/db/base.py | 26 +- magnum/tests/unit/db/sqlalchemy/test_types.py | 24 +- 6 files changed, 677 insertions(+), 486 deletions(-) diff --git a/magnum/common/context.py b/magnum/common/context.py index 7d6c4011a8..a225c2194a 100644 --- a/magnum/common/context.py +++ b/magnum/common/context.py @@ -12,6 +12,7 @@ from eventlet.green import threading from oslo_context import context +from oslo_db.sqlalchemy import enginefacade from magnum.common import policy @@ -20,6 +21,7 @@ CONF = magnum.conf.CONF +@enginefacade.transaction_context_provider class RequestContext(context.RequestContext): """Extends security contexts from the OpenStack common library.""" diff --git a/magnum/db/sqlalchemy/alembic/env.py b/magnum/db/sqlalchemy/alembic/env.py index ff264b7652..e7690eee4e 100644 --- a/magnum/db/sqlalchemy/alembic/env.py +++ b/magnum/db/sqlalchemy/alembic/env.py @@ -13,8 +13,8 @@ from logging import config as log_config from alembic import context +from oslo_db.sqlalchemy import enginefacade -from magnum.db.sqlalchemy import api as sqla_api from magnum.db.sqlalchemy import models # this is the Alembic Config object, which provides @@ -43,7 +43,7 @@ def run_migrations_online(): and associate a connection with the context. """ - engine = sqla_api.get_engine() + engine = enginefacade.writer.get_engine() with engine.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/magnum/db/sqlalchemy/api.py b/magnum/db/sqlalchemy/api.py index 0ec438063d..231803f3c8 100644 --- a/magnum/db/sqlalchemy/api.py +++ b/magnum/db/sqlalchemy/api.py @@ -14,8 +14,11 @@ """SQLAlchemy storage backend.""" +import threading + +from oslo_db import api as oslo_db_api from oslo_db import exception as db_exc -from oslo_db.sqlalchemy import session as db_session +from oslo_db.sqlalchemy import enginefacade from oslo_db.sqlalchemy import utils as db_utils from oslo_log import log from oslo_utils import importutils @@ -35,34 +38,13 @@ from magnum.db.sqlalchemy import models from magnum.i18n import _ -profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy') +profiler_sqlalchemy = importutils.try_import("osprofiler.sqlalchemy") CONF = magnum.conf.CONF LOG = log.getLogger(__name__) -_FACADE = None - - -def _create_facade_lazily(): - global _FACADE - if _FACADE is None: - _FACADE = db_session.EngineFacade.from_config(CONF) - if profiler_sqlalchemy: - if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy: - profiler_sqlalchemy.add_tracing(sa, _FACADE.get_engine(), "db") - - return _FACADE - - -def get_engine(): - facade = _create_facade_lazily() - return facade.get_engine() - - -def get_session(**kwargs): - facade = _create_facade_lazily() - return facade.get_session(**kwargs) +_CONTEXT = threading.local() def get_backend(): @@ -70,15 +52,21 @@ def get_backend(): return Connection() -def model_query(model, *args, **kwargs): - """Query helper for simpler session usage. +def _session_for_read(): + return _wrap_session(enginefacade.reader.using(_CONTEXT)) - :param session: if present, the session to use - """ - session = kwargs.get('session') or get_session() - query = session.query(model, *args) - return query +# Please add @oslo_db_api.retry_on_deadlock decorator to all methods using +# _session_for_write (as deadlocks happen on write), so that oslo_db is able +# to retry in case of deadlocks. +def _session_for_write(): + return _wrap_session(enginefacade.writer.using(_CONTEXT)) + + +def _wrap_session(session): + if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy: + session = profiler_sqlalchemy.wrap_session(sa, session) + return session def add_identity_filter(query, value): @@ -99,20 +87,21 @@ def add_identity_filter(query, value): raise exception.InvalidIdentity(identity=value) -def _paginate_query(model, limit=None, marker=None, sort_key=None, - sort_dir=None, query=None): - if not query: - query = model_query(model) - sort_keys = ['id'] +def _paginate_query( + model, limit=None, marker=None, sort_key=None, sort_dir=None, query=None +): + sort_keys = ["id"] if sort_key and sort_key not in sort_keys: sort_keys.insert(0, sort_key) try: - query = db_utils.paginate_query(query, model, limit, sort_keys, - marker=marker, sort_dir=sort_dir) + query = db_utils.paginate_query( + query, model, limit, sort_keys, marker=marker, sort_dir=sort_dir + ) except db_exc.InvalidSortKey: raise exception.InvalidParameterValue( _('The sort_key value "%(key)s" is an invalid field for sorting') - % {'key': sort_key}) + % {"key": sort_key} + ) return query.all() @@ -139,7 +128,7 @@ def _add_tenant_filters(self, context, query): # reside in. This is equivalent to the project filtering above. elif context.domain_id == kst.trustee_domain_id: user_name = kst.client.users.get(context.user_id).name - user_project = user_name.split('_', 2)[1] + user_project = user_name.split("_", 2)[1] query = query.filter_by(project_id=user_project) else: query = query.filter_by(user_id=context.user_id) @@ -150,118 +139,150 @@ def _add_clusters_filters(self, query, filters): if filters is None: filters = {} - possible_filters = ["cluster_template_id", "name", "stack_id", - "api_address", "node_addresses", "project_id", - "user_id"] + possible_filters = [ + "cluster_template_id", + "name", + "stack_id", + "api_address", + "node_addresses", + "project_id", + "user_id", + ] filter_names = set(filters).intersection(possible_filters) - filter_dict = {filter_name: filters[filter_name] - for filter_name in filter_names} + filter_dict = { + filter_name: filters[filter_name] for filter_name in filter_names + } query = query.filter_by(**filter_dict) - if 'status' in filters: - query = query.filter(models.Cluster.status.in_(filters['status'])) + if "status" in filters: + query = query.filter(models.Cluster.status.in_(filters["status"])) # Helper to filter based on node_count field from nodegroups def filter_node_count(query, node_count, is_master=False): nfunc = func.sum(models.NodeGroup.node_count) - nquery = model_query(models.NodeGroup) - if is_master: - nquery = nquery.filter(models.NodeGroup.role == 'master') - else: - nquery = nquery.filter(models.NodeGroup.role != 'master') - nquery = nquery.group_by(models.NodeGroup.cluster_id) - nquery = nquery.having(nfunc == node_count) - uuids = [ng.cluster_id for ng in nquery.all()] - return query.filter(models.Cluster.uuid.in_(uuids)) - - if 'node_count' in filters: + with _session_for_read() as session: + nquery = session.query(models.NodeGroup) + if is_master: + nquery = nquery.filter(models.NodeGroup.role == "master") + else: + nquery = nquery.filter(models.NodeGroup.role != "master") + nquery = nquery.group_by(models.NodeGroup.cluster_id) + nquery = nquery.having(nfunc == node_count) + uuids = [ng.cluster_id for ng in nquery.all()] + return query.filter(models.Cluster.uuid.in_(uuids)) + + if "node_count" in filters: query = filter_node_count( - query, filters['node_count'], is_master=False) - if 'master_count' in filters: + query, filters["node_count"], is_master=False + ) + if "master_count" in filters: query = filter_node_count( - query, filters['master_count'], is_master=True) + query, filters["master_count"], is_master=True + ) return query - def get_cluster_list(self, context, filters=None, limit=None, marker=None, - sort_key=None, sort_dir=None): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = self._add_clusters_filters(query, filters) - return _paginate_query(models.Cluster, limit, marker, - sort_key, sort_dir, query) - + def get_cluster_list( + self, + context, + filters=None, + limit=None, + marker=None, + sort_key=None, + sort_dir=None, + ): + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = self._add_clusters_filters(query, filters) + return _paginate_query( + models.Cluster, limit, marker, sort_key, sort_dir, query + ) + + @oslo_db_api.retry_on_deadlock def create_cluster(self, values): # ensure defaults are present for new clusters - if not values.get('uuid'): - values['uuid'] = uuidutils.generate_uuid() + if not values.get("uuid"): + values["uuid"] = uuidutils.generate_uuid() cluster = models.Cluster() cluster.update(values) - try: - cluster.save() - except db_exc.DBDuplicateEntry: - raise exception.ClusterAlreadyExists(uuid=values['uuid']) - return cluster + + with _session_for_write() as session: + try: + session.add(cluster) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.ClusterAlreadyExists(uuid=values["uuid"]) + return cluster def get_cluster_by_id(self, context, cluster_id): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=cluster_id) - try: - return query.one() - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_id) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=cluster_id) + try: + return query.one() + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_id) def get_cluster_by_name(self, context, cluster_name): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(name=cluster_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple clusters exist with same name.' - ' Please use the cluster uuid instead.') - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_name) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(name=cluster_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict( + "Multiple clusters exist with same name." + " Please use the cluster uuid instead." + ) + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_name) def get_cluster_by_uuid(self, context, cluster_uuid): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=cluster_uuid) - try: - return query.one() - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_uuid) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=cluster_uuid) + try: + return query.one() + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_uuid) def get_cluster_stats(self, context, project_id=None): - query = model_query(models.Cluster) - node_count_col = models.NodeGroup.node_count - ncfunc = func.sum(node_count_col) - - if project_id: - query = query.filter_by(project_id=project_id) - nquery = query.session.query(ncfunc.label("nodes")).filter_by( - project_id=project_id) - else: - nquery = query.session.query(ncfunc.label("nodes")) + with _session_for_read() as session: + query = session.query(models.Cluster) + node_count_col = models.NodeGroup.node_count + ncfunc = func.sum(node_count_col) + + if project_id: + query = query.filter_by(project_id=project_id) + # TODO(tylerchristie): hmmmmm?? + nquery = query.session.query(ncfunc.label("nodes")).filter_by( + project_id=project_id + ) + else: + nquery = query.session.query(ncfunc.label("nodes")) clusters = query.count() nodes = int(nquery.one()[0]) if nquery.one()[0] else 0 return clusters, nodes def get_cluster_count_all(self, context, filters=None): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = self._add_clusters_filters(query, filters) - return query.count() + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = self._add_clusters_filters(query, filters) + return query.count() + @oslo_db_api.retry_on_deadlock def destroy_cluster(self, cluster_id): - session = get_session() - with session.begin(): - query = model_query(models.Cluster, session=session) + with _session_for_write() as session: + query = session.query(models.Cluster) query = add_identity_filter(query, cluster_id) try: @@ -273,16 +294,16 @@ def destroy_cluster(self, cluster_id): def update_cluster(self, cluster_id, values): # NOTE(dtantsur): this can lead to very strange errors - if 'uuid' in values: + if "uuid" in values: msg = _("Cannot overwrite UUID for an existing Cluster.") raise exception.InvalidParameterValue(err=msg) return self._do_update_cluster(cluster_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_cluster(self, cluster_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Cluster, session=session) + with _session_for_write() as session: + query = session.query(models.Cluster) query = add_identity_filter(query, cluster_id) try: ref = query.with_for_update().one() @@ -296,189 +317,249 @@ def _add_cluster_template_filters(self, query, filters): if filters is None: filters = {} - possible_filters = ["name", "image_id", "flavor_id", - "master_flavor_id", "keypair_id", - "external_network_id", "dns_nameserver", - "project_id", "user_id", "labels"] + possible_filters = [ + "name", + "image_id", + "flavor_id", + "master_flavor_id", + "keypair_id", + "external_network_id", + "dns_nameserver", + "project_id", + "user_id", + "labels", + ] filter_names = set(filters).intersection(possible_filters) - filter_dict = {filter_name: filters[filter_name] - for filter_name in filter_names} + filter_dict = { + filter_name: filters[filter_name] for filter_name in filter_names + } return query.filter_by(**filter_dict) - def get_cluster_template_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - query = self._add_cluster_template_filters(query, filters) - # include public (and not hidden) ClusterTemplates - public_q = model_query(models.ClusterTemplate).filter_by( - public=True, hidden=False) - query = query.union(public_q) - # include hidden and public ClusterTemplate if admin - if context.is_admin: - hidden_q = model_query(models.ClusterTemplate).filter_by( - public=True, hidden=True) - query = query.union(hidden_q) - - return _paginate_query(models.ClusterTemplate, limit, marker, - sort_key, sort_dir, query) - + def get_cluster_template_list( + self, + context, + filters=None, + limit=None, + marker=None, + sort_key=None, + sort_dir=None, + ): + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + query = self._add_cluster_template_filters(query, filters) + # include public (and not hidden) ClusterTemplates + public_q = session.query(models.ClusterTemplate).filter_by( + public=True, hidden=False + ) + query = query.union(public_q) + # include hidden and public ClusterTemplate if admin + if context.is_admin: + hidden_q = session.query(models.ClusterTemplate).filter_by( + public=True, hidden=True + ) + query = query.union(hidden_q) + + return _paginate_query( + models.ClusterTemplate, limit, marker, sort_key, sort_dir, query + ) + + @oslo_db_api.retry_on_deadlock def create_cluster_template(self, values): # ensure defaults are present for new ClusterTemplates - if not values.get('uuid'): - values['uuid'] = uuidutils.generate_uuid() + if not values.get("uuid"): + values["uuid"] = uuidutils.generate_uuid() cluster_template = models.ClusterTemplate() cluster_template.update(values) - try: - cluster_template.save() - except db_exc.DBDuplicateEntry: - raise exception.ClusterTemplateAlreadyExists(uuid=values['uuid']) - return cluster_template + + with _session_for_write() as session: + try: + session.add(cluster_template) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.ClusterTemplateAlreadyExists( + uuid=values["uuid"] + ) + return cluster_template def get_cluster_template_by_id(self, context, cluster_template_id): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter(models.ClusterTemplate.id == cluster_template_id) - try: - return query.one() - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_id) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by( + public=True + ) + query = query.union(public_q) + query = query.filter( + models.ClusterTemplate.id == cluster_template_id + ) + try: + return query.one() + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_id + ) def get_cluster_template_by_uuid(self, context, cluster_template_uuid): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter( - models.ClusterTemplate.uuid == cluster_template_uuid) - try: - return query.one() - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_uuid) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by( + public=True + ) + query = query.union(public_q) + query = query.filter( + models.ClusterTemplate.uuid == cluster_template_uuid + ) + try: + return query.one() + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_uuid + ) def get_cluster_template_by_name(self, context, cluster_template_name): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter( - models.ClusterTemplate.name == cluster_template_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple ClusterTemplates exist with' - ' same name. Please use the ' - 'ClusterTemplate uuid instead.') - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_name) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by( + public=True + ) + query = query.union(public_q) + query = query.filter( + models.ClusterTemplate.name == cluster_template_name + ) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict( + "Multiple ClusterTemplates exist with" + " same name. Please use the " + "ClusterTemplate uuid instead." + ) + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_name + ) def _is_cluster_template_referenced(self, session, cluster_template_uuid): """Checks whether the ClusterTemplate is referenced by cluster(s).""" - query = model_query(models.Cluster, session=session) - query = self._add_clusters_filters(query, {'cluster_template_id': - cluster_template_uuid}) + query = session.query(models.Cluster) + query = self._add_clusters_filters( + query, {"cluster_template_id": cluster_template_uuid} + ) return query.count() != 0 def _is_publishing_cluster_template(self, values): - if (len(values) == 1 and ( - ('public' in values and values['public'] is True) or - ('hidden' in values) or - ('tags' in values and values['tags'] is not None))): + if len(values) == 1 and ( + ("public" in values and values["public"] is True) + or ("hidden" in values) + or ("tags" in values and values["tags"] is not None) + ): return True return False + @oslo_db_api.retry_on_deadlock def destroy_cluster_template(self, cluster_template_id): - session = get_session() - with session.begin(): - query = model_query(models.ClusterTemplate, session=session) + with _session_for_write() as session: + query = session.query(models.ClusterTemplate) query = add_identity_filter(query, cluster_template_id) try: cluster_template_ref = query.one() except NoResultFound: raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_id) + clustertemplate=cluster_template_id + ) if self._is_cluster_template_referenced( - session, cluster_template_ref['uuid']): + session, cluster_template_ref["uuid"] + ): raise exception.ClusterTemplateReferenced( - clustertemplate=cluster_template_id) + clustertemplate=cluster_template_id + ) query.delete() def update_cluster_template(self, cluster_template_id, values): # NOTE(dtantsur): this can lead to very strange errors - if 'uuid' in values: + if "uuid" in values: msg = _("Cannot overwrite UUID for an existing ClusterTemplate.") raise exception.InvalidParameterValue(err=msg) return self._do_update_cluster_template(cluster_template_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_cluster_template(self, cluster_template_id, values): - session = get_session() - with session.begin(): - query = model_query(models.ClusterTemplate, session=session) + with _session_for_write() as session: + query = session.query(models.ClusterTemplate) query = add_identity_filter(query, cluster_template_id) try: ref = query.with_for_update().one() except NoResultFound: raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_id) + clustertemplate=cluster_template_id + ) - if self._is_cluster_template_referenced(session, ref['uuid']): + if self._is_cluster_template_referenced(session, ref["uuid"]): # NOTE(flwang): We only allow to update ClusterTemplate to be # public, hidden and rename - if (not self._is_publishing_cluster_template(values) and - list(values.keys()) != ["name"]): + if not self._is_publishing_cluster_template(values) and list( + values.keys() + ) != ["name"]: raise exception.ClusterTemplateReferenced( - clustertemplate=cluster_template_id) + clustertemplate=cluster_template_id + ) ref.update(values) return ref + @oslo_db_api.retry_on_deadlock def create_x509keypair(self, values): # ensure defaults are present for new x509keypairs - if not values.get('uuid'): - values['uuid'] = uuidutils.generate_uuid() + if not values.get("uuid"): + values["uuid"] = uuidutils.generate_uuid() x509keypair = models.X509KeyPair() x509keypair.update(values) - try: - x509keypair.save() - except db_exc.DBDuplicateEntry: - raise exception.X509KeyPairAlreadyExists(uuid=values['uuid']) - return x509keypair + + with _session_for_write() as session: + try: + session.add(x509keypair) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.X509KeyPairAlreadyExists(uuid=values["uuid"]) + return x509keypair def get_x509keypair_by_id(self, context, x509keypair_id): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=x509keypair_id) - try: - return query.one() - except NoResultFound: - raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id) + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=x509keypair_id) + try: + return query.one() + except NoResultFound: + raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id) def get_x509keypair_by_uuid(self, context, x509keypair_uuid): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=x509keypair_uuid) - try: - return query.one() - except NoResultFound: - raise exception.X509KeyPairNotFound(x509keypair=x509keypair_uuid) + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=x509keypair_uuid) + try: + return query.one() + except NoResultFound: + raise exception.X509KeyPairNotFound( + x509keypair=x509keypair_uuid + ) + @oslo_db_api.retry_on_deadlock def destroy_x509keypair(self, x509keypair_id): - session = get_session() - with session.begin(): - query = model_query(models.X509KeyPair, session=session) + with _session_for_write() as session: + query = session.query(models.X509KeyPair) query = add_identity_filter(query, x509keypair_id) count = query.delete() if count != 1: @@ -486,16 +567,16 @@ def destroy_x509keypair(self, x509keypair_id): def update_x509keypair(self, x509keypair_id, values): # NOTE(dtantsur): this can lead to very strange errors - if 'uuid' in values: + if "uuid" in values: msg = _("Cannot overwrite UUID for an existing X509KeyPair.") raise exception.InvalidParameterValue(err=msg) return self._do_update_x509keypair(x509keypair_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_x509keypair(self, x509keypair_id, values): - session = get_session() - with session.begin(): - query = model_query(models.X509KeyPair, session=session) + with _session_for_write() as session: + query = session.query(models.X509KeyPair) query = add_identity_filter(query, x509keypair_id) try: ref = query.with_for_update().one() @@ -509,92 +590,124 @@ def _add_x509keypairs_filters(self, query, filters): if filters is None: filters = {} - if 'project_id' in filters: - query = query.filter_by(project_id=filters['project_id']) - if 'user_id' in filters: - query = query.filter_by(user_id=filters['user_id']) + if "project_id" in filters: + query = query.filter_by(project_id=filters["project_id"]) + if "user_id" in filters: + query = query.filter_by(user_id=filters["user_id"]) return query - def get_x509keypair_list(self, context, filters=None, limit=None, - marker=None, sort_key=None, sort_dir=None): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = self._add_x509keypairs_filters(query, filters) - return _paginate_query(models.X509KeyPair, limit, marker, - sort_key, sort_dir, query) - + def get_x509keypair_list( + self, + context, + filters=None, + limit=None, + marker=None, + sort_key=None, + sort_dir=None, + ): + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = self._add_x509keypairs_filters(query, filters) + return _paginate_query( + models.X509KeyPair, limit, marker, sort_key, sort_dir, query + ) + + @oslo_db_api.retry_on_deadlock def destroy_magnum_service(self, magnum_service_id): - session = get_session() - with session.begin(): - query = model_query(models.MagnumService, session=session) + with _session_for_write() as session: + query = session.query(models.MagnumService) query = add_identity_filter(query, magnum_service_id) count = query.delete() if count != 1: raise exception.MagnumServiceNotFound( - magnum_service_id=magnum_service_id) + magnum_service_id=magnum_service_id + ) + @oslo_db_api.retry_on_deadlock def update_magnum_service(self, magnum_service_id, values): - session = get_session() - with session.begin(): - query = model_query(models.MagnumService, session=session) + with _session_for_write() as session: + query = session.query(models.MagnumService) query = add_identity_filter(query, magnum_service_id) try: ref = query.with_for_update().one() except NoResultFound: raise exception.MagnumServiceNotFound( - magnum_service_id=magnum_service_id) + magnum_service_id=magnum_service_id + ) - if 'report_count' in values: - if values['report_count'] > ref.report_count: + if "report_count" in values: + if values["report_count"] > ref.report_count: ref.last_seen_up = timeutils.utcnow() ref.update(values) return ref def get_magnum_service_by_host_and_binary(self, host, binary): - query = model_query(models.MagnumService) - query = query.filter_by(host=host, binary=binary) - try: - return query.one() - except NoResultFound: - return None + with _session_for_read() as session: + query = session.query(models.MagnumService) + query = query.filter_by(host=host, binary=binary) + try: + return query.one() + except NoResultFound: + return None + @oslo_db_api.retry_on_deadlock def create_magnum_service(self, values): magnum_service = models.MagnumService() magnum_service.update(values) - try: - magnum_service.save() - except db_exc.DBDuplicateEntry: - host = values["host"] - binary = values["binary"] - LOG.warning("Magnum service with same host:%(host)s and" - " binary:%(binary)s had been saved into DB", - {'host': host, 'binary': binary}) - query = model_query(models.MagnumService) - query = query.filter_by(host=host, binary=binary) - return query.one() - return magnum_service - - def get_magnum_service_list(self, disabled=None, limit=None, - marker=None, sort_key=None, sort_dir=None - ): - query = model_query(models.MagnumService) - if disabled: - query = query.filter_by(disabled=disabled) - - return _paginate_query(models.MagnumService, limit, marker, - sort_key, sort_dir, query) + with _session_for_write() as session: + try: + session.add(magnum_service) + session.flush() + except db_exc.DBDuplicateEntry: + host = values["host"] + binary = values["binary"] + LOG.warning( + "Magnum service with same host:%(host)s and" + " binary:%(binary)s had been saved into DB", + {"host": host, "binary": binary}, + ) + with _session_for_read() as read_session: + query = read_session.query(models.MagnumService) + query = query.filter_by(host=host, binary=binary) + return query.one() + return magnum_service + + def get_magnum_service_list( + self, + disabled=None, + limit=None, + marker=None, + sort_key=None, + sort_dir=None, + ): + with _session_for_read() as session: + query = session.query(models.MagnumService) + if disabled: + query = query.filter_by(disabled=disabled) + + return _paginate_query( + models.MagnumService, limit, marker, sort_key, sort_dir, query + ) + + @oslo_db_api.retry_on_deadlock def create_quota(self, values): quotas = models.Quota() quotas.update(values) - try: - quotas.save() - except db_exc.DBDuplicateEntry: - raise exception.QuotaAlreadyExists(project_id=values['project_id'], - resource=values['resource']) - return quotas + + with _session_for_write() as session: + try: + session.add(quotas) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.QuotaAlreadyExists( + project_id=values["project_id"], + resource=values["resource"], + ) + return quotas def _add_quota_filters(self, query, filters): if filters is None: @@ -603,85 +716,113 @@ def _add_quota_filters(self, query, filters): possible_filters = ["resource", "project_id"] filter_names = set(filters).intersection(possible_filters) - filter_dict = {filter_name: filters[filter_name] - for filter_name in filter_names} + filter_dict = { + filter_name: filters[filter_name] for filter_name in filter_names + } query = query.filter_by(**filter_dict) return query - def get_quota_list(self, context, filters=None, limit=None, marker=None, - sort_key=None, sort_dir=None): - query = model_query(models.Quota) - query = self._add_quota_filters(query, filters) - return _paginate_query(models.Quota, limit, marker, - sort_key, sort_dir, query) - + def get_quota_list( + self, + context, + filters=None, + limit=None, + marker=None, + sort_key=None, + sort_dir=None, + ): + with _session_for_read() as session: + query = session.query(models.Quota) + query = self._add_quota_filters(query, filters) + return _paginate_query( + models.Quota, limit, marker, sort_key, sort_dir, query + ) + + @oslo_db_api.retry_on_deadlock def update_quota(self, project_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Quota, session=session) - resource = values['resource'] + with _session_for_write() as session: + query = session.query(models.Quota) + resource = values["resource"] try: query = query.filter_by(project_id=project_id).filter_by( - resource=resource) + resource=resource + ) ref = query.with_for_update().one() except NoResultFound: - msg = (_('project_id %(project_id)s resource %(resource)s.') % - {'project_id': project_id, 'resource': resource}) + msg = _("project_id %(project_id)s resource %(resource)s.") % { + "project_id": project_id, + "resource": resource, + } raise exception.QuotaNotFound(msg=msg) ref.update(values) return ref + @oslo_db_api.retry_on_deadlock def delete_quota(self, project_id, resource): - session = get_session() - with session.begin(): - query = model_query(models.Quota, session=session) \ - .filter_by(project_id=project_id) \ + with _session_for_write() as session: + query = ( + session.query(models.Quota) + .filter_by(project_id=project_id) .filter_by(resource=resource) + ) try: query.one() except NoResultFound: - msg = (_('project_id %(project_id)s resource %(resource)s.') % - {'project_id': project_id, 'resource': resource}) + msg = _("project_id %(project_id)s resource %(resource)s.") % { + "project_id": project_id, + "resource": resource, + } raise exception.QuotaNotFound(msg=msg) query.delete() def get_quota_by_id(self, context, quota_id): - query = model_query(models.Quota) - query = query.filter_by(id=quota_id) - try: - return query.one() - except NoResultFound: - msg = _('quota id %s .') % quota_id - raise exception.QuotaNotFound(msg=msg) + with _session_for_read() as session: + query = session.query(models.Quota) + query = query.filter_by(id=quota_id) + try: + return query.one() + except NoResultFound: + msg = _("quota id %s .") % quota_id + raise exception.QuotaNotFound(msg=msg) def quota_get_all_by_project_id(self, project_id): - query = model_query(models.Quota) - result = query.filter_by(project_id=project_id).all() + with _session_for_read() as session: + query = session.query(models.Quota) + result = query.filter_by(project_id=project_id).all() return result def get_quota_by_project_id_resource(self, project_id, resource): - query = model_query(models.Quota) - query = query.filter_by(project_id=project_id).filter_by( - resource=resource) + with _session_for_read() as session: + query = session.query(models.Quota) + query = query.filter_by(project_id=project_id).filter_by( + resource=resource + ) - try: - return query.one() - except NoResultFound: - msg = (_('project_id %(project_id)s resource %(resource)s.') % - {'project_id': project_id, 'resource': resource}) - raise exception.QuotaNotFound(msg=msg) + try: + return query.one() + except NoResultFound: + msg = _("project_id %(project_id)s resource %(resource)s.") % { + "project_id": project_id, + "resource": resource, + } + raise exception.QuotaNotFound(msg=msg) def _add_federation_filters(self, query, filters): if filters is None: filters = {} - possible_filters = ["name", "project_id", "hostcluster_id", - "member_ids", "properties"] + possible_filters = [ + "name", + "project_id", + "hostcluster_id", + "member_ids", + "properties", + ] # TODO(clenimar): implement 'member_ids' filter as a contains query, # so we return all the federations that have the given clusters, @@ -689,72 +830,92 @@ def _add_federation_filters(self, query, filters): # clusters. filter_names = set(filters).intersection(possible_filters) - filter_dict = {filter_name: filters[filter_name] - for filter_name in filter_names} + filter_dict = { + filter_name: filters[filter_name] for filter_name in filter_names + } query = query.filter_by(**filter_dict) - if 'status' in filters: + if "status" in filters: query = query.filter( - models.Federation.status.in_(filters['status'])) + models.Federation.status.in_(filters["status"]) + ) return query def get_federation_by_id(self, context, federation_id): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=federation_id) - try: - return query.one() - except NoResultFound: - raise exception.FederationNotFound(federation=federation_id) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=federation_id) + try: + return query.one() + except NoResultFound: + raise exception.FederationNotFound(federation=federation_id) def get_federation_by_uuid(self, context, federation_uuid): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=federation_uuid) - try: - return query.one() - except NoResultFound: - raise exception.FederationNotFound(federation=federation_uuid) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=federation_uuid) + try: + return query.one() + except NoResultFound: + raise exception.FederationNotFound(federation=federation_uuid) def get_federation_by_name(self, context, federation_name): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(name=federation_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple federations exist with same ' - 'name. Please use the federation uuid ' - 'instead.') - except NoResultFound: - raise exception.FederationNotFound(federation=federation_name) - - def get_federation_list(self, context, limit=None, marker=None, - sort_key=None, sort_dir=None, filters=None): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = self._add_federation_filters(query, filters) - return _paginate_query(models.Federation, limit, marker, - sort_key, sort_dir, query) - + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(name=federation_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict( + "Multiple federations exist with same " + "name. Please use the federation uuid " + "instead." + ) + except NoResultFound: + raise exception.FederationNotFound(federation=federation_name) + + def get_federation_list( + self, + context, + limit=None, + marker=None, + sort_key=None, + sort_dir=None, + filters=None, + ): + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = self._add_federation_filters(query, filters) + return _paginate_query( + models.Federation, limit, marker, sort_key, sort_dir, query + ) + + @oslo_db_api.retry_on_deadlock def create_federation(self, values): - if not values.get('uuid'): - values['uuid'] = uuidutils.generate_uuid() + if not values.get("uuid"): + values["uuid"] = uuidutils.generate_uuid() federation = models.Federation() federation.update(values) - try: - federation.save() - except db_exc.DBDuplicateEntry: - raise exception.FederationAlreadyExists(uuid=values['uuid']) - return federation + with _session_for_write() as session: + try: + session.add(federation) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.FederationAlreadyExists(uuid=values["uuid"]) + return federation + + @oslo_db_api.retry_on_deadlock def destroy_federation(self, federation_id): - session = get_session() - with session.begin(): - query = model_query(models.Federation, session=session) + with _session_for_write() as session: + query = session.query(models.Federation) query = add_identity_filter(query, federation_id) try: @@ -765,16 +926,16 @@ def destroy_federation(self, federation_id): query.delete() def update_federation(self, federation_id, values): - if 'uuid' in values: + if "uuid" in values: msg = _("Cannot overwrite UUID for an existing Federation.") raise exception.InvalidParameterValue(err=msg) return self._do_update_federation(federation_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_federation(self, federation_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Federation, session=session) + with _session_for_write() as session: + query = session.query(models.Federation) query = add_identity_filter(query, federation_id) try: ref = query.with_for_update().one() @@ -789,38 +950,50 @@ def _add_nodegoup_filters(self, query, filters): if filters is None: filters = {} - possible_filters = ["name", "node_count", "node_addresses", - "role", "is_default"] + possible_filters = [ + "name", + "node_count", + "node_addresses", + "role", + "is_default", + ] filter_names = set(filters).intersection(possible_filters) - filter_dict = {filter_name: filters[filter_name] - for filter_name in filter_names} + filter_dict = { + filter_name: filters[filter_name] for filter_name in filter_names + } query = query.filter_by(**filter_dict) - if 'status' in filters: + if "status" in filters: query = query.filter( - models.NodeGroup.status.in_(filters['status'])) + models.NodeGroup.status.in_(filters["status"]) + ) return query + @oslo_db_api.retry_on_deadlock def create_nodegroup(self, values): - if not values.get('uuid'): - values['uuid'] = uuidutils.generate_uuid() + if not values.get("uuid"): + values["uuid"] = uuidutils.generate_uuid() nodegroup = models.NodeGroup() nodegroup.update(values) - try: - nodegroup.save() - except db_exc.DBDuplicateEntry: - raise exception.NodeGroupAlreadyExists( - cluster_id=values['cluster_id'], name=values['name']) - return nodegroup + with _session_for_write() as session: + try: + session.add(nodegroup) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.NodeGroupAlreadyExists( + cluster_id=values["cluster_id"], name=values["name"] + ) + return nodegroup + + @oslo_db_api.retry_on_deadlock def destroy_nodegroup(self, cluster_id, nodegroup_id): - session = get_session() - with session.begin(): - query = model_query(models.NodeGroup, session=session) + with _session_for_write() as session: + query = session.query(models.NodeGroup) query = add_identity_filter(query, nodegroup_id) query = query.filter_by(cluster_id=cluster_id) try: @@ -832,10 +1005,10 @@ def destroy_nodegroup(self, cluster_id, nodegroup_id): def update_nodegroup(self, cluster_id, nodegroup_id, values): return self._do_update_nodegroup(cluster_id, nodegroup_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_nodegroup(self, cluster_id, nodegroup_id, values): - session = get_session() - with session.begin(): - query = model_query(models.NodeGroup, session=session) + with _session_for_write() as session: + query = session.query(models.NodeGroup) query = add_identity_filter(query, nodegroup_id) query = query.filter_by(cluster_id=cluster_id) try: @@ -847,56 +1020,71 @@ def _do_update_nodegroup(self, cluster_id, nodegroup_id, values): return ref def get_nodegroup_by_id(self, context, cluster_id, nodegroup_id): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(id=nodegroup_id) - try: - return query.one() - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_id) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(id=nodegroup_id) + try: + return query.one() + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_id) def get_nodegroup_by_uuid(self, context, cluster_id, nodegroup_uuid): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(uuid=nodegroup_uuid) - try: - return query.one() - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(uuid=nodegroup_uuid) + try: + return query.one() + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid) def get_nodegroup_by_name(self, context, cluster_id, nodegroup_name): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(name=nodegroup_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple nodegroups exist with same ' - 'name. Please use the nodegroup uuid ' - 'instead.') - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_name) - - def list_cluster_nodegroups(self, context, cluster_id, filters=None, - limit=None, marker=None, sort_key=None, - sort_dir=None): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = self._add_nodegoup_filters(query, filters) - return _paginate_query(models.NodeGroup, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(name=nodegroup_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict( + "Multiple nodegroups exist with same " + "name. Please use the nodegroup uuid " + "instead." + ) + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_name) + + def list_cluster_nodegroups( + self, + context, + cluster_id, + filters=None, + limit=None, + marker=None, + sort_key=None, + sort_dir=None, + ): + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = self._add_nodegoup_filters(query, filters) + return _paginate_query( + models.NodeGroup, limit, marker, sort_key, sort_dir, query + ) def get_cluster_nodegroup_count(self, context, cluster_id): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - return query.count() + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + return query.count() diff --git a/magnum/db/sqlalchemy/models.py b/magnum/db/sqlalchemy/models.py index 92b474da37..2d83093010 100644 --- a/magnum/db/sqlalchemy/models.py +++ b/magnum/db/sqlalchemy/models.py @@ -87,15 +87,6 @@ def as_dict(self): d[c.name] = self[c.name] return d - def save(self, session=None): - import magnum.db.sqlalchemy.api as db_api - - if session is None: - session = db_api.get_session() - - with session.begin(): - super(MagnumBase, self).save(session) - Base = declarative_base(cls=MagnumBase) diff --git a/magnum/tests/unit/db/base.py b/magnum/tests/unit/db/base.py index 711d30caeb..d78d8fa378 100644 --- a/magnum/tests/unit/db/base.py +++ b/magnum/tests/unit/db/base.py @@ -16,10 +16,10 @@ """Magnum DB test base class.""" import fixtures +from oslo_db.sqlalchemy import enginefacade import magnum.conf from magnum.db import api as dbapi -from magnum.db.sqlalchemy import api as sqla_api from magnum.db.sqlalchemy import migration from magnum.db.sqlalchemy import models from magnum.tests import base @@ -32,16 +32,15 @@ class Database(fixtures.Fixture): - def __init__(self, db_api, db_migrate, sql_connection): + def __init__(self, engine, db_migrate, sql_connection): self.sql_connection = sql_connection - self.engine = db_api.get_engine() + self.engine = engine self.engine.dispose() - conn = self.engine.connect() - self.setup_sqlite(db_migrate) - self.post_migrations() - - self._DB = "".join(line for line in conn.connection.iterdump()) + with self.engine.connect() as conn: + self.setup_sqlite(db_migrate) + self.post_migrations() + self._DB = "".join(line for line in conn.connection.iterdump()) self.engine.dispose() def setup_sqlite(self, db_migrate): @@ -50,9 +49,10 @@ def setup_sqlite(self, db_migrate): models.Base.metadata.create_all(self.engine) db_migrate.stamp('head') - def _setUp(self): - conn = self.engine.connect() - conn.connection.executescript(self._DB) + def setUp(self): + super(Database, self).setUp() + with self.engine.connect() as conn: + conn.connection.executescript(self._DB) self.addCleanup(self.engine.dispose) def post_migrations(self): @@ -68,6 +68,8 @@ def setUp(self): global _DB_CACHE if not _DB_CACHE: - _DB_CACHE = Database(sqla_api, migration, + engine = enginefacade.writer.get_engine() + _DB_CACHE = Database(engine, migration, sql_connection=CONF.database.connection) + engine.dispose() self.useFixture(_DB_CACHE) diff --git a/magnum/tests/unit/db/sqlalchemy/test_types.py b/magnum/tests/unit/db/sqlalchemy/test_types.py index b9a2c1103a..d89bea97ca 100644 --- a/magnum/tests/unit/db/sqlalchemy/test_types.py +++ b/magnum/tests/unit/db/sqlalchemy/test_types.py @@ -26,16 +26,22 @@ def test_JSONEncodedDict_default_value(self): # Create ClusterTemplate w/o labels cluster_template1_id = uuidutils.generate_uuid() self.dbapi.create_cluster_template({'uuid': cluster_template1_id}) - cluster_template1 = sa_api.model_query( - models.ClusterTemplate).filter_by(uuid=cluster_template1_id).one() + with sa_api._session_for_read() as session: + cluster_template1 = (session.query( + models.ClusterTemplate) + .filter_by(uuid=cluster_template1_id) + .one()) self.assertEqual({}, cluster_template1.labels) # Create ClusterTemplate with labels cluster_template2_id = uuidutils.generate_uuid() self.dbapi.create_cluster_template( {'uuid': cluster_template2_id, 'labels': {'bar': 'foo'}}) - cluster_template2 = sa_api.model_query( - models.ClusterTemplate).filter_by(uuid=cluster_template2_id).one() + with sa_api._session_for_read() as session: + cluster_template2 = (session.query( + models.ClusterTemplate) + .filter_by(uuid=cluster_template2_id) + .one()) self.assertEqual('foo', cluster_template2.labels['bar']) def test_JSONEncodedDict_type_check(self): @@ -48,8 +54,9 @@ def test_JSONEncodedList_default_value(self): # Create nodegroup w/o node_addresses nodegroup1_id = uuidutils.generate_uuid() self.dbapi.create_nodegroup({'uuid': nodegroup1_id}) - nodegroup1 = sa_api.model_query( - models.NodeGroup).filter_by(uuid=nodegroup1_id).one() + with sa_api._session_for_read() as session: + nodegroup1 = session.query( + models.NodeGroup).filter_by(uuid=nodegroup1_id).one() self.assertEqual([], nodegroup1.node_addresses) # Create nodegroup with node_addresses @@ -59,8 +66,9 @@ def test_JSONEncodedList_default_value(self): 'node_addresses': ['mynode_address1', 'mynode_address2'] }) - nodegroup2 = sa_api.model_query( - models.NodeGroup).filter_by(uuid=nodegroup2_id).one() + with sa_api._session_for_read() as session: + nodegroup2 = session.query( + models.NodeGroup).filter_by(uuid=nodegroup2_id).one() self.assertEqual(['mynode_address1', 'mynode_address2'], nodegroup2.node_addresses)