Skip to content

Commit

Permalink
Add commit / rollback SQLA on commands which edit the database on har…
Browse files Browse the repository at this point in the history
…ddisk. See #1187. Originally executions only occurred in memory but weren't saved to harddisk.

PiperOrigin-RevId: 691080052
  • Loading branch information
xingyousong authored and copybara-github committed Oct 29, 2024
1 parent f083e4e commit 64b0220
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
2 changes: 1 addition & 1 deletion vizier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@

sys.path.append(PROTO_ROOT)

__version__ = "0.1.19"
__version__ = "0.1.20"
20 changes: 19 additions & 1 deletion vizier/_src/service/sql_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
class SQLDataStore(datastore.DataStore):
"""SQL Datastore."""

def __init__(self, engine):
def __init__(self, engine: sqla.engine.Engine):
self._engine = engine
self._connection = self._engine.connect()
self._root_metadata = sqla.MetaData()
Expand Down Expand Up @@ -104,12 +104,16 @@ def create_study(self, study: study_pb2.Study) -> resources.StudyResource:
with self._lock:
try:
self._connection.execute(owner_query)
self._connection.commit()
except sqla.exc.IntegrityError:
logging.info('Owner with name %s currently exists.', owner_name)
self._connection.rollback()
try:
self._connection.execute(study_query)
self._connection.commit()
return study_resource
except sqla.exc.IntegrityError as e:
self._connection.rollback()
raise AlreadyExistsError(
'Study with name %s already exists.' % study.name
) from e
Expand Down Expand Up @@ -148,6 +152,7 @@ def update_study(self, study: study_pb2.Study) -> resources.StudyResource:
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Study %s does not exist.' % study.name)
self._connection.execute(uq)
self._connection.commit()
return study_resource

def delete_study(self, study_name: str) -> None:
Expand All @@ -172,6 +177,7 @@ def delete_study(self, study_name: str) -> None:
raise NotFoundError('Study %s does not exist.' % study_name)
self._connection.execute(dsq)
self._connection.execute(dtq)
self._connection.commit()

def list_studies(self, owner_name: str) -> List[study_pb2.Study]:
owner_id = resources.OwnerResource.from_name(owner_name).owner_id
Expand Down Expand Up @@ -205,8 +211,10 @@ def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
with self._lock:
try:
self._connection.execute(query)
self._connection.commit()
return trial_resource
except sqla.exc.IntegrityError as e:
self._connection.rollback()
raise AlreadyExistsError(
'Trial with name %s already exists.' % trial.name
) from e
Expand Down Expand Up @@ -246,6 +254,7 @@ def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource:
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Trial %s does not exist.' % trial.name)
self._connection.execute(uq)
self._connection.commit()

return trial_resource

Expand Down Expand Up @@ -283,6 +292,7 @@ def delete_trial(self, trial_name: str) -> None:
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Trial %s does not exist.' % trial_name)
self._connection.execute(dq)
self._connection.commit()

def max_trial_id(self, study_name: str) -> int:
study_resource = resources.StudyResource.from_name(study_name)
Expand Down Expand Up @@ -323,8 +333,10 @@ def create_suggestion_operation(
try:
with self._lock:
self._connection.execute(query)
self._connection.commit()
return resource
except sqla.exc.IntegrityError as e:
self._connection.rollback()
raise AlreadyExistsError(
'Suggest Op with name %s already exists.' % operation.name
) from e
Expand Down Expand Up @@ -375,6 +387,7 @@ def update_suggestion_operation(
if not self._connection.execute(eq).fetchone()[0]:
raise NotFoundError('Suggest op %s does not exist.' % operation.name)
self._connection.execute(uq)
self._connection.commit()
return resource

def list_suggestion_operations(
Expand Down Expand Up @@ -464,8 +477,10 @@ def create_early_stopping_operation(
try:
with self._lock:
self._connection.execute(query)
self._connection.commit()
return resource
except sqla.exc.IntegrityError as e:
self._connection.rollback()
raise AlreadyExistsError(
'Early stopping op with name %s already exists.' % operation.name
) from e
Expand Down Expand Up @@ -521,6 +536,7 @@ def update_early_stopping_operation(
'Early stopping op %s does not exist.' % operation.name
)
self._connection.execute(uq)
self._connection.commit()
return resource

def update_metadata(
Expand Down Expand Up @@ -552,6 +568,7 @@ def update_metadata(
usq = usq.where(self._studies_table.c.study_name == study_name)
usq = usq.values(serialized_study=original_study.SerializeToString())
self._connection.execute(usq)
self._connection.commit()

# Split the trial-related metadata by Trial.
split_metadata = collections.defaultdict(list)
Expand All @@ -578,3 +595,4 @@ def update_metadata(
utq = utq.where(self._trials_table.c.trial_name == trial_name)
utq = utq.values(serialized_trial=original_trial.SerializeToString())
self._connection.execute(utq)
self._connection.commit()
10 changes: 6 additions & 4 deletions vizier/_src/service/sql_datastore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from __future__ import annotations

"""Tests for sql_datastore."""

import os
import sqlalchemy as sqla

from vizier._src.service import constants
from vizier._src.service import datastore_test_lib
from vizier._src.service import sql_datastore
Expand Down Expand Up @@ -46,7 +46,9 @@ def setUp(self):
)
)

engine = sqla.create_engine(constants.SQL_MEMORY_URL, echo=True)
engine = sqla.create_engine(
constants.SQL_MEMORY_URL, echo=True, future=True
)
self.datastore = sql_datastore.SQLDataStore(engine)
super().setUp()

Expand Down Expand Up @@ -92,12 +94,12 @@ def test_local_hdd_persistence(self):
db_path = os.path.join(absltest.get_default_test_tmpdir(), 'local.db')
sql_url = f'sqlite:///{db_path}'

engine = sqla.create_engine(sql_url, echo=True)
engine = sqla.create_engine(sql_url, echo=True, future=True)
datastore = sql_datastore.SQLDataStore(engine)
datastore.create_study(self.example_study)
del datastore

engine2 = sqla.create_engine(sql_url, echo=True)
engine2 = sqla.create_engine(sql_url, echo=True, future=True)
datastore2 = sql_datastore.SQLDataStore(engine2)
study = datastore2.load_study(self.example_study.name)

Expand Down
3 changes: 2 additions & 1 deletion vizier/_src/service/vizier_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ def __init__(
else:
engine = sqla.create_engine(
database_url,
echo=False, # Set True to log transactions for debugging.
connect_args={'check_same_thread': False},
echo=False, # Set True to log transactions for debugging.
future=True, # Backward compatibility with sqlalchemy 1.4.
poolclass=sqla.pool.StaticPool,
)
self.datastore = sql_datastore.SQLDataStore(engine)
Expand Down

0 comments on commit 64b0220

Please sign in to comment.