diff --git a/typedspark/_core/validate_schema.py b/typedspark/_core/validate_schema.py index 47c0cef..5de41e6 100644 --- a/typedspark/_core/validate_schema.py +++ b/typedspark/_core/validate_schema.py @@ -1,8 +1,10 @@ """Module containing functions that are related to validating schema's at runtime.""" -from typing import Dict, Set +from typing import Dict from pyspark.sql.types import ArrayType, DataType, MapType, StructField, StructType +from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype + def validate_schema( structtype_expected: StructType, structtype_observed: StructType, schema_name: str @@ -11,7 +13,7 @@ def validate_schema( expected = unpack_schema(structtype_expected) observed = unpack_schema(structtype_observed) - check_names(set(expected.keys()), set(observed.keys()), schema_name) + check_names(expected, observed, schema_name) check_dtypes(expected, observed, schema_name) @@ -31,15 +33,26 @@ def unpack_schema(schema: StructType) -> Dict[str, StructField]: return res -def check_names(names_expected: Set[str], names_observed: Set[str], schema_name: str) -> None: +def check_names( + expected: Dict[str, StructField], observed: Dict[str, StructField], schema_name: str +) -> None: """Checks whether the observed and expected list of column names overlap. Is order insensitive. """ + names_observed = set(observed.keys()) + names_expected = set(expected.keys()) + diff = names_observed - names_expected if diff: + diff_schema = create_schema_from_structtype( + StructType([observed[colname] for colname in diff]), schema_name + ) raise TypeError( - f"Data contains the following columns not present in schema {schema_name}: {diff}" + f"Data contains the following columns not present in schema {schema_name}: {diff}.\n\n" + "If you believe these columns should be part of the schema, consider adding the " + "following lines to it.\n\n" + f"{diff_schema.get_schema_definition_as_string(generate_imports=False)}" ) diff = names_expected - names_observed diff --git a/typedspark/_utils/create_dataset_from_structtype.py b/typedspark/_utils/create_dataset_from_structtype.py new file mode 100644 index 0000000..15102f6 --- /dev/null +++ b/typedspark/_utils/create_dataset_from_structtype.py @@ -0,0 +1,70 @@ +"""Utility functions for creating a ``Schema`` from a ``StructType``""" +from typing import Dict, Literal, Optional, Type + +from pyspark.sql.types import ArrayType as SparkArrayType +from pyspark.sql.types import DataType +from pyspark.sql.types import DayTimeIntervalType as SparkDayTimeIntervalType +from pyspark.sql.types import DecimalType as SparkDecimalType +from pyspark.sql.types import MapType as SparkMapType +from pyspark.sql.types import StructType as SparkStructType + +from typedspark._core.column import Column +from typedspark._core.datatypes import ( + ArrayType, + DayTimeIntervalType, + DecimalType, + MapType, + StructType, +) +from typedspark._schema.schema import MetaSchema, Schema +from typedspark._utils.camelcase import to_camel_case + + +def create_schema_from_structtype( + structtype: SparkStructType, schema_name: Optional[str] = None +) -> Type[Schema]: + """Dynamically builds a ``Schema`` based on a ``DataFrame``'s ``StructType``""" + type_annotations = {} + attributes: Dict[str, None] = {} + for column in structtype: + name = column.name + data_type = _extract_data_type(column.dataType, name) + type_annotations[name] = Column[data_type] # type: ignore + attributes[name] = None + + if not schema_name: + schema_name = "DynamicallyLoadedSchema" + + schema = MetaSchema(schema_name, tuple([Schema]), attributes) + schema.__annotations__ = type_annotations + + return schema # type: ignore + + +def _extract_data_type(dtype: DataType, name: str) -> Type[DataType]: + """Given an instance of a ``DataType``, it extracts the corresponding ``DataType`` + class, potentially including annotations (e.g. ``ArrayType[StringType]``).""" + if isinstance(dtype, SparkArrayType): + element_type = _extract_data_type(dtype.elementType, name) + return ArrayType[element_type] # type: ignore + + if isinstance(dtype, SparkMapType): + key_type = _extract_data_type(dtype.keyType, name) + value_type = _extract_data_type(dtype.valueType, name) + return MapType[key_type, value_type] # type: ignore + + if isinstance(dtype, SparkStructType): + subschema = create_schema_from_structtype(dtype, to_camel_case(name)) + return StructType[subschema] # type: ignore + + if isinstance(dtype, SparkDayTimeIntervalType): + start_field = dtype.startField + end_field = dtype.endField + return DayTimeIntervalType[Literal[start_field], Literal[end_field]] # type: ignore + + if isinstance(dtype, SparkDecimalType): + precision = dtype.precision + scale = dtype.scale + return DecimalType[Literal[precision], Literal[scale]] # type: ignore + + return type(dtype) diff --git a/typedspark/_utils/load_table.py b/typedspark/_utils/load_table.py index 6746134..12cf599 100644 --- a/typedspark/_utils/load_table.py +++ b/typedspark/_utils/load_table.py @@ -1,27 +1,13 @@ """Functions for loading `DataSet` and `Schema` in notebooks.""" import re -from typing import Dict, Literal, Optional, Tuple, Type +from typing import Dict, Optional, Tuple, Type from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import ArrayType as SparkArrayType -from pyspark.sql.types import DataType -from pyspark.sql.types import DayTimeIntervalType as SparkDayTimeIntervalType -from pyspark.sql.types import DecimalType as SparkDecimalType -from pyspark.sql.types import MapType as SparkMapType -from pyspark.sql.types import StructType as SparkStructType - -from typedspark._core.column import Column + from typedspark._core.dataset import DataSet -from typedspark._core.datatypes import ( - ArrayType, - DayTimeIntervalType, - DecimalType, - MapType, - StructType, -) -from typedspark._schema.schema import MetaSchema, Schema -from typedspark._utils.camelcase import to_camel_case +from typedspark._schema.schema import Schema +from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset @@ -64,60 +50,11 @@ def _replace_illegal_characters(column_name: str) -> str: return re.sub("[^A-Za-z0-9]", "_", column_name) -def _create_schema(structtype: SparkStructType, schema_name: Optional[str] = None) -> Type[Schema]: - """Dynamically builds a ``Schema`` based on a ``DataFrame``'s - ``StructType``""" - type_annotations = {} - attributes: Dict[str, None] = {} - for column in structtype: - name = column.name - data_type = _extract_data_type(column.dataType, name) - type_annotations[name] = Column[data_type] # type: ignore - attributes[name] = None - - if not schema_name: - schema_name = "DynamicallyLoadedSchema" - - schema = MetaSchema(schema_name, tuple([Schema]), attributes) - schema.__annotations__ = type_annotations - - return schema # type: ignore - - -def _extract_data_type(dtype: DataType, name: str) -> Type[DataType]: - """Given an instance of a ``DataType``, it extracts the corresponding - ``DataType`` class, potentially including annotations (e.g. - ``ArrayType[StringType]``).""" - if isinstance(dtype, SparkArrayType): - element_type = _extract_data_type(dtype.elementType, name) - return ArrayType[element_type] # type: ignore - - if isinstance(dtype, SparkMapType): - key_type = _extract_data_type(dtype.keyType, name) - value_type = _extract_data_type(dtype.valueType, name) - return MapType[key_type, value_type] # type: ignore - - if isinstance(dtype, SparkStructType): - subschema = _create_schema(dtype, to_camel_case(name)) - return StructType[subschema] # type: ignore - - if isinstance(dtype, SparkDayTimeIntervalType): - start_field = dtype.startField - end_field = dtype.endField - return DayTimeIntervalType[Literal[start_field], Literal[end_field]] # type: ignore - - if isinstance(dtype, SparkDecimalType): - precision = dtype.precision - scale = dtype.scale - return DecimalType[Literal[precision], Literal[scale]] # type: ignore - - return type(dtype) - - def create_schema( dataframe: DataFrame, schema_name: Optional[str] = None ) -> Tuple[DataSet[Schema], Type[Schema]]: - """This function inferres a ``Schema`` in a notebook based on a the provided ``DataFrame``. + """This function inferres a ``Schema`` in a notebook based on a the provided + ``DataFrame``. This allows for autocompletion on column names, amongst other things. @@ -127,7 +64,7 @@ def create_schema( df, Person = create_schema(df) """ dataframe = _replace_illegal_column_names(dataframe) - schema = _create_schema(dataframe.schema, schema_name) + schema = create_schema_from_structtype(dataframe.schema, schema_name) dataset = DataSet[schema](dataframe) # type: ignore schema = register_schema_to_dataset(dataset, schema) return dataset, schema @@ -136,8 +73,8 @@ def create_schema( def load_table( spark: SparkSession, table_name: str, schema_name: Optional[str] = None ) -> Tuple[DataSet[Schema], Type[Schema]]: - """This function loads a ``DataSet``, along with its inferred ``Schema``, - in a notebook. + """This function loads a ``DataSet``, along with its inferred ``Schema``, in a + notebook. This allows for autocompletion on column names, amongst other things.