Skip to content

Commit

Permalink
cleanup and switching to peewee orm
Browse files Browse the repository at this point in the history
  • Loading branch information
havok2063 committed Oct 20, 2023
1 parent 0630ab3 commit 09b128d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 31 deletions.
12 changes: 6 additions & 6 deletions python/valis/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __getattr__(self, name):

def connect_db(db, orm: str = 'peewee'):
""" Connect to the peewee sdss5db database """

from valis.main import settings
profset = db.set_profile(settings.db_server)
if settings.db_remote and not profset:
Expand All @@ -52,8 +53,7 @@ def connect_db(db, orm: str = 'peewee'):
db.connect_from_parameters(dbname='sdss5db', host=host, port=port,
user=user, password=passwd)

print(db)
print(db.connection_params)
# raise error if we cannot connect
if not db.connected:
raise HTTPException(status_code=503, detail=f'Could not connect to database via sdssdb {orm}.')

Expand All @@ -62,6 +62,8 @@ def connect_db(db, orm: str = 'peewee'):

def get_pw_db(db_state=Depends(reset_db_state)):
""" Dependency to connect a database with peewee """

# connect to the db, yield None since we don't need the db in peewee
db = connect_db(pdb, orm='peewee')
try:
yield None
Expand All @@ -72,11 +74,9 @@ def get_pw_db(db_state=Depends(reset_db_state)):

def get_sqla_db():
""" Dependency to connect to a database with sqlalchemy """
db = connect_db(sdb, orm='sqla')

if not db.connected:
raise HTTPException(status_code=503, detail='Could not connect to database via sdssdb sqla.')

# connect to the db, yield the db Session object for sql queries
db = connect_db(sdb, orm='sqla')
db = db.Session()
try:
yield db
Expand Down
19 changes: 16 additions & 3 deletions python/valis/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
#

# all resuable Pydantic models of the ORMs go here

import peewee
from typing import Any, Optional
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -32,8 +34,19 @@ class Config:
getter_dict = PeeweeGetterDict


# class SDSSidStackedBaseA(OrmBase):
# """ Pydantic model for the SQLA vizdb.SDSSidStacked ORM """

# sdss_id: int = Field(..., description='the SDSS identifier')
# ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid')
# dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid')
# catalogid21: Optional[int] = Field(description='the version 21 catalog id')
# catalogid25: Optional[int] = Field(description='the version 25 catalog id')
# catalogid31: Optional[int] = Field(description='the version 31 catalog id')


class SDSSidStackedBase(PeeweeBase):
""" Pydantic model for the SQLA vizdb.SDSSidStacked ORM """
""" Pydantic model for the Peewee vizdb.SDSSidStacked ORM """

sdss_id: int = Field(..., description='the SDSS identifier')
ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid')
Expand All @@ -44,7 +57,7 @@ class SDSSidStackedBase(PeeweeBase):


class SDSSidFlatBase(PeeweeBase):
""" Pydantic model for the SQLA vizdb.SDSSidFlat ORM """
""" Pydantic model for the Peewee vizdb.SDSSidFlat ORM """

sdss_id: int = Field(..., description='the SDSS identifier')
ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid')
Expand All @@ -57,7 +70,7 @@ class SDSSidFlatBase(PeeweeBase):


class SDSSidPipesBase(PeeweeBase):
""" Pydantic model for the SQLA vizdb.SDSSidToPipes ORM """
""" Pydantic model for the Peewee vizdb.SDSSidToPipes ORM """

sdss_id: int = Field(..., description='the SDSS identifier')
in_boss: bool = Field(..., description='Flag if the sdss_id is in the BHM reductions')
Expand Down
4 changes: 4 additions & 0 deletions python/valis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
"name": "mocs",
"description": "Access SDSS surveys MOCs",
},
{
"name": "query",
"description": "Query the SDSS databases",
},
]


Expand Down
31 changes: 9 additions & 22 deletions python/valis/routes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from fastapi import APIRouter, Depends, Query
from fastapi_utils.cbv import cbv

from sdssdb.peewee.sdss5db import vizdb

from valis.routes.base import Base
from valis.db.db import get_pw_db, get_sqla_db
from valis.db.db import get_pw_db
from valis.db.models import SDSSidStackedBase
from valis.db.queries import cone_search


router = APIRouter()
Expand All @@ -20,28 +19,14 @@
class Query(Base):
""" API routes for performing queries against sdss5db """

@router.get("/testa", summary='slqa test')
async def get_testa(self, db=Depends(get_sqla_db)) -> dict:
""" Get a list of available SDSS maskbits schema or flag names """
#if db and db.connected:
from sdssdb.sqlalchemy.mangadb import datadb
return db.query(datadb.Cube).first()
#else:
# return {"message": "Not connected to database"}

# @router.get("/testp", summary='peewee test', response_model=SourceBase,
# dependencies=[Depends(get_pw_db)])
# async def get_testp(self) -> dict:
# """ Get a list of available SDSS maskbits schema or flag names """
# return catalogdb.Gaia_edr3_allwise_best_neighbour.select().first()

# @router.get('/acone', summary='Perform a cone search for SDSS targets with sdss_ids',
# response_model=List[SDSSidStackedBase])
# async def cone_search(self,
# @router.get('/test_sqla', summary='Perform a cone search for SDSS targets with sdss_ids',
# response_model=List[SDSSidStackedBaseA])
# async def test_search(self,
# ra=Query(..., description='right ascension in degrees', example=315.01417),
# dec=Query(..., description='declination in degrees', example=35.299),
# radius=Query(..., description='the search radius in degrees', example=0.01),
# db=Depends(get_sqla_db)):
# """ Example for writing a route with a sqlalchemy ORM """
# from sdssdb.sqlalchemy.sdss5db import vizdb

# return db.query(vizdb.SDSSidStacked).\
Expand All @@ -53,4 +38,6 @@ async def cone_search(self,
ra=Query(..., description='right ascension in degrees', example=315.01417),
dec=Query(..., description='declination in degrees', example=35.299),
radius=Query(..., description='the search radius in degrees', example=0.01)):
return list(vizdb.SDSSidStacked.select().where(vizdb.SDSSidStacked.cone_search(ra, dec, radius, ra_col='ra_sdss_id', dec_col='dec_sdss_id')))
""" Perform a cone search """
return list(cone_search(ra, dec, radius))

0 comments on commit 09b128d

Please sign in to comment.