Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
nanne-aben committed Aug 14, 2023
1 parent 83ebda6 commit 69c14f0
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 99 deletions.
118 changes: 50 additions & 68 deletions test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,105 +11,87 @@
"text": [
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"23/07/03 18:53:46 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----+\n",
"| age|\n",
"+----+\n",
"|null|\n",
"|null|\n",
"|null|\n",
"+----+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
"23/08/14 20:09:46 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
}
],
"source": [
"from typing import Protocol, TypeVar\n",
"from pyspark.sql import SparkSession\n",
"from pyspark.sql.types import LongType, StringType\n",
"from typedspark import Schema, DataSet, Column, PartialDataSet, transform_to_schema, create_empty_dataset\n",
"from typedspark import (\n",
" Schema,\n",
" DataSet,\n",
" Column,\n",
" DataSetExtends,\n",
" DataSetImplements,\n",
" transform_to_schema,\n",
" create_empty_dataset,\n",
")\n",
"\n",
"\n",
"class Person(Schema):\n",
" name: Column[StringType]\n",
" age: Column[LongType]\n",
"\n",
"\n",
"class Job(Schema):\n",
" role: Column[StringType]\n",
"\n",
"\n",
"class Age(Schema, Protocol):\n",
" type: Column[StringType]\n",
" age: Column[LongType]\n",
"\n",
"def get_age(df: PartialDataSet[Age]) -> DataSet[Age]:\n",
" return transform_to_schema(df, Age)\n",
"\n",
"spark = SparkSession.builder.getOrCreate()\n",
"\n",
"df = create_empty_dataset(spark, Person)\n",
"get_age(df).show()"
"person = create_empty_dataset(spark, Person)\n",
"job = create_empty_dataset(spark, Job)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----+----+\n",
"|name| age|\n",
"+----+----+\n",
"|null|null|\n",
"|null|null|\n",
"|null|null|\n",
"+----+----+\n",
"\n"
]
}
],
"outputs": [],
"source": [
"T = TypeVar(\"T\", bound=Schema)\n",
"def get_age(df: DataSetExtends[Age]) -> DataSet[Age]:\n",
" return transform_to_schema(df, Age)\n",
"\n",
"def birthday(df: PartialDataSet[Age], schema: T) -> DataSet[T]:\n",
" return transform_to_schema(\n",
" df, \n",
" schema, # type: ignore\n",
" {Age.age: Age.age + 1}\n",
" )\n",
"\n",
"res: DataSet[Person] = birthday(df, Person)\n",
"res.show()"
"get_age(person)\n",
"try:\n",
" get_age(job) # linting error: Job is not a subtype of Age\n",
"except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Base class for protocol classes.\\n\\nProtocol classes are defined as::\\n\\n class Proto(Protocol):\\n def meth(self) -> int:\\n ...\\n\\nSuch classes are primarily used with static type checkers that recognize\\nstructural subtyping (static duck-typing), for example::\\n\\n class C:\\n def meth(self) -> int:\\n return 0\\n\\n def func(x: Proto) -> int:\\n return x.meth()\\n\\n func(C()) # Passes static type check\\n\\nSee PEP 544 for details. Protocol classes decorated with\\[email protected]_checkable act as simple-minded runtime protocols that check\\nonly the presence of given attributes, ignoring their type signatures.\\nProtocol classes can be generic, they are defined as::\\n\\n class GenProto(Protocol[T]):\\n def meth(self) -> T:\\n ...'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"Age.get_docstring()"
"T = TypeVar(\"T\", bound=Schema)\n",
"\n",
"\n",
"def birthday(\n",
" df: DataSetImplements[Age, T],\n",
") -> DataSet[T]:\n",
" return transform_to_schema(\n",
" df,\n",
" df.typedspark_schema, # type: ignore\n",
" {\n",
" Age.age: Age.age + 1,\n",
" },\n",
" )\n",
"\n",
"\n",
"res = birthday(person)\n",
"try:\n",
" res = birthday(job) # linting error: Job is not a subtype of Age\n",
"except:\n",
" pass"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions tests/_schema/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_get_snake_case():


def test_get_docstring():
assert A.get_docstring() is None
assert A.get_docstring() == ""
assert PascalCase.get_docstring() == "Schema docstring."


Expand All @@ -125,7 +125,7 @@ def test_get_structtype():
def test_get_dlt_kwargs():
assert A.get_dlt_kwargs() == DltKwargs(
name="a",
comment=None,
comment="",
schema=StructType(
[StructField("a", LongType(), True), StructField("b", StringType(), True)]
),
Expand Down
5 changes: 3 additions & 2 deletions typedspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typedspark._core.column import Column
from typedspark._core.column_meta import ColumnMeta
from typedspark._core.dataset import DataSet, PartialDataSet
from typedspark._core.dataset import DataSet, DataSetExtends, DataSetImplements
from typedspark._core.datatypes import (
ArrayType,
DayTimeIntervalType,
Expand Down Expand Up @@ -36,7 +36,8 @@
"IntervalType",
"MapType",
"MetaSchema",
"PartialDataSet",
"DataSetExtends",
"DataSetImplements",
"Schema",
"StructType",
"create_empty_dataset",
Expand Down
33 changes: 19 additions & 14 deletions typedspark/_core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,23 @@
from typedspark._core.validate_schema import validate_schema
from typedspark._schema.schema import Schema

T = TypeVar("T", bound=Schema)
_Schema = TypeVar("_Schema", bound=Schema)
_ReturnType = TypeVar("_ReturnType", bound=DataFrame) # pylint: disable=C0103
P = ParamSpec("P")

V = TypeVar("V", bound=Schema, covariant=True)
_Protocol = TypeVar("_Protocol", bound=Schema, covariant=True)
_Implementation = TypeVar("_Implementation", bound=Schema, covariant=True)


class PartialDataSet(DataFrame, Generic[V]):
pass
class DataSetImplements(DataFrame, Generic[_Protocol, _Implementation]):
"""TODO."""


class DataSet(PartialDataSet, Generic[T]):
class DataSetExtends(DataSetImplements[_Protocol, _Protocol], Generic[_Protocol]):
""" "TODO."""


class DataSet(DataSetExtends[_Schema]):
"""``DataSet`` subclasses pyspark ``DataFrame`` and hence has all the same
functionality, with in addition the possibility to define a schema.
Expand All @@ -47,7 +52,7 @@ def foo(df: DataSet[Person]) -> DataSet[Person]:
return df
"""

def __new__(cls, dataframe: DataFrame) -> "DataSet[T]":
def __new__(cls, dataframe: DataFrame) -> "DataSet[_Schema]":
"""``__new__()`` instantiates the object (prior to ``__init__()``).
Here, we simply take the provided ``df`` and cast it to a
Expand Down Expand Up @@ -93,18 +98,18 @@ def _add_schema_metadata(self) -> None:
self.schema[field.name].metadata = field.metadata

@property
def typedspark_schema(self) -> Type[T]:
def typedspark_schema(self) -> Type[_Schema]:
"""Returns the ``Schema`` of the ``DataSet``."""
return self._schema_annotations # type: ignore

"""The following functions are equivalent to their parents in ``DataFrame``, but since they
don't affect the ``Schema``, we can add type annotations here. We're omitting docstrings,
such that the docstring from the parent will appear."""

def distinct(self) -> "DataSet[T]": # pylint: disable=C0116
def distinct(self) -> "DataSet[_Schema]": # pylint: disable=C0116
return DataSet[self._schema_annotations](super().distinct()) # type: ignore

def filter(self, condition) -> "DataSet[T]": # pylint: disable=C0116
def filter(self, condition) -> "DataSet[_Schema]": # pylint: disable=C0116
return DataSet[self._schema_annotations](super().filter(condition)) # type: ignore

@overload
Expand All @@ -126,7 +131,7 @@ def join(
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: Literal["semi"] = ...,
) -> "DataSet[T]":
) -> "DataSet[_Schema]":
... # pragma: no cover

@overload
Expand All @@ -150,13 +155,13 @@ def join( # pylint: disable=C0116
) -> DataFrame:
return super().join(other, on, how) # type: ignore

def orderBy(self, *args, **kwargs) -> "DataSet[T]": # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103
def orderBy(self, *args, **kwargs) -> "DataSet[_Schema]": # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103
return DataSet[self._schema_annotations](super().orderBy(*args, **kwargs)) # type: ignore

@overload
def transform(
self,
func: Callable[Concatenate["DataSet[T]", P], _ReturnType],
func: Callable[Concatenate["DataSet[_Schema]", P], _ReturnType],
*args: P.args,
**kwargs: P.kwargs,
) -> _ReturnType:
Expand All @@ -174,9 +179,9 @@ def transform( # pylint: disable=C0116
@overload
def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: "DataSet[T]",
other: "DataSet[_Schema]",
allowMissingColumns: Literal[False] = ..., # noqa: N803
) -> "DataSet[T]":
) -> "DataSet[_Schema]":
... # pragma: no cover

@overload
Expand Down
9 changes: 4 additions & 5 deletions typedspark/_schema/get_schema_definition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Module to output a string with the ``Schema`` definition of a given
``DataFrame``."""
"""Module to output a string with the ``Schema`` definition of a given ``DataFrame``."""
from __future__ import annotations

import re
Expand Down Expand Up @@ -57,7 +56,7 @@ def _build_schema_definition_string(

def _create_docstring(schema: Type[Schema]) -> str:
"""Create the docstring for a given ``Schema``."""
if schema.get_docstring() is not None:
if schema.get_docstring() != "":
docstring = f' """{schema.get_docstring()}"""\n\n'
else:
docstring = ' """Add documentation here."""\n\n'
Expand Down Expand Up @@ -140,8 +139,8 @@ def _replace_literal(


def _add_subschemas(schema: Type[Schema], add_subschemas: bool, include_documentation: bool) -> str:
"""Identifies whether any ``Column`` are of the ``StructType`` type and
generates their schema recursively."""
"""Identifies whether any ``Column`` are of the ``StructType`` type and generates
their schema recursively."""
lines = ""
for val in get_type_hints(schema).values():
args = get_args(val)
Expand Down
13 changes: 5 additions & 8 deletions typedspark/_schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
_ProtocolMeta,
get_args,
get_type_hints,
runtime_checkable,
)

from pyspark.sql import DataFrame
Expand Down Expand Up @@ -116,8 +115,8 @@ def all_column_names(cls) -> List[str]:
return list(get_type_hints(cls).keys())

def all_column_names_except_for(cls, except_for: List[str]) -> List[str]:
"""Returns all column names for a given schema except for the columns
specified in the ``except_for`` parameter."""
"""Returns all column names for a given schema except for the columns specified
in the ``except_for`` parameter."""
return list(name for name in get_type_hints(cls).keys() if name not in except_for)

def get_snake_case(cls) -> str:
Expand Down Expand Up @@ -177,8 +176,7 @@ def get_structtype(cls) -> StructType:
)

def get_dlt_kwargs(cls, name: Optional[str] = None) -> DltKwargs:
"""Creates a representation of the ``Schema`` to be used by Delta Live
Tables.
"""Creates a representation of the ``Schema`` to be used by Delta Live Tables.
.. code-block:: python
Expand All @@ -198,10 +196,9 @@ def get_schema_name(cls):


class Schema(Protocol, metaclass=MetaSchema):
# pylint: disable=missing-class-docstring
# pylint: disable=empty-docstring
# Since docstrings are inherrited, and since we use docstrings to
# annotate tables (see MetaSchema.get_dlt_kwargs()), we have chosen
# to add an empty docstring to the Schema class (otherwise the Schema
# docstring would be added to any schema without a docstring).
__doc__ = None
pass
""""""

0 comments on commit 69c14f0

Please sign in to comment.