Skip to content

Commit

Permalink
fixed formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhargav Dodla committed Sep 29, 2023
1 parent 8b2c5ca commit 593b3bc
Showing 1 changed file with 107 additions and 41 deletions.
148 changes: 107 additions & 41 deletions sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def __init__(
assert registry_config is not None, "SqlRegistry needs a valid registry_config"
# pool_recycle will recycle connections after the given number of seconds has passed
# This is to avoid automatic disconnections when no activity is detected on connection
self.engine: Engine = create_engine(registry_config.path, echo=False, pool_recycle=3600)
self.engine: Engine = create_engine(
registry_config.path, echo=False, pool_recycle=3600
)
metadata.create_all(self.engine)
self.project = project
if project is not None:
Expand Down Expand Up @@ -257,19 +259,26 @@ def refresh(self, project: Optional[str] = None):
if project_metadata:
usage.set_current_project_uuid(project_metadata.project_uuid)
else:
proto_registry_utils.init_project_metadata(self.cached_registry_proto, project)
proto_registry_utils.init_project_metadata(
self.cached_registry_proto, project
)
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = datetime.utcnow()

def _refresh_cached_registry_if_necessary(self):
with self._refresh_lock:
expired = (
self.cached_registry_proto is None or self.cached_registry_proto_created is None
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds() > 0 # 0 ttl means infinity
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
datetime.utcnow()
> (self.cached_registry_proto_created + self.cached_registry_proto_ttl)
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
)
)
)

Expand All @@ -279,7 +288,10 @@ def _refresh_cached_registry_if_necessary(self):

def _check_if_registry_refreshed(self):
CACHE_REFRESH_THRESHOLD_SECONDS = 300
if (self.cached_registry_proto is None or self.cached_registry_proto_created is None) or (
if (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds() > 0 # 0 ttl means infinity
and (
datetime.utcnow()
Expand All @@ -294,7 +306,9 @@ def _check_if_registry_refreshed(self):
f"Cache is stale: {seconds_since_last_refresh} seconds since last refresh"
)

def get_stream_feature_view(self, name: str, project: str, allow_cache: bool = False):
def get_stream_feature_view(
self, name: str, project: str, allow_cache: bool = False
):
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.get_stream_feature_view(
Expand Down Expand Up @@ -343,7 +357,9 @@ def apply_entity(self, entity: Entity, project: str, commit: bool = True):
def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.get_entity(self.cached_registry_proto, name, project)
return proto_registry_utils.get_entity(
self.cached_registry_proto, name, project
)
return self._get_object(
table=entities,
name=name,
Expand All @@ -355,10 +371,14 @@ def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Enti
not_found_exception=EntityNotFoundException,
)

def get_feature_view(self, name: str, project: str, allow_cache: bool = False) -> FeatureView:
def get_feature_view(
self, name: str, project: str, allow_cache: bool = False
) -> FeatureView:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.get_feature_view(self.cached_registry_proto, name, project)
return proto_registry_utils.get_feature_view(
self.cached_registry_proto, name, project
)
return self._get_object(
table=feature_views,
name=name,
Expand Down Expand Up @@ -389,7 +409,9 @@ def get_on_demand_feature_view(
not_found_exception=FeatureViewNotFoundException,
)

def get_request_feature_view(self, name: str, project: str, allow_cache: bool = False):
def get_request_feature_view(
self, name: str, project: str, allow_cache: bool = False
):
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.get_request_feature_view(
Expand Down Expand Up @@ -482,11 +504,17 @@ def list_validation_references(
def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.list_entities(self.cached_registry_proto, project)
return self._list_objects(entities, project, EntityProto, Entity, "entity_proto")
return proto_registry_utils.list_entities(
self.cached_registry_proto, project
)
return self._list_objects(
entities, project, EntityProto, Entity, "entity_proto"
)

def delete_entity(self, name: str, project: str, commit: bool = True):
return self._delete_object(entities, name, project, "entity_name", EntityNotFoundException)
return self._delete_object(
entities, name, project, "entity_name", EntityNotFoundException
)

def delete_feature_view(self, name: str, project: str, commit: bool = True):
deleted_count = 0
Expand All @@ -496,7 +524,9 @@ def delete_feature_view(self, name: str, project: str, commit: bool = True):
on_demand_feature_views,
stream_feature_views,
}:
deleted_count += self._delete_object(table, name, project, "feature_view_name", None)
deleted_count += self._delete_object(
table, name, project, "feature_view_name", None
)
if deleted_count == 0:
raise FeatureViewNotFoundException(name, project)

Expand All @@ -509,10 +539,14 @@ def delete_feature_service(self, name: str, project: str, commit: bool = True):
FeatureServiceNotFoundException,
)

def get_data_source(self, name: str, project: str, allow_cache: bool = False) -> DataSource:
def get_data_source(
self, name: str, project: str, allow_cache: bool = False
) -> DataSource:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.get_data_source(self.cached_registry_proto, name, project)
return proto_registry_utils.get_data_source(
self.cached_registry_proto, name, project
)
return self._get_object(
table=data_sources,
name=name,
Expand All @@ -524,20 +558,28 @@ def get_data_source(self, name: str, project: str, allow_cache: bool = False) ->
not_found_exception=DataSourceObjectNotFoundException,
)

def list_data_sources(self, project: str, allow_cache: bool = False) -> List[DataSource]:
def list_data_sources(
self, project: str, allow_cache: bool = False
) -> List[DataSource]:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.list_data_sources(self.cached_registry_proto, project)
return proto_registry_utils.list_data_sources(
self.cached_registry_proto, project
)
return self._list_objects(
data_sources, project, DataSourceProto, DataSource, "data_source_proto"
)

def apply_data_source(self, data_source: DataSource, project: str, commit: bool = True):
def apply_data_source(
self, data_source: DataSource, project: str, commit: bool = True
):
return self._apply_object(
data_sources, project, "data_source_name", data_source, "data_source_proto"
)

def apply_feature_view(self, feature_view: BaseFeatureView, project: str, commit: bool = True):
def apply_feature_view(
self, feature_view: BaseFeatureView, project: str, commit: bool = True
):
fv_table = self._infer_fv_table(feature_view)

return self._apply_object(
Expand Down Expand Up @@ -570,7 +612,9 @@ def list_feature_services(
) -> List[FeatureService]:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.list_feature_services(self.cached_registry_proto, project)
return proto_registry_utils.list_feature_services(
self.cached_registry_proto, project
)
return self._list_objects(
feature_services,
project,
Expand All @@ -579,18 +623,26 @@ def list_feature_services(
"feature_service_proto",
)

def list_feature_views(self, project: str, allow_cache: bool = False) -> List[FeatureView]:
def list_feature_views(
self, project: str, allow_cache: bool = False
) -> List[FeatureView]:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.list_feature_views(self.cached_registry_proto, project)
return proto_registry_utils.list_feature_views(
self.cached_registry_proto, project
)
return self._list_objects(
feature_views, project, FeatureViewProto, FeatureView, "feature_view_proto"
)

def list_saved_datasets(self, project: str, allow_cache: bool = False) -> List[SavedDataset]:
def list_saved_datasets(
self, project: str, allow_cache: bool = False
) -> List[SavedDataset]:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.list_saved_datasets(self.cached_registry_proto, project)
return proto_registry_utils.list_saved_datasets(
self.cached_registry_proto, project
)
return self._list_objects(
saved_datasets,
project,
Expand Down Expand Up @@ -636,7 +688,9 @@ def list_project_metadata(
) -> List[ProjectMetadata]:
if allow_cache:
self._check_if_registry_refreshed()
return proto_registry_utils.list_project_metadata(self.cached_registry_proto, project)
return proto_registry_utils.list_project_metadata(
self.cached_registry_proto, project
)
with self.engine.connect() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.project_id == project,
Expand Down Expand Up @@ -706,7 +760,9 @@ def apply_materialization(
FeatureViewNotFoundException,
)
fv.materialization_intervals.append((start_date, end_date))
self._apply_object(table, project, "feature_view_name", fv, "feature_view_proto")
self._apply_object(
table, project, "feature_view_name", fv, "feature_view_proto"
)

def delete_validation_reference(self, name: str, project: str, commit: bool = True):
self._delete_object(
Expand Down Expand Up @@ -804,7 +860,9 @@ def _infer_fv_classes(self, feature_view):
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return python_class, proto_class

def get_user_metadata(self, project: str, feature_view: BaseFeatureView) -> Optional[bytes]:
def get_user_metadata(
self, project: str, feature_view: BaseFeatureView
) -> Optional[bytes]:
table = self._infer_fv_table(feature_view)

name = feature_view.name
Expand Down Expand Up @@ -907,7 +965,9 @@ def _apply_object(
else:
obj_proto = obj.to_proto()

if hasattr(obj_proto, "meta") and hasattr(obj_proto.meta, "created_timestamp"):
if hasattr(obj_proto, "meta") and hasattr(
obj_proto.meta, "created_timestamp"
):
obj_proto.meta.created_timestamp.FromDatetime(update_datetime)

values = {
Expand Down Expand Up @@ -1004,15 +1064,18 @@ def _list_objects(
rows = conn.execute(stmt).all()
if rows:
return [
python_class.from_proto(proto_class.FromString(row[proto_field_name]))
python_class.from_proto(
proto_class.FromString(row[proto_field_name])
)
for row in rows
]
return []

def _set_last_updated_metadata(self, last_updated: datetime, project: str):
with self.engine.connect() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
feast_metadata.c.project_id == project,
)
row = conn.execute(stmt).first()
Expand Down Expand Up @@ -1045,7 +1108,8 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str):
def _get_last_updated_metadata(self, project: str):
with self.engine.connect() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
feast_metadata.c.project_id == project,
)
row = conn.execute(stmt).first()
Expand Down Expand Up @@ -1092,23 +1156,25 @@ def get_all_project_metadata(self) -> List[ProjectMetadataModel]:
project_name=project_id
)

project_metadata_model: ProjectMetadataModel = project_metadata_model_dict[
project_id
]
project_metadata_model: ProjectMetadataModel = (
project_metadata_model_dict[project_id]
)
if metadata_key == FeastMetadataKeys.PROJECT_UUID.value:
project_metadata_model.project_uuid = metadata_value

if metadata_key == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value:
project_metadata_model.last_updated_timestamp = datetime.utcfromtimestamp(
int(metadata_value)
project_metadata_model.last_updated_timestamp = (
datetime.utcfromtimestamp(int(metadata_value))
)
return list(project_metadata_model_dict.values())

def get_project_metadata(self, project: str) -> ProjectMetadataModel:
"""
Returns given project metdata. No supporting function in SQL Registry so implemented this here rather than using _get_last_updated_metadata and list_project_metadata.
"""
project_metadata_model: ProjectMetadataModel = ProjectMetadataModel(project_name=project)
project_metadata_model: ProjectMetadataModel = ProjectMetadataModel(
project_name=project
)
with self.engine.connect() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.project_id == project,
Expand All @@ -1123,7 +1189,7 @@ def get_project_metadata(self, project: str) -> ProjectMetadataModel:
project_metadata_model.project_uuid = metadata_value

if metadata_key == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value:
project_metadata_model.last_updated_timestamp = datetime.utcfromtimestamp(
int(metadata_value)
project_metadata_model.last_updated_timestamp = (
datetime.utcfromtimestamp(int(metadata_value))
)
return project_metadata_model

0 comments on commit 593b3bc

Please sign in to comment.