diff --git a/requirements-dev.txt b/requirements-dev.txt index 76e988a..7f9b373 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,7 +8,7 @@ black[jupyter]==23.10.1 isort==5.12.0 docformatter==1.7.5 mypy==1.6.1 -pyright==1.1.333 +pyright==1.1.334 autoflake==2.2.1 # stubs pandas-stubs==2.1.1.230928 @@ -20,7 +20,7 @@ pandas==2.1.2 setuptools==68.2.2 chispa==0.9.4 # notebooks -nbconvert==7.9.2 +nbconvert==7.11.0 jupyter==1.0.0 nbformat==5.9.2 # readthedocs diff --git a/setup.py b/setup.py index 9fa9432..7af67d7 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ def get_long_description(): description="Column-wise type annotations for pyspark DataFrames", keywords="pyspark spark typing type checking annotations", long_description=get_long_description(), - long_description_content_type="text/x-rst", + long_description_content_type="text/markdown", packages=find_packages(include=["typedspark", "typedspark.*"]), install_requires=get_requirements(), python_requires=">=3.9.0", diff --git a/tests/_utils/test_register_schema_to_dataset.py b/tests/_utils/test_register_schema_to_dataset.py index 5679d28..91bbf9b 100644 --- a/tests/_utils/test_register_schema_to_dataset.py +++ b/tests/_utils/test_register_schema_to_dataset.py @@ -46,6 +46,7 @@ def test_register_schema_to_dataset(spark: SparkSession): job = register_schema_to_dataset(df_b, Job) assert person.get_schema_name() == "Person" + assert hash(person.a) != hash(Person.a) df_a.join(df_b, person.a == job.a) diff --git a/typedspark/_core/validate_schema.py b/typedspark/_core/validate_schema.py index 47c0cef..5de41e6 100644 --- a/typedspark/_core/validate_schema.py +++ b/typedspark/_core/validate_schema.py @@ -1,8 +1,10 @@ """Module containing functions that are related to validating schema's at runtime.""" -from typing import Dict, Set +from typing import Dict from pyspark.sql.types import ArrayType, DataType, MapType, StructField, StructType +from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype + def validate_schema( structtype_expected: StructType, structtype_observed: StructType, schema_name: str @@ -11,7 +13,7 @@ def validate_schema( expected = unpack_schema(structtype_expected) observed = unpack_schema(structtype_observed) - check_names(set(expected.keys()), set(observed.keys()), schema_name) + check_names(expected, observed, schema_name) check_dtypes(expected, observed, schema_name) @@ -31,15 +33,26 @@ def unpack_schema(schema: StructType) -> Dict[str, StructField]: return res -def check_names(names_expected: Set[str], names_observed: Set[str], schema_name: str) -> None: +def check_names( + expected: Dict[str, StructField], observed: Dict[str, StructField], schema_name: str +) -> None: """Checks whether the observed and expected list of column names overlap. Is order insensitive. """ + names_observed = set(observed.keys()) + names_expected = set(expected.keys()) + diff = names_observed - names_expected if diff: + diff_schema = create_schema_from_structtype( + StructType([observed[colname] for colname in diff]), schema_name + ) raise TypeError( - f"Data contains the following columns not present in schema {schema_name}: {diff}" + f"Data contains the following columns not present in schema {schema_name}: {diff}.\n\n" + "If you believe these columns should be part of the schema, consider adding the " + "following lines to it.\n\n" + f"{diff_schema.get_schema_definition_as_string(generate_imports=False)}" ) diff = names_expected - names_observed diff --git a/typedspark/_utils/create_dataset_from_structtype.py b/typedspark/_utils/create_dataset_from_structtype.py new file mode 100644 index 0000000..15102f6 --- /dev/null +++ b/typedspark/_utils/create_dataset_from_structtype.py @@ -0,0 +1,70 @@ +"""Utility functions for creating a ``Schema`` from a ``StructType``""" +from typing import Dict, Literal, Optional, Type + +from pyspark.sql.types import ArrayType as SparkArrayType +from pyspark.sql.types import DataType +from pyspark.sql.types import DayTimeIntervalType as SparkDayTimeIntervalType +from pyspark.sql.types import DecimalType as SparkDecimalType +from pyspark.sql.types import MapType as SparkMapType +from pyspark.sql.types import StructType as SparkStructType + +from typedspark._core.column import Column +from typedspark._core.datatypes import ( + ArrayType, + DayTimeIntervalType, + DecimalType, + MapType, + StructType, +) +from typedspark._schema.schema import MetaSchema, Schema +from typedspark._utils.camelcase import to_camel_case + + +def create_schema_from_structtype( + structtype: SparkStructType, schema_name: Optional[str] = None +) -> Type[Schema]: + """Dynamically builds a ``Schema`` based on a ``DataFrame``'s ``StructType``""" + type_annotations = {} + attributes: Dict[str, None] = {} + for column in structtype: + name = column.name + data_type = _extract_data_type(column.dataType, name) + type_annotations[name] = Column[data_type] # type: ignore + attributes[name] = None + + if not schema_name: + schema_name = "DynamicallyLoadedSchema" + + schema = MetaSchema(schema_name, tuple([Schema]), attributes) + schema.__annotations__ = type_annotations + + return schema # type: ignore + + +def _extract_data_type(dtype: DataType, name: str) -> Type[DataType]: + """Given an instance of a ``DataType``, it extracts the corresponding ``DataType`` + class, potentially including annotations (e.g. ``ArrayType[StringType]``).""" + if isinstance(dtype, SparkArrayType): + element_type = _extract_data_type(dtype.elementType, name) + return ArrayType[element_type] # type: ignore + + if isinstance(dtype, SparkMapType): + key_type = _extract_data_type(dtype.keyType, name) + value_type = _extract_data_type(dtype.valueType, name) + return MapType[key_type, value_type] # type: ignore + + if isinstance(dtype, SparkStructType): + subschema = create_schema_from_structtype(dtype, to_camel_case(name)) + return StructType[subschema] # type: ignore + + if isinstance(dtype, SparkDayTimeIntervalType): + start_field = dtype.startField + end_field = dtype.endField + return DayTimeIntervalType[Literal[start_field], Literal[end_field]] # type: ignore + + if isinstance(dtype, SparkDecimalType): + precision = dtype.precision + scale = dtype.scale + return DecimalType[Literal[precision], Literal[scale]] # type: ignore + + return type(dtype) diff --git a/typedspark/_utils/load_table.py b/typedspark/_utils/load_table.py index 12c7937..12cf599 100644 --- a/typedspark/_utils/load_table.py +++ b/typedspark/_utils/load_table.py @@ -1,27 +1,13 @@ """Functions for loading `DataSet` and `Schema` in notebooks.""" import re -from typing import Dict, Literal, Optional, Tuple, Type +from typing import Dict, Optional, Tuple, Type from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import ArrayType as SparkArrayType -from pyspark.sql.types import DataType -from pyspark.sql.types import DayTimeIntervalType as SparkDayTimeIntervalType -from pyspark.sql.types import DecimalType as SparkDecimalType -from pyspark.sql.types import MapType as SparkMapType -from pyspark.sql.types import StructType as SparkStructType - -from typedspark._core.column import Column + from typedspark._core.dataset import DataSet -from typedspark._core.datatypes import ( - ArrayType, - DayTimeIntervalType, - DecimalType, - MapType, - StructType, -) -from typedspark._schema.schema import MetaSchema, Schema -from typedspark._utils.camelcase import to_camel_case +from typedspark._schema.schema import Schema +from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset @@ -64,54 +50,6 @@ def _replace_illegal_characters(column_name: str) -> str: return re.sub("[^A-Za-z0-9]", "_", column_name) -def _create_schema(structtype: SparkStructType, schema_name: Optional[str] = None) -> Type[Schema]: - """Dynamically builds a ``Schema`` based on a ``DataFrame``'s ``StructType``""" - type_annotations = {} - attributes: Dict[str, None] = {} - for column in structtype: - name = column.name - data_type = _extract_data_type(column.dataType, name) - type_annotations[name] = Column[data_type] # type: ignore - attributes[name] = None - - if not schema_name: - schema_name = "DynamicallyLoadedSchema" - - schema = MetaSchema(schema_name, tuple([Schema]), attributes) - schema.__annotations__ = type_annotations - - return schema # type: ignore - - -def _extract_data_type(dtype: DataType, name: str) -> Type[DataType]: - """Given an instance of a ``DataType``, it extracts the corresponding ``DataType`` - class, potentially including annotations (e.g. ``ArrayType[StringType]``).""" - if isinstance(dtype, SparkArrayType): - element_type = _extract_data_type(dtype.elementType, name) - return ArrayType[element_type] # type: ignore - - if isinstance(dtype, SparkMapType): - key_type = _extract_data_type(dtype.keyType, name) - value_type = _extract_data_type(dtype.valueType, name) - return MapType[key_type, value_type] # type: ignore - - if isinstance(dtype, SparkStructType): - subschema = _create_schema(dtype, to_camel_case(name)) - return StructType[subschema] # type: ignore - - if isinstance(dtype, SparkDayTimeIntervalType): - start_field = dtype.startField - end_field = dtype.endField - return DayTimeIntervalType[Literal[start_field], Literal[end_field]] # type: ignore - - if isinstance(dtype, SparkDecimalType): - precision = dtype.precision - scale = dtype.scale - return DecimalType[Literal[precision], Literal[scale]] # type: ignore - - return type(dtype) - - def create_schema( dataframe: DataFrame, schema_name: Optional[str] = None ) -> Tuple[DataSet[Schema], Type[Schema]]: @@ -126,7 +64,7 @@ def create_schema( df, Person = create_schema(df) """ dataframe = _replace_illegal_column_names(dataframe) - schema = _create_schema(dataframe.schema, schema_name) + schema = create_schema_from_structtype(dataframe.schema, schema_name) dataset = DataSet[schema](dataframe) # type: ignore schema = register_schema_to_dataset(dataset, schema) return dataset, schema