Skip to content
This repository has been archived by the owner on Apr 15, 2022. It is now read-only.

Commit

Permalink
primary keys in vector (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
Myles Novick authored Apr 5, 2021
1 parent 51fadb8 commit 80519d0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 11 additions & 2 deletions feature_store/src/rest_api/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,31 @@ def validate_feature_vector_keys(join_key_values, feature_sets) -> None:
raise SpliceMachineException(status_code=status.HTTP_400_BAD_REQUEST, code=ExceptionCodes.MISSING_ARGUMENTS,
message=f"The following keys were not provided and must be: {missing_keys}")

def get_feature_vector(db: Session, feats: List[schemas.Feature], join_keys: Dict[str, Union[str, int]], feature_sets: List[schemas.FeatureSet], return_sql: bool) -> Union[Dict[str, Any], str]:
def get_feature_vector(db: Session, feats: List[schemas.Feature], join_keys: Dict[str, Union[str, int]], feature_sets: List[schemas.FeatureSet],
return_pks: bool, return_sql: bool) -> Union[Dict[str, Any], str]:
"""
Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets
:param db: SqlAlchemy Session
:param features: List of Features
:param join_key_values: (dict) join key values to get the proper Feature values formatted as {join_key_column_name: join_key_value}
:param feature_sets: List of Feature Sets
:param return_pks: Whether to return the Feature Set primary keys in the vector. Default True
:param return_sql: Whether to return the SQL needed to get the vector or the values themselves. Default False
:return: Dict or str (SQL statement)
"""
metadata = MetaData(db.get_bind())

tables = [Table(fset.table_name.lower(), metadata, PrimaryKeyConstraint(*[pk.lower() for pk in fset.primary_keys]), schema=fset.schema_name.lower(), autoload=True).\
alias(f'fset{fset.feature_set_id}') for fset in feature_sets]
columns = [getattr(table.c, f.name.lower()) for f in feats for table in tables if f.name.lower() in table.c]

columns = []
if return_pks:
seen = set()
pks = [seen.add(pk_col.name) or getattr(table.c, pk_col.name) for table in tables for pk_col in table.primary_key if pk_col.name not in seen]
columns.extend(pks)

columns.extend([getattr(table.c, f.name.lower()) for f in feats for table in tables if f.name.lower() in table.c])

# For each Feature Set, for each primary key in the given feature set, get primary key value from the user provided dictionary
filters = [getattr(table.c, pk_col.name)==join_keys[pk_col.name.lower()]
Expand Down
4 changes: 2 additions & 2 deletions feature_store/src/rest_api/routers/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_features_by_name(names: List[str] = Query([], alias="name"), db: Session
description="Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets",
operation_id='get_feature_vector', tags=['Features'])
@managed_transaction
def get_feature_vector(fjk: schemas.FeatureJoinKeys, sql: bool = False, db: Session = Depends(crud.get_db)):
def get_feature_vector(fjk: schemas.FeatureJoinKeys, pks: bool = True, sql: bool = False, db: Session = Depends(crud.get_db)):
"""
Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets
"""
Expand All @@ -93,7 +93,7 @@ def get_feature_vector(fjk: schemas.FeatureJoinKeys, sql: bool = False, db: Sess
feature_sets = crud.get_feature_sets(db, [f.feature_set_id for f in feats])
crud.validate_feature_vector_keys(join_keys, feature_sets)

return crud.get_feature_vector(db, feats, join_keys, feature_sets, sql)
return crud.get_feature_vector(db, feats, join_keys, feature_sets, pks, sql)


@SYNC_ROUTER.post('/feature-vector-sql', status_code=status.HTTP_200_OK, response_model=str,
Expand Down

0 comments on commit 80519d0

Please sign in to comment.