Skip to content

Commit

Permalink
fix(ingest): glue import type stubs only for testing (#3032)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhu authored Aug 4, 2021
1 parent 8c4a141 commit 3d06116
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 61 deletions.
17 changes: 10 additions & 7 deletions metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from functools import reduce
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

import boto3
from boto3.session import Session
from mypy_boto3_glue import GlueClient
from mypy_boto3_s3 import S3Client
from mypy_boto3_sagemaker import SageMakerClient

from datahub.configuration import ConfigModel
from datahub.configuration.common import AllowDenyPattern
from datahub.emitter.mce_builder import DEFAULT_ENV

if TYPE_CHECKING:

from mypy_boto3_glue import GlueClient
from mypy_boto3_s3 import S3Client
from mypy_boto3_sagemaker import SageMakerClient


def assume_role(
role_arn: str, aws_region: str, credentials: Optional[dict] = None
Expand Down Expand Up @@ -88,13 +91,13 @@ def get_session(self) -> Session:
else:
return Session(region_name=self.aws_region)

def get_s3_client(self) -> S3Client:
def get_s3_client(self) -> "S3Client":
return self.get_session().client("s3")

def get_glue_client(self) -> GlueClient:
def get_glue_client(self) -> "GlueClient":
return self.get_session().client("glue")

def get_sagemaker_client(self) -> SageMakerClient:
def get_sagemaker_client(self) -> "SageMakerClient":
return self.get_session().client("sagemaker")


Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from dataclasses import dataclass
from typing import Iterable, List

from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeFeatureGroupResponseTypeDef,
FeatureDefinitionTypeDef,
FeatureGroupSummaryTypeDef,
)
from typing import TYPE_CHECKING, Iterable, List

import datahub.emitter.mce_builder as builder
from datahub.ingestion.api.workunit import MetadataWorkUnit
Expand All @@ -27,14 +20,23 @@
MLPrimaryKeyPropertiesClass,
)

if TYPE_CHECKING:

from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeFeatureGroupResponseTypeDef,
FeatureDefinitionTypeDef,
FeatureGroupSummaryTypeDef,
)


@dataclass
class FeatureGroupProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport

def get_all_feature_groups(self) -> List[FeatureGroupSummaryTypeDef]:
def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]:
"""
List all feature groups in SageMaker.
"""
Expand All @@ -50,7 +52,7 @@ def get_all_feature_groups(self) -> List[FeatureGroupSummaryTypeDef]:

def get_feature_group_details(
self, feature_group_name: str
) -> DescribeFeatureGroupResponseTypeDef:
) -> "DescribeFeatureGroupResponseTypeDef":
"""
Get details of a feature group (including list of component features).
"""
Expand All @@ -74,7 +76,7 @@ def get_feature_group_details(
return feature_group

def get_feature_group_wu(
self, feature_group_details: DescribeFeatureGroupResponseTypeDef
self, feature_group_details: "DescribeFeatureGroupResponseTypeDef"
) -> MetadataWorkUnit:
"""
Generate an MLFeatureTable workunit for a SageMaker feature group.
Expand Down Expand Up @@ -146,8 +148,8 @@ def get_feature_type(self, aws_type: str, feature_name: str) -> str:

def get_feature_wu(
self,
feature_group_details: DescribeFeatureGroupResponseTypeDef,
feature: FeatureDefinitionTypeDef,
feature_group_details: "DescribeFeatureGroupResponseTypeDef",
feature: "FeatureDefinitionTypeDef",
) -> MetadataWorkUnit:
"""
Generate an MLFeature workunit for a SageMaker feature.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
Expand All @@ -16,8 +17,6 @@
Union,
)

from mypy_boto3_sagemaker import SageMakerClient

from datahub.emitter import mce_builder
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.aws_common import make_s3_urn
Expand Down Expand Up @@ -47,6 +46,9 @@
JobStatusClass,
)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

JobInfo = TypeVar(
"JobInfo",
AutoMlJobInfo,
Expand Down Expand Up @@ -151,7 +153,7 @@ class JobProcessor:
"""

# boto3 SageMaker client
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
# config filter for specific job types to ingest (see metadata-ingestion README)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, DefaultDict, Dict, List, Set

from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
ActionSummaryTypeDef,
ArtifactSummaryTypeDef,
AssociationSummaryTypeDef,
ContextSummaryTypeDef,
)
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set

from datahub.ingestion.source.aws.sagemaker_processors.common import (
SagemakerSourceReport,
)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
ActionSummaryTypeDef,
ArtifactSummaryTypeDef,
AssociationSummaryTypeDef,
ContextSummaryTypeDef,
)


@dataclass
class LineageInfo:
Expand Down Expand Up @@ -42,13 +43,13 @@ class LineageInfo:

@dataclass
class LineageProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
nodes: Dict[str, Dict[str, Any]] = field(default_factory=dict)
lineage_info: LineageInfo = field(default_factory=LineageInfo)

def get_all_actions(self) -> List[ActionSummaryTypeDef]:
def get_all_actions(self) -> List["ActionSummaryTypeDef"]:
"""
List all actions in SageMaker.
"""
Expand All @@ -62,7 +63,7 @@ def get_all_actions(self) -> List[ActionSummaryTypeDef]:

return actions

def get_all_artifacts(self) -> List[ArtifactSummaryTypeDef]:
def get_all_artifacts(self) -> List["ArtifactSummaryTypeDef"]:
"""
List all artifacts in SageMaker.
"""
Expand All @@ -76,7 +77,7 @@ def get_all_artifacts(self) -> List[ArtifactSummaryTypeDef]:

return artifacts

def get_all_contexts(self) -> List[ContextSummaryTypeDef]:
def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
"""
List all contexts in SageMaker.
"""
Expand All @@ -90,7 +91,7 @@ def get_all_contexts(self) -> List[ContextSummaryTypeDef]:

return contexts

def get_incoming_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:
def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
"""
Get all incoming edges for a node in the lineage graph.
"""
Expand All @@ -104,7 +105,7 @@ def get_incoming_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:

return edges

def get_outgoing_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:
def get_outgoing_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
"""
Get all outgoing edges for a node in the lineage graph.
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple

from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeEndpointOutputTypeDef,
DescribeModelOutputTypeDef,
DescribeModelPackageGroupOutputTypeDef,
EndpointSummaryTypeDef,
ModelPackageGroupSummaryTypeDef,
ModelSummaryTypeDef,
from typing import (
TYPE_CHECKING,
DefaultDict,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)

import datahub.emitter.mce_builder as builder
Expand Down Expand Up @@ -43,6 +42,17 @@
OwnershipTypeClass,
)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeEndpointOutputTypeDef,
DescribeModelOutputTypeDef,
DescribeModelPackageGroupOutputTypeDef,
EndpointSummaryTypeDef,
ModelPackageGroupSummaryTypeDef,
ModelSummaryTypeDef,
)

ENDPOINT_STATUS_MAP: Dict[str, str] = {
"OutOfService": DeploymentStatusClass.OUT_OF_SERVICE,
"Creating": DeploymentStatusClass.CREATING,
Expand All @@ -58,7 +68,7 @@

@dataclass
class ModelProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
lineage: LineageInfo
Expand All @@ -81,7 +91,7 @@ class ModelProcessor:

group_arn_to_name: Dict[str, str] = field(default_factory=dict)

def get_all_models(self) -> List[ModelSummaryTypeDef]:
def get_all_models(self) -> List["ModelSummaryTypeDef"]:
"""
List all models in SageMaker.
"""
Expand All @@ -95,15 +105,15 @@ def get_all_models(self) -> List[ModelSummaryTypeDef]:

return models

def get_model_details(self, model_name: str) -> DescribeModelOutputTypeDef:
def get_model_details(self, model_name: str) -> "DescribeModelOutputTypeDef":
"""
Get details of a model.
"""

# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model
return self.sagemaker_client.describe_model(ModelName=model_name)

def get_all_groups(self) -> List[ModelPackageGroupSummaryTypeDef]:
def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]:
"""
List all model groups in SageMaker.
"""
Expand All @@ -118,7 +128,7 @@ def get_all_groups(self) -> List[ModelPackageGroupSummaryTypeDef]:

def get_group_details(
self, group_name: str
) -> DescribeModelPackageGroupOutputTypeDef:
) -> "DescribeModelPackageGroupOutputTypeDef":
"""
Get details of a model group.
"""
Expand All @@ -128,7 +138,7 @@ def get_group_details(
ModelPackageGroupName=group_name
)

def get_all_endpoints(self) -> List[EndpointSummaryTypeDef]:
def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]:

endpoints = []

Expand All @@ -140,7 +150,9 @@ def get_all_endpoints(self) -> List[EndpointSummaryTypeDef]:

return endpoints

def get_endpoint_details(self, endpoint_name: str) -> DescribeEndpointOutputTypeDef:
def get_endpoint_details(
self, endpoint_name: str
) -> "DescribeEndpointOutputTypeDef":

# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_endpoint
return self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
Expand All @@ -162,7 +174,7 @@ def get_endpoint_status(
return endpoint_status

def get_endpoint_wu(
self, endpoint_details: DescribeEndpointOutputTypeDef
self, endpoint_details: "DescribeEndpointOutputTypeDef"
) -> MetadataWorkUnit:
"""a
Get a workunit for an endpoint.
Expand Down Expand Up @@ -206,7 +218,7 @@ def get_endpoint_wu(

def get_model_endpoints(
self,
model_details: DescribeModelOutputTypeDef,
model_details: "DescribeModelOutputTypeDef",
endpoint_arn_to_name: Dict[str, str],
model_image: Optional[str],
model_uri: Optional[str],
Expand Down Expand Up @@ -235,7 +247,7 @@ def get_model_endpoints(
return model_endpoints_sorted

def get_group_wu(
self, group_details: DescribeModelPackageGroupOutputTypeDef
self, group_details: "DescribeModelPackageGroupOutputTypeDef"
) -> MetadataWorkUnit:
"""
Get a workunit for a model group.
Expand Down Expand Up @@ -285,7 +297,7 @@ def get_group_wu(
return MetadataWorkUnit(id=group_name, mce=mce)

def match_model_jobs(
self, model_details: DescribeModelOutputTypeDef
self, model_details: "DescribeModelOutputTypeDef"
) -> Tuple[Set[str], Set[str], List[MLHyperParamClass], List[MLMetricClass]]:

model_training_jobs: Set[str] = set()
Expand Down Expand Up @@ -380,7 +392,7 @@ def strip_quotes(string: str) -> str:

def get_model_wu(
self,
model_details: DescribeModelOutputTypeDef,
model_details: "DescribeModelOutputTypeDef",
endpoint_arn_to_name: Dict[str, str],
) -> MetadataWorkUnit:
"""
Expand Down

0 comments on commit 3d06116

Please sign in to comment.