Skip to content

Commit

Permalink
merging maria-migrate into main
Browse files Browse the repository at this point in the history
  • Loading branch information
amyfromandi committed Sep 4, 2024
2 parents 28b40ea + 57a2ccf commit 240d0f6
Show file tree
Hide file tree
Showing 146 changed files with 1,788 additions and 998 deletions.
19 changes: 19 additions & 0 deletions .idea/dataSources.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 69 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion cli/macrostrat/cli/_dev/dump_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from macrostrat.utils import get_logger
from sqlalchemy.engine import Engine

from .utils import _create_command, print_stdout, print_stream_progress
from .utils import _create_command
from .stream_utils import print_stream_progress, print_stdout

log = get_logger(__name__)

Expand Down
12 changes: 4 additions & 8 deletions cli/macrostrat/cli/_dev/restore_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from .utils import (
_create_command,
_create_database_if_not_exists,
print_stdout,
print_stream_progress,
)
from .stream_utils import print_stream_progress, print_stdout

console = Console()

Expand All @@ -26,10 +25,8 @@ def pg_restore(*args, **kwargs):

async def _pg_restore(
engine: Engine,
*,
*args,
create=False,
command_prefix: Optional[list] = None,
args: list = [],
postgres_container: str = "postgres:15",
):
# Pipe file to pg_restore, mimicking
Expand All @@ -42,11 +39,10 @@ async def _pg_restore(
# host, if possible, is probably the fastest option. There should be
# multiple options ideally.
_cmd = _create_command(
engine,
"pg_restore",
"-d",
args=args,
prefix=command_prefix,
engine,
*args,
container=postgres_container,
)

Expand Down
117 changes: 117 additions & 0 deletions cli/macrostrat/cli/_dev/stream_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import asyncio
import sys
import zlib

from aiofiles.threadpool import AsyncBufferedIOBase
from macrostrat.utils import get_logger
from .utils import console

log = get_logger(__name__)


async def print_stream_progress(
input: asyncio.StreamReader | asyncio.subprocess.Process,
out_stream: asyncio.StreamWriter | None,
*,
verbose: bool = False,
chunk_size: int = 1024,
prefix: str = None,
):
"""This should be unified with print_stream_progress, but there seem to be
slight API differences between aiofiles and asyncio.StreamWriter APIs.?"""
in_stream = input
if isinstance(in_stream, asyncio.subprocess.Process):
in_stream = input.stdout

megabytes_written = 0
i = 0

# Iterate over the stream by chunks
try:
while True:
chunk = await in_stream.read(chunk_size)
if not chunk:
log.info("End of stream")
break
if verbose:
log.info(chunk)
megabytes_written += len(chunk) / 1_000_000
if isinstance(out_stream, AsyncBufferedIOBase):
await out_stream.write(chunk)
await out_stream.flush()
elif out_stream is not None:
out_stream.write(chunk)
await out_stream.drain()
i += 1
if i == 100:
i = 0
_print_progress(megabytes_written, end="\r", prefix=prefix)
except asyncio.CancelledError:
pass
finally:
_print_progress(megabytes_written, prefix=prefix)

if isinstance(out_stream, AsyncBufferedIOBase):
out_stream.close()
elif out_stream is not None:
out_stream.close()
await out_stream.wait_closed()


def _print_progress(megabytes: float, **kwargs):
prefix = kwargs.pop("prefix", None)
if prefix is None:
prefix = "Dumped"
progress = f"{prefix} {megabytes:.1f} MB"
kwargs["file"] = sys.stderr
print(progress, **kwargs)


async def print_stdout(stream: asyncio.StreamReader):
async for line in stream:
log.info(line)
console.print(line.decode("utf-8"), style="dim")


class DecodingStreamReader(asyncio.StreamReader):
"""A StreamReader that decompresses gzip files (if compressed)"""

# https://ejosh.co/de/2022/08/stream-a-massive-gzipped-json-file-in-python/

def __init__(self, stream, encoding="utf-8", errors="strict"):
super().__init__()
self.stream = stream
self._is_gzipped = None
self.d = zlib.decompressobj(zlib.MAX_WBITS | 16)

def decompress(self, input: bytes) -> bytes:
decompressed = self.d.decompress(input)
data = b""
while self.d.unused_data != b"":
buf = self.d.unused_data
self.d = zlib.decompressobj(zlib.MAX_WBITS | 16)
data = self.d.decompress(buf)
return decompressed + data

def transform_data(self, data):
if self._is_gzipped is None:
self._is_gzipped = data[:2] == b"\x1f\x8b"
log.info("is_gzipped: %s", self._is_gzipped)
if self._is_gzipped:
# Decompress the data
data = self.decompress(data)
return data

async def read(self, n=-1):
data = await self.stream.read(n)
return self.transform_data(data)

async def readline(self):
res = b""
while res == b"":
# Read next line
line = await self.stream.readline()
if not line:
break
res += self.transform_data(line)
return res
2 changes: 1 addition & 1 deletion cli/macrostrat/cli/_dev/transfer_tables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from .utils import print_stream_progress, print_stdout
from .stream_utils import print_stream_progress, print_stdout
from sqlalchemy.engine import Engine
from .dump_database import _pg_dump
from .restore_database import _pg_restore
Expand Down
70 changes: 42 additions & 28 deletions cli/macrostrat/cli/_dev/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
from urllib.parse import quote
import sys

from aiofiles.threadpool.binary import AsyncBufferedIOBase
from macrostrat.utils import get_logger
from rich.console import Console
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sqlalchemy_utils import create_database, database_exists
from sqlalchemy_utils import create_database, database_exists, drop_database
from macrostrat.core.exc import MacrostratError


console = Console()
Expand All @@ -20,53 +19,64 @@ def _docker_local_run_args(postgres_container: str = "postgres:15"):
"docker",
"run",
"-i",
"--attach",
"stdin",
"--attach",
"stdout",
"--attach",
"stderr",
"--log-driver",
"none",
"--rm",
"--network",
"host",
postgres_container,
]


def _create_database_if_not_exists(_url: URL, create=False):
def _create_database_if_not_exists(
_url: URL, *, create=False, allow_exists=True, overwrite=False
):
database = _url.database
if overwrite:
create = True
db_exists = database_exists(_url)
if db_exists:
console.print(f"Database [bold cyan]{database}[/] already exists")
msg = f"Database [bold underline]{database}[/] already exists"
if overwrite:
console.print(f"{msg}, overwriting")
drop_database(_url)
db_exists = False
elif not allow_exists:
raise MacrostratError(msg, details="Use `--overwrite` to overwrite")
else:
console.print(msg)

if create and not db_exists:
console.print(f"Creating database [bold cyan]{database}[/]")
create_database(_url)

if not db_exists and not create:
raise ValueError(
raise MacrostratError(
f"Database [bold cyan]{database}[/] does not exist. "
"Use `--create` to create it."
)


def _create_command(
engine: Engine,
*command,
args=[],
prefix=None | list[str],
container="postgres:16",
container=None | str,
):
command_prefix = prefix or _docker_local_run_args(container)
_cmd = [*command_prefix, *command, str(engine.url), *args]

log.info(" ".join(_cmd))

# Replace asterisks with the real password (if any). This is kind of backwards
# but it works.
if "***" in str(engine.url) and engine.url.password is not None:
_cmd = [
*command_prefix,
*command,
raw_database_url(engine.url),
*args,
]
"""Create a command for operating on a database"""
_args = []
if container is not None:
_args = _docker_local_run_args(container)

return _cmd
for arg in command:
if isinstance(arg, Engine):
arg = arg.url
if isinstance(arg, URL):
arg = raw_database_url(arg)
_args.append(arg)
return _args


async def print_stream_progress(
Expand Down Expand Up @@ -109,4 +119,8 @@ async def print_stdout(stream: asyncio.StreamReader):


def raw_database_url(url: URL):
return str(url).replace("***", quote(url.password, safe=""))
"""Replace the password asterisks with the actual password, in order to pass to other commands."""
_url = str(url)
if "***" not in _url or url.password is None:
return _url
return _url.replace("***", quote(url.password, safe=""))
12 changes: 12 additions & 0 deletions cli/macrostrat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ def update_tileserver(db):

app.subsystems.add(MacrostratAPISubsystem(app))

# Mariadb CLI
if mariadb_url := getattr(settings, "mysql_database", None):
from .database.mariadb import app as mariadb_app

main.add_typer(
mariadb_app,
name="mariadb",
rich_help_panel="Subsystems",
short_help="Manage the MariaDB database",
)


app.finish_loading_subsystems()


Expand Down
Loading

0 comments on commit 240d0f6

Please sign in to comment.