Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce DataSetImplements #146

Merged
merged 23 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Advanced Topics
subclassing_schemas
column_ambiguity
advanced_linting_support
dataset_implements
193 changes: 193 additions & 0 deletions docs/source/dataset_implements.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d3c03896",
"metadata": {},
"source": [
"# Transformations for all schemas with a given column using DataSetImplements\n",
"\n",
"Let's illustrate this with an example! First, we'll define some data."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "91c48423",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"\n",
"spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
"spark.sparkContext.setLogLevel(\"ERROR\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "99c453be",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql.types import LongType, StringType\n",
"from typedspark import (\n",
" Schema,\n",
" Column,\n",
" create_empty_dataset,\n",
")\n",
"\n",
"\n",
"class Person(Schema):\n",
" name: Column[StringType]\n",
" age: Column[LongType]\n",
" job: Column[StringType]\n",
"\n",
"\n",
"class Pet(Schema):\n",
" name: Column[StringType]\n",
" age: Column[LongType]\n",
" type: Column[StringType]\n",
"\n",
"\n",
"class Fruit(Schema):\n",
" type: Column[StringType]\n",
"\n",
"\n",
"person = create_empty_dataset(spark, Person)\n",
"pet = create_empty_dataset(spark, Pet)\n",
"fruit = create_empty_dataset(spark, Fruit)"
]
},
{
"cell_type": "markdown",
"id": "ca634c83",
"metadata": {},
"source": [
"Now, suppose we want to define a function `birthday()` that works on all schemas that contain the column `age`. With `DataSet`, we'd have to specifically indicate which schemas contain the `age` column. We could do this with for example:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c948c8d3",
"metadata": {},
"outputs": [],
"source": [
"from typing import TypeVar, Union\n",
"\n",
"from typedspark import DataSet, transform_to_schema\n",
"\n",
"T = TypeVar(\"T\", bound=Union[Person, Pet])\n",
"\n",
"\n",
"def birthday(df: DataSet[T]) -> DataSet[T]:\n",
" return transform_to_schema(\n",
" df,\n",
" df.typedspark_schema,\n",
" {Person.age: Person.age + 1},\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "9784804d",
"metadata": {},
"source": [
"This can get tedious if the list of schemas with the column `age` changes, for example because new schemas are added, or because the `age` column is removed from a schema! It's also not great that we're using `Person.age` here to define the `age` column...\n",
"\n",
"Fortunately, we can do better! Consider the following example:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b2436106",
"metadata": {},
"outputs": [],
"source": [
"from typing import Protocol\n",
"\n",
"from typedspark import DataSetImplements\n",
"\n",
"\n",
"class Age(Schema, Protocol):\n",
" age: Column[LongType]\n",
"\n",
"\n",
"T = TypeVar(\"T\", bound=Schema)\n",
"\n",
"\n",
"def birthday(df: DataSetImplements[Age, T]) -> DataSet[T]:\n",
" return transform_to_schema(\n",
" df,\n",
" df.typedspark_schema,\n",
" {Age.age: Age.age + 1},\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "2088742c",
"metadata": {},
"source": [
"Here, we define `Age` to be both a `Schema` and a `Protocol` ([PEP-0544](https://peps.python.org/pep-0544/)). \n",
"\n",
"We then define `birthday()` to:\n",
"1. Take as an input `DataSetImplements[Age, T]`: a `DataSet` that implements the protocol `Age` as `T`. \n",
"2. Return a `DataSet[T]`: a `DataSet` of the same type as the one that was provided.\n",
"\n",
"Let's see this in action!"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5658210f",
"metadata": {},
"outputs": [],
"source": [
"# returns a DataSet[Person]\n",
"happy_person = birthday(person)\n",
"\n",
"# returns a DataSet[Pet]\n",
"happy_pet = birthday(pet)\n",
"\n",
"try:\n",
" # Raises a linting error:\n",
" # Argument of type \"DataSet[Fruit]\" cannot be assigned to\n",
" # parameter \"df\" of type \"DataSetImplements[Age, T@birthday]\"\n",
" birthday(fruit)\n",
"except Exception as e:\n",
" pass"
]
},
{
"cell_type": "markdown",
"id": "5bfb99ed",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "typedspark",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
6 changes: 6 additions & 0 deletions tests/_core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pyspark.sql.types import LongType, StringType

from typedspark import Column, DataSet, Schema
from typedspark._core.dataset import DataSetImplements
from typedspark._utils.create_dataset import create_empty_dataset


Expand Down Expand Up @@ -88,3 +89,8 @@ def test_inherrited_functions_with_other_dataset(spark: SparkSession):
def test_schema_property_of_dataset(spark: SparkSession):
df = create_empty_dataset(spark, A)
assert df.typedspark_schema == A


def test_initialize_dataset_implements(spark: SparkSession):
with pytest.raises(NotImplementedError):
DataSetImplements()
4 changes: 2 additions & 2 deletions tests/_schema/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_get_snake_case():


def test_get_docstring():
assert A.get_docstring() is None
assert A.get_docstring() == ""
assert PascalCase.get_docstring() == "Schema docstring."


Expand All @@ -125,7 +125,7 @@ def test_get_structtype():
def test_get_dlt_kwargs():
assert A.get_dlt_kwargs() == DltKwargs(
name="a",
comment=None,
comment="",
schema=StructType(
[StructField("a", LongType(), True), StructField("b", StringType(), True)]
),
Expand Down
3 changes: 2 additions & 1 deletion typedspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typedspark._core.column import Column
from typedspark._core.column_meta import ColumnMeta
from typedspark._core.dataset import DataSet
from typedspark._core.dataset import DataSet, DataSetImplements
from typedspark._core.datatypes import (
ArrayType,
DayTimeIntervalType,
Expand Down Expand Up @@ -39,6 +39,7 @@
"IntervalType",
"MapType",
"MetaSchema",
"DataSetImplements",
"Schema",
"StructType",
"create_empty_dataset",
Expand Down
Loading