diff --git a/docs/api.yml b/docs/api.yml index 08d9a5259..ee625d4ed 100644 --- a/docs/api.yml +++ b/docs/api.yml @@ -96,16 +96,11 @@ sidebar: - "api-reference/expressions/typeof" - "api-reference/expressions/when" # - "api-reference/expressions/datetime" - # - "api-reference/expressions/str/lower" - # - "api-reference/expressions/str/upper" - # - "api-reference/expressions/str/endswith" - # - "api-reference/expressions/str/concat" - # - "api-reference/expressions/str/parse" - # - "api-reference/expressions/str/len" - # - "api-reference/expressions/str/contains" - # - "api-reference/expressions/list.len" - # - "api-reference/expressions/list.hasnull" - # - "api-reference/expressions/list.contains" + # - "api-reference/expressions/from_epoch" + + - slug: "api-reference/expressions/dt" + title: "Datetime Expressions" + pages: # - "api-reference/expressions/dt.since" # - "api-reference/expressions/dt.since_epoch" # - "api-reference/expressions/dt.year" @@ -115,14 +110,13 @@ sidebar: # - "api-reference/expressions/dt.minute" # - "api-reference/expressions/dt.second" # - "api-reference/expressions/dt.strftime" - # - "api-reference/expressions/struct" - # - "api-reference/expressions/from_epoch" - # - "api-reference/expressions/struct.get" - - slug: "api-reference/expressions/str" - title: "String Expressions" + - slug: "api-reference/expressions/list" + title: "List Expressions" pages: - - "api-reference/expressions/str/startswith" + # - "api-reference/expressions/list.len" + # - "api-reference/expressions/list.hasnull" + # - "api-reference/expressions/list.contains" - slug: "api-reference/expressions/num" title: "Num Expressions" @@ -132,6 +126,26 @@ sidebar: - "api-reference/expressions/num/floor" - "api-reference/expressions/num/round" + - slug: "api-reference/expressions/str" + title: "String Expressions" + pages: + - "api-reference/expressions/str/concat" + - "api-reference/expressions/str/contains" + - "api-reference/expressions/str/endswith" + - "api-reference/expressions/str/len" + - "api-reference/expressions/str/lower" + - "api-reference/expressions/str/parse" + - "api-reference/expressions/str/startswith" + - "api-reference/expressions/str/strptime" + - "api-reference/expressions/str/upper" + + - slug: "api-reference/expressions/struct" + title: "Struct Expressions" + pages: + - "api-reference/expressions/struct/get" + # - "api-reference/expressions/struct/init" + + - slug: "api-reference/decorators" title: "Decorators" pages: diff --git a/docs/examples/api-reference/expressions/basic.py b/docs/examples/api-reference/expressions/basic.py index 37e8c3e61..344456251 100644 --- a/docs/examples/api-reference/expressions/basic.py +++ b/docs/examples/api-reference/expressions/basic.py @@ -2,25 +2,6 @@ from typing import Optional import pandas as pd -def test_num_abs(): - # docsnip expr_num_abs - from fennel.expr import col - - # docsnip-highlight next-line - expr = col("x").num.abs() - assert expr.typeof(schema={"x": int}) == int - assert expr.typeof(schema={"x": Optional[int]}) == Optional[int] - assert expr.typeof(schema={"x": float}) == float - assert expr.typeof(schema={"x": Optional[float]}) == Optional[float] - - # can be evaluated with a dataframe - df = pd.DataFrame({"x": pd.Series([1, -2, None], dtype=pd.Int64Dtype())}) - assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [1, 2, None] - - with pytest.raises(ValueError): - expr.typeof(schema={"x": str}) - # /docsnip - def test_unary_not(): # docsnip expr_unary_not from fennel.expr import lit @@ -63,7 +44,7 @@ def test_col(): def test_when_then(): # docsnip expr_when_then - from fennel.expr import when, col + from fennel.expr import when, col, InvalidExprException # docsnip-highlight next-line expr = when(col("x")).then(1).otherwise(0) @@ -84,8 +65,8 @@ def test_when_then(): assert expr.eval(df, schema={"x": bool}).tolist() == [1, 0, 1] # not valid if only when is provided - expr = when(col("x")) - with pytest.raises(ValueError): + with pytest.raises(InvalidExprException): + expr = when(col("x")) expr.typeof(schema={"x": bool}) # if otherwise is not provided, it defaults to None @@ -114,7 +95,7 @@ def test_isnull(): import pandas as pd df = pd.DataFrame({"x": pd.Series([1, 2, None], dtype=pd.Int64Dtype())}) - assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [True, False, False] + assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [False, False, True] # /docsnip def test_fillnull(): @@ -153,5 +134,5 @@ def test_lit(): # can be evaluated with a dataframe expr = col("x") + lit(1) df = pd.DataFrame({"x": pd.Series([1, 2, None], dtype=pd.Int64Dtype())}) - assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [2, 3, None] + assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [2, 3, pd.NA] # /docsnip \ No newline at end of file diff --git a/docs/examples/api-reference/expressions/binary.py b/docs/examples/api-reference/expressions/binary.py index 211914990..a8a2fc907 100644 --- a/docs/examples/api-reference/expressions/binary.py +++ b/docs/examples/api-reference/expressions/binary.py @@ -14,14 +14,14 @@ def test_typeof(): df = pd.DataFrame({"x": [1, 2, None]}) expr = lit(1) + col("x") - assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [2, 3, None] + assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [2, 3, pd.NA] expr = lit(1) - col("x") - assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [0, -1, None] + assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [0, -1, pd.NA] expr = lit(1) * col("x") - assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [1, 2, None] + assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [1, 2, pd.NA] expr = lit(1) / col("x") - assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [1, 0.5, None] + assert expr.eval(df, schema={"x": Optional[int]}).tolist() == [1, 0.5, pd.NA] # /docsnip diff --git a/docs/examples/api-reference/expressions/eval.py b/docs/examples/api-reference/expressions/eval.py index ed3ad813c..e9968358a 100644 --- a/docs/examples/api-reference/expressions/eval.py +++ b/docs/examples/api-reference/expressions/eval.py @@ -46,7 +46,7 @@ def test_eval(): # dataframe doesn't have the required column even though schema is provided df = pd.DataFrame({"other": [1, 2, 3]}) - with pytest.raises(KeyError): + with pytest.raises(Exception): expr.eval(df, schema={"amount": int}) # /docsnip \ No newline at end of file diff --git a/docs/examples/api-reference/expressions/num.py b/docs/examples/api-reference/expressions/num.py index 8468d0c19..50d92c59f 100644 --- a/docs/examples/api-reference/expressions/num.py +++ b/docs/examples/api-reference/expressions/num.py @@ -60,4 +60,45 @@ def test_ceil(): with pytest.raises(ValueError): expr.typeof(schema={"x": str}) - # /docsnip \ No newline at end of file + # /docsnip + + +def test_round(): + # docsnip round + from fennel.expr import col + + # docsnip-highlight next-line + expr = col("x").round() # equivalent to col("x").num.round() + + assert expr.typeof(schema={"x": int}) == int + assert expr.typeof(schema={"x": Optional[int]}) == Optional[int] + assert expr.typeof(schema={"x": float}) == int + assert expr.typeof(schema={"x": Optional[float]}) == Optional[int] + + # can be evaluated with a dataframe + df = pd.DataFrame({"x": pd.Series([1.1, -2.3, None])}) + assert expr.eval(df, schema={"x": Optional[float]}).tolist() == [1, -2, pd.NA] + + # can also explicit specify the number of decimals + # docsnip-highlight next-line + expr = col("x").round(1) + + assert expr.typeof(schema={"x": int}) == float + assert expr.typeof(schema={"x": Optional[int]}) == Optional[float] + assert expr.typeof(schema={"x": float}) == float + assert expr.typeof(schema={"x": Optional[float]}) == Optional[float] + + df = pd.DataFrame({"x": pd.Series([1.12, -2.37, None])}) + assert expr.eval(df, schema={"x": Optional[float]}).tolist() == [1.1, -2.4, pd.NA] + + df = pd.DataFrame({"x": pd.Series([1, -2, None])}) + assert expr.eval(df, schema={"x": Optional[float]}).tolist() == [1.0, -2.0, pd.NA] + + # /docsnip + + # invalid number of decimals + with pytest.raises(Exception): + expr = col("x").round(-1) + + with pytest.raises(Exception): + expr = col("x").round(1.1) \ No newline at end of file diff --git a/docs/examples/api-reference/expressions/str.py b/docs/examples/api-reference/expressions/str.py new file mode 100644 index 000000000..537b3ffba --- /dev/null +++ b/docs/examples/api-reference/expressions/str.py @@ -0,0 +1,317 @@ +import pytest +from typing import Optional, List +import pandas as pd + + +def test_concact(): + # docsnip concat + from fennel.expr import col + + # docsnip-highlight next-line + expr = col("x").str.concat(col("y")) + + assert expr.typeof(schema={"x": str, "y": str}) == str + assert expr.typeof(schema={"x": str, "y": Optional[str]}) == Optional[str] + assert expr.typeof(schema={"x": Optional[str], "y": str}) == Optional[str] + assert expr.typeof(schema={"x": Optional[str], "y": Optional[str]}) == Optional[str] + + # can be evaluated with a dataframe + df = pd.DataFrame({ + "x": ["hello", "world", "some", None], + "y": [" world", " hello", None, None], + }) + schema = {"x": Optional[str], "y": Optional[str]} + assert expr.eval(df, schema=schema).tolist() == ["hello world", "world hello", pd.NA, pd.NA] + + # schema of both columns must be str + with pytest.raises(ValueError): + expr.typeof(schema={"x": str}) + + with pytest.raises(Exception): + expr.typeof(schema={"x": str, "y": int}) + # /docsnip + + +def test_contains(): + # docsnip contains + from fennel.expr import col + + # docsnip-highlight next-line + expr = col("x").str.contains(col("y")) + + assert expr.typeof(schema={"x": str, "y": str}) == bool + assert expr.typeof(schema={"x": str, "y": Optional[str]}) == Optional[bool] + assert expr.typeof(schema={"x": Optional[str], "y": str}) == Optional[bool] + assert expr.typeof(schema={"x": Optional[str], "y": Optional[str]}) == Optional[bool] + + # can be evaluated with a dataframe + df = pd.DataFrame({ + "x": ["hello", "world", "some", None], + "y": ["ell", "random", None, None], + }) + schema = {"x": Optional[str], "y": Optional[str]} + assert expr.eval(df, schema=schema).tolist() == [True, False, pd.NA, pd.NA] + + # schema of both columns must be str + with pytest.raises(ValueError): + expr.typeof(schema={"x": str}) + + with pytest.raises(Exception): + expr.typeof(schema={"x": str, "y": int}) + # /docsnip + +def test_startswith(): + # docsnip startswith + from fennel.expr import col + + # docsnip-highlight next-line + expr = col("x").str.startswith(col("y")) + + assert expr.typeof(schema={"x": str, "y": str}) == bool + assert expr.typeof(schema={"x": str, "y": Optional[str]}) == Optional[bool] + assert expr.typeof(schema={"x": Optional[str], "y": str}) == Optional[bool] + assert expr.typeof(schema={"x": Optional[str], "y": Optional[str]}) == Optional[bool] + + # can be evaluated with a dataframe + df = pd.DataFrame({ + "x": ["hello", "world", "some", None], + "y": ["he", "rld", None, None], + }) + schema = {"x": Optional[str], "y": Optional[str]} + assert expr.eval(df, schema=schema).tolist() == [True, False, pd.NA, pd.NA] + + # schema of both columns must be str + with pytest.raises(ValueError): + expr.typeof(schema={"x": str}) + + with pytest.raises(Exception): + expr.typeof(schema={"x": str, "y": int}) + # /docsnip + + +def test_endswith(): + # docsnip endswith + from fennel.expr import col + + # docsnip-highlight next-line + expr = col("x").str.endswith(col("y")) + + assert expr.typeof(schema={"x": str, "y": str}) == bool + assert expr.typeof(schema={"x": str, "y": Optional[str]}) == Optional[bool] + assert expr.typeof(schema={"x": Optional[str], "y": str}) == Optional[bool] + assert expr.typeof(schema={"x": Optional[str], "y": Optional[str]}) == Optional[bool] + + # can be evaluated with a dataframe + df = pd.DataFrame({ + "x": ["hello", "world", "some", None], + "y": ["lo", "wor", None, None], + }) + schema = {"x": Optional[str], "y": Optional[str]} + assert expr.eval(df, schema=schema).tolist() == [True, False, pd.NA, pd.NA] + + # schema of both columns must be str + with pytest.raises(ValueError): + expr.typeof(schema={"x": str}) + + with pytest.raises(Exception): + expr.typeof(schema={"x": str, "y": int}) + # /docsnip + + +def test_lower(): + # docsnip lower + from fennel.expr import col + + # docsnip-highlight next-line + expr = col("x").str.lower() + + assert expr.typeof(schema={"x": str}) == str + assert expr.typeof(schema={"x": Optional[str]}) == Optional[str] + + # can be evaluated with a dataframe + df = pd.DataFrame({"x": ["HeLLo", "World", "some", None]}) + schema = {"x": Optional[str]} + assert expr.eval(df, schema=schema).tolist() == ["hello", "world", "some", pd.NA] + + # schema of column must be str + with pytest.raises(ValueError): + expr.typeof(schema={"x": int}) + # /docsnip + + +def test_upper(): + # docsnip upper + from fennel.expr import col + + # docsnip-highlight next-line + expr = col("x").str.upper() + + assert expr.typeof(schema={"x": str}) == str + assert expr.typeof(schema={"x": Optional[str]}) == Optional[str] + + # can be evaluated with a dataframe + df = pd.DataFrame({"x": ["HeLLo", "World", "some", None]}) + schema = {"x": Optional[str]} + assert expr.eval(df, schema=schema).tolist() == ["HELLO", "WORLD", "SOME", pd.NA] + + # schema of column must be str + with pytest.raises(ValueError): + expr.typeof(schema={"x": int}) + # /docsnip + + +def test_len(): + # docsnip len + from fennel.expr import col + + # docsnip-highlight next-line + expr = col("x").str.len() + + assert expr.typeof(schema={"x": str}) == int + assert expr.typeof(schema={"x": Optional[str]}) == Optional[int] + + # can be evaluated with a dataframe + df = pd.DataFrame({"x": ["hello", "world", "some", None]}) + schema = {"x": Optional[str]} + assert expr.eval(df, schema=schema).tolist() == [5, 5, 4, pd.NA] + + # schema of column must be str + with pytest.raises(ValueError): + expr.typeof(schema={"x": int}) + # /docsnip + + +def test_parse_basic(): + # docsnip parse_basic + from fennel.expr import col, lit + + # docsnip-highlight next-line + expr = col("x").str.parse(list[int]) + + assert expr.typeof(schema={"x": str}) == List[int] + assert expr.typeof(schema={"x": Optional[str]}) == Optional[List[int]] + + # can be evaluated with a dataframe + df = pd.DataFrame({"x": ["[1, 2, 3]", "[4, 5]", None]}) + schema = {"x": Optional[str]} + assert expr.eval(df, schema=schema).tolist() == [[1, 2, 3], [4, 5], pd.NA] + + # schema of column must be str + with pytest.raises(ValueError): + expr.typeof(schema={"x": int}) + + # can use this to parse several common types + df = pd.DataFrame({"x": ["1"]}) + schema = {"x": str} + cases = [ + ("1", int, 1), + ("1.1", float, 1.1), + ("true", bool, True), + ("false", bool, False), + ("\"hi\"", str, "hi"), + ] + for case in cases: + expr = lit(case[0]).str.parse(case[1]) + assert expr.eval(df, schema).tolist() == [case[2]] + # /docsnip + + +def test_parse_invalid(): + # docsnip parse_invalid + from fennel.expr import col, lit + + invalids = [ + ("False", bool), # "False" is not valid json, "false" is + ("hi", str), # "hi" is not valid json, "\"hi\"" is + ("[1, 2, 3", List[int]), + ("1.1.1", float), + ] + for invalid in invalids: + expr = lit(invalid[0]).str.parse(invalid[1]) + df = pd.DataFrame({"x": ["1"]}) + schema = {"x": str} + with pytest.raises(Exception): + expr.eval(df, schema) + # /docsnip + + +def test_parse_struct(): + # docsnip parse_struct + from fennel.expr import col, lit + from fennel.dtypes import struct + + @struct + class MyStruct: + x: int + y: Optional[bool] + + cases = [ + ("{\"x\": 1, \"y\": true}", MyStruct(1, True)), + ("{\"x\": 2, \"y\": null}", MyStruct(2, None)), + ("{\"x\": 3}", MyStruct(3, None)), + ] + for case in cases: + expr = lit(case[0]).str.parse(MyStruct) + df = pd.DataFrame({"x": ["1"]}) + schema = {"x": str} + found = expr.eval(df, schema).tolist() + assert len(found) == 1 + assert found[0].x == case[1].x + assert found[0].y == case[1].y + # /docsnip + + # can also parse a list of structs + df = pd.DataFrame({"x": ["[{\"x\": 1, \"y\": true}, {\"x\": 2, \"y\": null}, null]"]}) + schema = {"x": str} + target = List[Optional[MyStruct]] + expr = col("x").str.parse(target) + found = expr.eval(df, schema).tolist() + assert len(found) == 1 + assert len(found[0]) == 3 + assert found[0][0].x == 1 + assert found[0][0].y == True + assert found[0][1].x == 2 + assert found[0][1].y == None + assert found[0][2] == None + # /docsnip + +def test_strptime(): + # docsnip strptime + from fennel.expr import col + from datetime import datetime + + # docsnip-highlight next-line + expr = col("x").str.strptime("%Y-%m-%d") + + assert expr.typeof(schema={"x": str}) == datetime + assert expr.typeof(schema={"x": Optional[str]}) == Optional[datetime] + + # TODO: replace NaT with pd.NA + # TODO: replace pd.Timestamp with datetime + df = pd.DataFrame({"x": ["2021-01-01", "2021-02-01", None]}) + schema = {"x": Optional[str]} + assert expr.eval(df, schema).tolist() == [ + pd.Timestamp(2021, 1, 1, tz="UTC"), + pd.Timestamp(2021, 2, 1, tz="UTC"), + pd.NaT, + ] + + # can also provide a timezone + expr = col("x").str.strptime("%Y-%m-%d", timezone="Asia/Tokyo") + + assert expr.eval(df, schema).tolist() == [ + pd.Timestamp(2021, 1, 1, tz="Asia/Tokyo"), + pd.Timestamp(2021, 2, 1, tz="Asia/Tokyo"), + pd.NaT, + ] + + # error on invalid format - %L is not a valid format + expr = col("x").str.strptime("%Y-%m-%d %L)") + with pytest.raises(Exception): + expr.eval(df, schema) + + # error on invalid timezone + expr = col("x").str.strptime("%Y-%m-%d", timezone="invalid") + with pytest.raises(Exception): + expr.eval(df, schema) + # /docsnip \ No newline at end of file diff --git a/docs/examples/api-reference/expressions/struct_snip.py b/docs/examples/api-reference/expressions/struct_snip.py new file mode 100644 index 000000000..dd96bf5df --- /dev/null +++ b/docs/examples/api-reference/expressions/struct_snip.py @@ -0,0 +1,34 @@ +import pytest +from typing import Optional, List +import pandas as pd + + +def test_get(): + # docsnip get + from fennel.expr import col + from fennel.dtypes import struct + + @struct + class MyStruct: + f1: int + f2: bool + + # docsnip-highlight next-line + expr = col("x").struct.get("f1") + assert expr.typeof(schema={"x": MyStruct}) == int + + # error to get a field that does not exist + expr = col("x").struct.get("z") + with pytest.raises(ValueError): + expr.typeof(schema={"x": MyStruct}) + + # can be evaluated with a dataframe + df = pd.DataFrame({ + "x": [MyStruct(1, True), MyStruct(2, False), None], + }) + schema = {"x": Optional[MyStruct]} + expr = col("x").struct.get("f1") + result = expr.eval(df, schema=schema) + print(result) + assert expr.eval(df, schema=schema).tolist() == [1, 2, pd.NA] + # /docsnip \ No newline at end of file diff --git a/docs/pages/api-reference/expressions/num/abs.md b/docs/pages/api-reference/expressions/num/abs.md index 028627d05..12acfb8a9 100644 --- a/docs/pages/api-reference/expressions/num/abs.md +++ b/docs/pages/api-reference/expressions/num/abs.md @@ -4,7 +4,7 @@ order: 0 status: published --- -### abs +### Abs Function to get the absolute value of a number. diff --git a/docs/pages/api-reference/expressions/num/ceil.md b/docs/pages/api-reference/expressions/num/ceil.md index 540e1ba33..4da1da6aa 100644 --- a/docs/pages/api-reference/expressions/num/ceil.md +++ b/docs/pages/api-reference/expressions/num/ceil.md @@ -4,7 +4,7 @@ order: 0 status: published --- -### ceil +### Ceil Function in `num` namespace to get the ceil of a number. diff --git a/docs/pages/api-reference/expressions/num/floor.md b/docs/pages/api-reference/expressions/num/floor.md index 22c2124d6..639b5bc21 100644 --- a/docs/pages/api-reference/expressions/num/floor.md +++ b/docs/pages/api-reference/expressions/num/floor.md @@ -4,7 +4,7 @@ order: 0 status: published --- -### floor +### Floor Function in `num` namespace to get the floor of a number. diff --git a/docs/pages/api-reference/expressions/num/round.md b/docs/pages/api-reference/expressions/num/round.md index 83cf0bb0a..ec2430dd3 100644 --- a/docs/pages/api-reference/expressions/num/round.md +++ b/docs/pages/api-reference/expressions/num/round.md @@ -4,23 +4,32 @@ order: 0 status: published --- -### round +### Round Function in `num` namespace to round a number. +#### Parameters + +The number of the decimal places to round the input to. + + #### Returns -Returns an expression object denoting the ceil of the input data. The -data type of the resulting expression is `int` if the input was `int` or `float` -or `Optional[int]` when the input is `Optional[int]` or `Optional[float]`. +Returns an expression object denoting the rounded value of the input data. The +data type of the resulting expression is `int` / `Optional[int]` if precision is +set to `0` or `float` / `Optional[int]` for precisions > 0.
+status="success" message="Rounding a number using Fennel expressions">
 
#### Errors Error during `typeof` or `eval` if the input expression is not of type int, float, optional int or optional float. + + + +Precision must be a non-negative integer. \ No newline at end of file diff --git a/docs/pages/api-reference/expressions/str/concat.md b/docs/pages/api-reference/expressions/str/concat.md new file mode 100644 index 000000000..896a67380 --- /dev/null +++ b/docs/pages/api-reference/expressions/str/concat.md @@ -0,0 +1,34 @@ +--- +title: Concat +order: 0 +status: published +--- + +### Concat + +Function to concatenate two strings. + +#### Parameters + +The string to be concatenated with the base string. + + +
+
+ + +#### Returns + +Returns an expression object denoting the result of the `concact` expression. +The resulting expression is of type `str` or `Optional[str]` depending on +either of input/item being nullable. + + + +#### Errors + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. Similarly, `item` must evaluate to either a string or an +optional of string. + diff --git a/docs/pages/api-reference/expressions/str/contains.md b/docs/pages/api-reference/expressions/str/contains.md new file mode 100644 index 000000000..e70a85ea7 --- /dev/null +++ b/docs/pages/api-reference/expressions/str/contains.md @@ -0,0 +1,34 @@ +--- +title: Contains +order: 0 +status: published +--- + +### Contains + +Function to check if the given string contains another string. + +#### Parameters + +`contains` check if the base string contains `item` or not. + + +
+
+ + +#### Returns + +Returns an expression object denoting the result of the `contains` expression. +The resulting expression is of type `bool` or `Optional[bool]` depending on +either of input/item being nullable. + + + +#### Errors + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. Similarly, `item` must evaluate to either a string or an +optional of string. + \ No newline at end of file diff --git a/docs/pages/api-reference/expressions/str/endswith.md b/docs/pages/api-reference/expressions/str/endswith.md new file mode 100644 index 000000000..ffaca922c --- /dev/null +++ b/docs/pages/api-reference/expressions/str/endswith.md @@ -0,0 +1,34 @@ +--- +title: Ends With +order: 0 +status: published +--- + +### Ends With + +Function to check if the given string ends with the given another string. + +#### Parameters + +`endswith` checks if the input string ends with the expression `item`. + + +
+
+ + +#### Returns + +Returns an expression object denoting the result of the `endswith` expression. +The resulting expression is of type `bool` or `Optional[bool]` depending on +either of input/item being nullable. + + + +#### Errors + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. Similarly, `item` must evaluate to either a string or an +optional of string. + diff --git a/docs/pages/api-reference/expressions/str/len.md b/docs/pages/api-reference/expressions/str/len.md new file mode 100644 index 000000000..5b4a747dd --- /dev/null +++ b/docs/pages/api-reference/expressions/str/len.md @@ -0,0 +1,27 @@ +--- +title: Len +order: 0 +status: published +--- + +### Len + +Function to get the length of a string + +
+
+ +#### Returns + +Returns an expression object denoting the result of the `len` function. +The resulting expression is of type `int` or `Optional[int]` depending on +input being nullable. + + + +#### Errors + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. + diff --git a/docs/pages/api-reference/expressions/str/lower.md b/docs/pages/api-reference/expressions/str/lower.md new file mode 100644 index 000000000..760055271 --- /dev/null +++ b/docs/pages/api-reference/expressions/str/lower.md @@ -0,0 +1,27 @@ +--- +title: Lower +order: 0 +status: published +--- + +### Lower + +Function to convert a string to all lowercase letters. + +
+
+ +#### Returns + +Returns an expression object denoting the result of the `lower` function. +The resulting expression is of type `str` or `Optional[str]` depending on +input being nullable. + + + +#### Errors + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. + diff --git a/docs/pages/api-reference/expressions/str/parse.md b/docs/pages/api-reference/expressions/str/parse.md new file mode 100644 index 000000000..bf30528b8 --- /dev/null +++ b/docs/pages/api-reference/expressions/str/parse.md @@ -0,0 +1,54 @@ +--- +title: Parse +order: 0 +status: published +--- + +### Parse + +Function to parse an object of the given type out of a string that represents json +encoded data. + +#### Parameters + +The type of the data should be parsed from the json encoded string. + + +
+
+ +
+
+ +
+
+ + +#### Returns + +Returns an expression object denoting the result of the `parse` expression. +The resulting expression is of type `dtype` or `Optional[dtype]` depending on +the base string being nullable. + + +:::info +A type can only be parsed out of valid json representation of that type. For +instance, a `str` can not be parsed out of `"hi"` because the correct json +representation of the string is `"\"hi\""`. +::: + + +#### Errors + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. + + + + +If the given string can not be parsed into an object of the given type, a runtime +error is raised. + \ No newline at end of file diff --git a/docs/pages/api-reference/expressions/str/startswith.md b/docs/pages/api-reference/expressions/str/startswith.md index 6d47dfc5f..157beadce 100644 --- a/docs/pages/api-reference/expressions/str/startswith.md +++ b/docs/pages/api-reference/expressions/str/startswith.md @@ -1,35 +1,33 @@ --- -title: String Startswith +title: Starts With order: 0 status: published --- -### str.startswith +### Starts With -Function in `str` namespace to check if the given string starts with another -string. +Function to check if the given string starts with another string. #### Parameters - -The name of the column being referenced. In the case of pipelines, this will -typically be the name of the field and in the case of extractors, this will -be the name of the feature. + +`startswith` checks if the input string starts with the expression `item`. -
+
 
#### Returns -Returns an expression object denoting a reference to the column. The type of -the resulting expression is same as that of the referenced column. When evaluated -in the context of a dataframe, the value of the expression is same as the -value of the dataframe column of that name. +Returns an expression object denoting the result of the `startswith` expression. +The resulting expression is of type `bool` or `Optional[bool]` depending on +either of input/item being nullable. #### Errors - -Error during `typeof` or `eval` if the referenced column isn't present. + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. Similarly, `item` must evaluate to either a string or an +optional of string. diff --git a/docs/pages/api-reference/expressions/str/strptime.md b/docs/pages/api-reference/expressions/str/strptime.md new file mode 100644 index 000000000..233373269 --- /dev/null +++ b/docs/pages/api-reference/expressions/str/strptime.md @@ -0,0 +1,48 @@ +--- +title: Strptime +order: 0 +status: published +--- + +### Strptime + +Function to parse a datetime of the given format out of the string. + +#### Parameters + +A valid datetime format string. See +[here](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) for a +full list of all format qualifiers supported by Fennel. + + + +Sometimes format strings don't precisely specify the timezone. In such cases, +a timezone can be provided. In absence of an explicit timezone, all ambiguous +strings are assumed to be in UTC. + +Note that `timezone` is merely a hint to resolve disambiguity - the timezone +info from the format string is preferentially used when available. + + + +
+
+ +#### Returns + +Returns an expression object denoting the result of the `strptime` expression. +The resulting expression is of type `datetime` or `Optional[datetime]` depending on +either of input/item being nullable. + + + +#### Errors + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. + + + +Compile time error is raised if either of the format string or timezone is invalid. + \ No newline at end of file diff --git a/docs/pages/api-reference/expressions/str/upper.md b/docs/pages/api-reference/expressions/str/upper.md new file mode 100644 index 000000000..316b3d5fb --- /dev/null +++ b/docs/pages/api-reference/expressions/str/upper.md @@ -0,0 +1,27 @@ +--- +title: Upper +order: 0 +status: published +--- + +### Upper + +Function to convert a string to all upper case letters. + +
+
+ +#### Returns + +Returns an expression object denoting the result of the `upper` function. +The resulting expression is of type `str` or `Optional[str]` depending on +input being nullable. + + + +#### Errors + +The `str` namespace must be invoked on an expression that evaluates to string +or optional of string. + diff --git a/docs/pages/api-reference/expressions/struct/get.md b/docs/pages/api-reference/expressions/struct/get.md new file mode 100644 index 000000000..aee67ab2c --- /dev/null +++ b/docs/pages/api-reference/expressions/struct/get.md @@ -0,0 +1,37 @@ +--- +title: Get +order: 0 +status: published +--- + +### Get + +Function to get a given field from a struct. + +#### Parameters + +The name of the field that needs to be obtained from the struct. Note that this +must be a literal string, not an expression. + + +
+
+ +#### Returns + +Returns an expression object denoting the result of the `get` operation. +If the corresponding field in the struct is of type `T`, the resulting expression +is of type `T` or `Optional[T]` depending on the struct itself being nullable. + + + +#### Errors + +The `struct` namespace must be invoked on an expression that evaluates to struct. + + + +Compile error is raised when trying to get a field that doesn't exist on the +struct. + \ No newline at end of file diff --git a/docs/pages/api-reference/expressions/when.md b/docs/pages/api-reference/expressions/when.md index 5403e8cf4..a2896f179 100644 --- a/docs/pages/api-reference/expressions/when.md +++ b/docs/pages/api-reference/expressions/when.md @@ -19,9 +19,11 @@ evaluates to True. `then` must always be called on the result of a `when` expression.
- + The equivalent of `else` branch in the ternary expression - the whole expression -evaluates to this branch when the predicate evaluates to be False. +evaluates to this branch when the predicate evaluates to be False. + +Defaults to `lit(None)` when not provided. diff --git a/fennel/expr/__init__.py b/fennel/expr/__init__.py index d1c29db03..04d0bc03e 100644 --- a/fennel/expr/__init__.py +++ b/fennel/expr/__init__.py @@ -1 +1 @@ -from fennel.expr.expr import col, lit, when, Expr +from fennel.expr.expr import col, lit, when, Expr, InvalidExprException diff --git a/fennel/expr/expr.py b/fennel/expr/expr.py index 6240a3ecb..3537ac3ef 100644 --- a/fennel/expr/expr.py +++ b/fennel/expr/expr.py @@ -12,7 +12,7 @@ from fennel.dtypes.dtypes import FENNEL_STRUCT, FENNEL_STRUCT_SRC_CODE import pandas as pd -from fennel.internal_lib.schema.schema import from_proto, parse_json +from fennel.internal_lib.schema.schema import convert_dtype_to_arrow_type_with_nullable, from_proto, parse_json import pyarrow as pa from fennel_data_lib import eval, type_of @@ -83,7 +83,15 @@ def fillnull(self, value: Any): def abs(self) -> _Number: return _Number(self, Abs()) - def round(self, precision: int) -> _Number: + def round(self, precision: int = 0) -> _Number: + if not isinstance(precision, int): + raise InvalidExprException( + f"precision can only be an int but got {precision} instead" + ) + if precision < 0: + raise InvalidExprException( + f"precision can only be a positive int but got {precision} instead" + ) return _Number(self, Round(precision)) def ceil(self) -> _Number: @@ -107,6 +115,9 @@ def __nonzero__(self): def __bool__(self): raise InvalidExprException("can not convert: '%s' to bool" % self) + def __neg__(self) -> Expr: + return Unary("-", self) + def __add__(self, other: Any) -> Expr: other = make_expr(other) if not isinstance(other, Expr): @@ -345,7 +356,18 @@ def pd_to_pa(pd_data: pd.DataFrame, schema: Dict[str, Type]): ) new_df[column] = cast_col_to_arrow_dtype(new_df[column], dtype) new_df = new_df.loc[:, list(schema.keys())] - return pa.RecordBatch.from_pandas(new_df, preserve_index=False) + fields = [] + for column, dtype in schema.items(): + proto_dtype = get_datatype(dtype) + pa_type, nullable = convert_dtype_to_arrow_type_with_nullable(proto_dtype) + field = pa.field(column, type=pa_type, nullable=nullable) + print("adding field", field, "dtype was", dtype, "proto_dtype", proto_dtype, "nullable", nullable) + print("field is nullable", field.nullable) + fields.append(field) + pa_schema = pa.schema(fields) + print("pa_schema", pa_schema) + print("new_df", new_df) + return pa.RecordBatch.from_pandas(new_df, preserve_index=False, schema=pa_schema) def pa_to_pd(pa_data, ret_type): ret = pa_data.to_pandas(types_mapper=pd.ArrowDtype) @@ -358,6 +380,7 @@ def pa_to_pd(pa_data, ret_type): proto_expr = serializer.serialize(self.root) proto_bytes = proto_expr.SerializeToString() df_pa = pd_to_pa(input_df, schema) + print("df_pa", df_pa) proto_schema = {} for key, value in schema.items(): proto_schema[key] = get_datatype(value).SerializeToString() @@ -443,6 +466,13 @@ class StringOp: class StrContains(StringOp): item: Expr +@dataclass +class StrStartsWith(StringOp): + item: Expr + +@dataclass +class StrEndsWith(StringOp): + item: Expr class Lower(StringOp): pass @@ -477,7 +507,6 @@ class Concat(StringOp): class _String(Expr): - def __init__(self, expr: Expr, op: StringOp): self.op = op self.operand = expr @@ -501,7 +530,7 @@ def len(self) -> _Number: return _Number(_String(self, StrLen()), MathNoop()) def strptime( - self, format: str, timezone: Optional[str] = None + self, format: str, timezone: Optional[str] = "UTC" ) -> _DateTime: return _DateTime( _String(self, StringStrpTime(format, timezone)), DateTimeNoop() @@ -510,6 +539,14 @@ def strptime( def parse(self, dtype: Type) -> Expr: return _String(self, StringParse(dtype)) + def startswith(self, item) -> _Bool: + item_expr = make_expr(item) + return _Bool(_String(self, StrStartsWith(item_expr))) + + def endswith(self, item) -> _Bool: + item_expr = make_expr(item) + return _Bool(_String(self, StrEndsWith(item_expr))) + ######################################################### # Dict Functions @@ -582,12 +619,12 @@ def __init__(self, expr: Expr, op: DateTimeOp): self.operand = expr super(_Struct, self).__init__() - def get(self, key: str) -> Expr: - if not isinstance(key, str): + def get(self, field: str) -> Expr: + if not isinstance(field, str): raise InvalidExprException( - f"invalid field access for struct, expected string but got {key}" + f"invalid field access for struct, expected string but got {field}" ) - return _Struct(self, StructGet(key)) + return _Struct(self, StructGet(field)) ######################################################### @@ -781,7 +818,7 @@ def __init__(self, c: Any, type: Type): class Unary(Expr): def __init__(self, op: str, operand: Any): - valid = ("~", "len", "str") + valid = ("~", "-") if op not in valid: raise InvalidExprException( "unary expressions only support %s but given '%s'" @@ -797,10 +834,7 @@ def __init__(self, op: str, operand: Any): super(Unary, self).__init__() def __str__(self) -> str: - if self.op in ["len", "str"]: - return f"{self.op}({self.operand})" - else: - return f"{self.op}{self.operand}" + return f"{self.op} {self.operand}" class Binary(Expr): @@ -953,14 +987,14 @@ def lit(v: Any, type: Optional[Type] = None) -> Expr: # TODO: Add support for more types recursively if type is not None: return Literal(v, type) + elif isinstance(v, bool): + return Literal(v, bool) elif isinstance(v, int): return Literal(v, int) elif isinstance(v, float): return Literal(v, float) elif isinstance(v, str): return Literal(v, str) - elif isinstance(v, bool): - return Literal(v, bool) elif v is None: return Literal(v, None) # type: ignore else: diff --git a/fennel/expr/serializer.py b/fennel/expr/serializer.py index e73685af1..58f346f96 100644 --- a/fennel/expr/serializer.py +++ b/fennel/expr/serializer.py @@ -45,6 +45,8 @@ StrLen, StringStrpTime, StringParse, + StrStartsWith, + StrEndsWith, Lower, Upper, StrContains, @@ -104,7 +106,12 @@ def visitRef(self, obj): def visitUnary(self, obj): expr = proto.Expr() - expr.unary.op = obj.op + if obj.op == "~": + expr.unary.op = proto.UnaryOp.NOT + elif obj.op == "-": + expr.unary.op = proto.UnaryOp.NEG + else: + raise Exception("invalid unary operation: %s" % obj.op) operand = self.visit(obj.operand) expr.unary.operand.CopyFrom(operand) return expr @@ -262,6 +269,18 @@ def visitString(self, obj): ) ) ) + elif isinstance(obj.op, StrStartsWith): + expr.string_fn.fn.CopyFrom( + proto.StringOp( + startswith=proto.StartsWith(key=self.visit(obj.op.item)) + ) + ) + elif isinstance(obj.op, StrEndsWith): + expr.string_fn.fn.CopyFrom( + proto.StringOp( + endswith=proto.EndsWith(key=self.visit(obj.op.item)) + ) + ) else: raise InvalidExprException("invalid string operation: %s" % obj.op) expr.string_fn.string.CopyFrom(self.visit(obj.operand)) diff --git a/fennel/expr/test_expr.py b/fennel/expr/test_expr.py index 2af57eb39..f8b8c0c5c 100644 --- a/fennel/expr/test_expr.py +++ b/fennel/expr/test_expr.py @@ -36,9 +36,29 @@ def test_basic_expr1(): ref_extractor.visit(expr.root) assert ref_extractor.refs == {"num", "d"} +def test_unary_expr(): + invert = ~col("a") + assert invert.typeof({"a": bool}) == bool + df = pd.DataFrame({"a": [True, False, True, False]}) + ret = invert.eval(df, {"a": bool}) + assert ret.tolist() == [False, True, False, True] + ref_extractor = FetchReferences() + ref_extractor.visit(invert.root) + assert ref_extractor.refs == {"a"} -def test_basic_expr2(): + negate = -col("a") + assert negate.typeof({"a": int}) == int + assert negate.typeof({"a": float}) == float + assert negate.typeof({"a": Optional[float]}) == Optional[float] + df = pd.DataFrame({"a": [1, 2, 3, 4]}) + ret = negate.eval(df, {"a": int}) + assert ret.tolist() == [-1, -2, -3, -4] + ref_extractor = FetchReferences() + ref_extractor.visit(negate.root) + assert ref_extractor.refs == {"a"} + +def test_basic_expr2(): expr = col("a") + col("b") + 3 printer = ExprPrinter() expected = "((col('a') + col('b')) + 3)" @@ -441,7 +461,7 @@ def test_datetime_expr(): {"a": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04"]} ), schema={"a": str}, - display="STRPTIME(col('a'), %Y-%m-%d)", + display="STRPTIME(col('a'), %Y-%m-%d, UTC)", refs={"a"}, eval_result=[ pd.Timestamp("2021-01-01 00:00:00+0000", tz="UTC"), @@ -580,10 +600,9 @@ def test_parse(): check_test_case(test_case) # Parse strings - # TODO(Aditya): Check why this is failing test_case = ExprTestCase( expr=(col("a").str.parse(str)), - df=pd.DataFrame({"a": ["a1", "b", "c", "d"]}), + df=pd.DataFrame({"a": ["\"a1\"", "\"b\"", "\"c\"", "\"d\""]}), schema={"a": str}, display="PARSE(col('a'), )", refs={"a"}, @@ -591,7 +610,7 @@ def test_parse(): expected_dtype=str, proto_json=None, ) - # check_test_case(test_case) + check_test_case(test_case) def test_list(): @@ -692,21 +711,27 @@ def test_list(): expected_dtype=bool, proto_json=None, ), - # (TODO: Aditya) Support for struct inside a list # Support struct inside a list - # ExprTestCase( - # #expr=(col("a").list.contains(make_struct({"x": 1, "y": 2, "z": "a"}, A))), - # expr=(col("a").list.len()), - # df=pd.DataFrame({"a": [[A(1, 2, "a"), A(2, 3, "b"), A(4, 5, "c")]]}), - # schema={"a": List[A]}, - # display="LEN(col('a'))", - # #display="""CONTAINS(col('a'), STRUCT(x=1, y=2, z="a"))""", - # refs={"a"}, - # eval_result=[True, False, False], - # expected_dtype=int, - # #expected_dtype=bool, - # proto_json=None, - # ), + ExprTestCase( + expr=(col("a").list.contains(make_struct({"x": 1, "y": 2, "z": "a"}, A))), + df=pd.DataFrame({"a": [[A(1, 2, "a"), A(2, 3, "b"), A(4, 5, "c")]]}), + schema={"a": List[A]}, + display="""CONTAINS(col('a'), STRUCT(x=1, y=2, z="a"))""", + refs={"a"}, + eval_result=[True], + expected_dtype=bool, + proto_json=None, + ), + ExprTestCase( + expr=(col("a").list.len()), + df=pd.DataFrame({"a": [[A(1, 2, "a"), A(2, 3, "b"), A(4, 5, "c")]]}), + schema={"a": List[A]}, + display="LEN(col('a'))", + refs={"a"}, + eval_result=[3], + expected_dtype=int, + proto_json=None, + ), # List length ExprTestCase( expr=(col("a").list.len()), @@ -858,7 +883,7 @@ def test_datetime(): } ), schema={"a": datetime}, - display="""SINCE(col('a'), STRPTIME("2021-01-01 00:01:00+0000", %Y-%m-%d %H:%M:%S%z), unit=TimeUnit.DAY)""", + display="""SINCE(col('a'), STRPTIME("2021-01-01 00:01:00+0000", %Y-%m-%d %H:%M:%S%z, UTC), unit=TimeUnit.DAY)""", refs={"a"}, eval_result=[0, 32, 61], expected_dtype=int, @@ -884,7 +909,7 @@ def test_datetime(): } ), schema={"a": datetime}, - display="""SINCE(col('a'), STRPTIME("2021-01-01 00:01:00+0000", %Y-%m-%d %H:%M:%S%z), unit=TimeUnit.YEAR)""", + display="""SINCE(col('a'), STRPTIME("2021-01-01 00:01:00+0000", %Y-%m-%d %H:%M:%S%z, UTC), unit=TimeUnit.YEAR)""", refs={"a"}, eval_result=[0, 0, 5], expected_dtype=int, @@ -1137,7 +1162,7 @@ def test_fillnull(): ), df=pd.DataFrame({"a": ["2021-01-01", None, "2021-01-03"]}), schema={"a": Optional[str]}, - display="""FILL_NULL(STRPTIME(col('a'), %Y-%m-%d), STRPTIME("2021-01-01", %Y-%m-%d, UTC))""", + display="""FILL_NULL(STRPTIME(col('a'), %Y-%m-%d, UTC), STRPTIME("2021-01-01", %Y-%m-%d, UTC))""", refs={"a"}, eval_result=[ pd.Timestamp("2021-01-01 00:00:00+0000", tz="UTC"), @@ -1154,50 +1179,64 @@ def test_fillnull(): def test_isnull(): cases = [ - ExprTestCase( - expr=(col("a").isnull()), - df=pd.DataFrame({"a": [1, 2, None, 4]}), - schema={"a": Optional[int]}, - display="IS_NULL(col('a'))", - refs={"a"}, - eval_result=[False, False, True, False], - expected_dtype=bool, - proto_json=None, - ), - ExprTestCase( - expr=(col("a").isnull()), - df=pd.DataFrame({"a": ["a", "b", None, "d"]}), - schema={"a": Optional[str]}, - display="IS_NULL(col('a'))", - refs={"a"}, - eval_result=[False, False, True, False], - expected_dtype=bool, - proto_json=None, - ), - # Each type is a struct - # TODO(Aditya): Fix this test case # ExprTestCase( # expr=(col("a").isnull()), - # df=pd.DataFrame({"a": [A(1, 2, "a"), A(2, 3, "b"), pd.NA]}), - # schema={"a": Optional[A]}, + # df=pd.DataFrame({"a": [1, 2, None, 4]}), + # schema={"a": Optional[int]}, # display="IS_NULL(col('a'))", # refs={"a"}, - # eval_result=[False, False, True], + # eval_result=[False, False, True, False], # expected_dtype=bool, # proto_json=None, # ), - # Each type is a list + # ExprTestCase( + # expr=(col("a").isnull()), + # df=pd.DataFrame({"a": ["a", "b", None, "d"]}), + # schema={"a": Optional[str]}, + # display="IS_NULL(col('a'))", + # refs={"a"}, + # eval_result=[False, False, True, False], + # expected_dtype=bool, + # proto_json=None, + # ), + # Each type is a struct + # TODO(Aditya): Fix this test case ExprTestCase( expr=(col("a").isnull()), - df=pd.DataFrame({"a": [[1, 2, 3], [4, 5, 6], None]}), - schema={"a": Optional[List[int]]}, + df=pd.DataFrame({"a": [A(1, 2, "a"), A(2, 3, "b"), None]}), + schema={"a": Optional[A]}, display="IS_NULL(col('a'))", refs={"a"}, eval_result=[False, False, True], expected_dtype=bool, proto_json=None, ), + # Each type is a list + # ExprTestCase( + # expr=(col("a").isnull()), + # df=pd.DataFrame({"a": [[1, 2, 3], [4, 5, 6], None]}), + # schema={"a": Optional[List[int]]}, + # display="IS_NULL(col('a'))", + # refs={"a"}, + # eval_result=[False, False, True], + # expected_dtype=bool, + # proto_json=None, + # ), ] for case in cases: check_test_case(case) + + +def test_complex_struct_parse(): + from fennel.dtypes import struct + @struct + class A: + x: int + y: int + z: Optional[bool] + + expr = col("a").str.parse(A).struct.get("z") + + # df = df.assign(**{json_col: lambda x: x[payload_col].fillna("{}").apply(json.loads)}) +# and then extract fields from json_col \ No newline at end of file diff --git a/fennel/expr/visitor.py b/fennel/expr/visitor.py index 5521d2e31..6e7cf4c1e 100644 --- a/fennel/expr/visitor.py +++ b/fennel/expr/visitor.py @@ -336,7 +336,7 @@ def visitRef(self, obj): self.refs.add(obj._col) def visitUnary(self, obj): - self.visit(obj.expr) + self.visit(obj.operand) def visitBinary(self, obj): self.visit(obj.left) diff --git a/fennel/internal_lib/schema/schema.py b/fennel/internal_lib/schema/schema.py index 79f5b1085..a4b4408f3 100644 --- a/fennel/internal_lib/schema/schema.py +++ b/fennel/internal_lib/schema/schema.py @@ -256,52 +256,60 @@ def get_python_type_from_pd(type): ] return type - -def convert_dtype_to_arrow_type(dtype: schema_proto.DataType) -> pa.DataType: +def convert_dtype_to_arrow_type_with_nullable(dtype: schema_proto.DataType) -> Tuple[pa.DataType, bool]: if dtype.HasField("optional_type"): - return convert_dtype_to_arrow_type(dtype.optional_type.of) + inner, _ = convert_dtype_to_arrow_type_with_nullable(dtype.optional_type.of) + return inner, True elif dtype.HasField("int_type"): - return pa.int64() + return pa.int64(), False elif dtype.HasField("double_type"): - return pa.float64() + return pa.float64(), False elif dtype.HasField("string_type") or dtype.HasField("regex_type"): - return pa.string() + return pa.string(), False elif dtype.HasField("bytes_type"): - return pa.binary() + return pa.binary(), False elif dtype.HasField("bool_type"): - return pa.bool_() + return pa.bool_(), False elif dtype.HasField("timestamp_type"): - return pa.timestamp("ns", "UTC") + return pa.timestamp("ns", "UTC"), False elif dtype.HasField("date_type"): - return pa.date32() + return pa.date32(), False elif dtype.HasField("decimal_type"): - return pa.decimal128(28, dtype.decimal_type.scale) + return pa.decimal128(28, dtype.decimal_type.scale), False elif dtype.HasField("array_type"): - return pa.list_( - value_type=convert_dtype_to_arrow_type(dtype.array_type.of) - ) + inner, nullable = convert_dtype_to_arrow_type_with_nullable(dtype.array_type.of) + field = pa.field("item", inner, nullable) + return pa.list_(field), False elif dtype.HasField("map_type"): - key_pa_type = convert_dtype_to_arrow_type(dtype.map_type.key) - value_pa_type = convert_dtype_to_arrow_type(dtype.map_type.value) - return pa.map_(key_pa_type, value_pa_type, False) + key_pa_type, nullable = convert_dtype_to_arrow_type_with_nullable(dtype.map_type.key) + key_field = pa.field("key", key_pa_type, nullable) + value_pa_type, nullable = convert_dtype_to_arrow_type_with_nullable(dtype.map_type.value) + value_field = pa.field("value", value_pa_type, nullable) + return pa.map_(key_field, value_field, False), False elif dtype.HasField("embedding_type"): embedding_size = dtype.embedding_type.embedding_size - return pa.list_(pa.float64(), embedding_size) + field = pa.field("item", pa.float64(), False) + return pa.list_(field) elif dtype.HasField("one_of_type"): - return convert_dtype_to_arrow_type(dtype.one_of_type.of) + return convert_dtype_to_arrow_type_with_nullable(dtype.one_of_type.of) elif dtype.HasField("between_type"): - return convert_dtype_to_arrow_type(dtype.between_type.dtype) + return convert_dtype_to_arrow_type_with_nullable(dtype.between_type.dtype) elif dtype.HasField("struct_type"): fields: List[Tuple[str, pa.DataType]] = [] for field in dtype.struct_type.fields: - fields.append( - (field.name, convert_dtype_to_arrow_type(field.dtype)) - ) - return pa.struct(fields) + inner, nullable = convert_dtype_to_arrow_type_with_nullable(field.dtype) + field = pa.field(field.name, inner, nullable) + fields.append(field) + return pa.struct(fields), False else: raise TypeError(f"Invalid dtype: {dtype}.") +def convert_dtype_to_arrow_type(dtype: schema_proto.DataType) -> pa.DataType: + atype, nullable = convert_dtype_to_arrow_type_with_nullable(dtype) + return atype + + def check_val_is_null(val: Any) -> bool: if isinstance(val, (list, tuple, dict, set, np.ndarray, frozendict)): return False @@ -933,13 +941,23 @@ def cast_col_to_arrow_dtype( # Let's convert structs into json, this is done because arrow # dtype conversion fails with fennel struct - if check_dtype_has_struct_type(dtype): - series = series.apply(lambda x: parse_struct_into_dict(x)) + print("in cast col to arrow dtype", dtype, check_dtype_has_struct_type(dtype)) # Parse datetime values series = series.apply(lambda x: parse_datetime_in_value(x, dtype)) - arrow_type = convert_dtype_to_arrow_type(dtype) - return series.astype(pd.ArrowDtype(arrow_type)) - + if check_dtype_has_struct_type(dtype): + before = series + series = series.apply(lambda x: parse_struct_into_dict(x)) + print(f"Converting struct into json: {before} -> {series}") + arrow_type, nullable = convert_dtype_to_arrow_type_with_nullable(dtype) + print("going for final conversion", series, arrow_type, nullable) + temp = series.astype(pd.ArrowDtype(arrow_type)) + if nullable: + print("beffore setting na", temp) + temp[series.isnull()] = pa.NA + print("after setting na", temp) + return temp + else: + return temp def check_dtype_has_struct_type(dtype: schema_proto.DataType) -> bool: if dtype.HasField("struct_type"): diff --git a/fennel/internal_lib/utils/utils.py b/fennel/internal_lib/utils/utils.py index ea675a7a0..c06d20e32 100644 --- a/fennel/internal_lib/utils/utils.py +++ b/fennel/internal_lib/utils/utils.py @@ -144,6 +144,8 @@ def parse_struct_into_dict(value: Any) -> Union[dict, list]: return [parse_struct_into_dict(x) for x in value] elif isinstance(value, dict) or isinstance(value, frozendict): return {key: parse_struct_into_dict(val) for key, val in value.items()} + elif value is None or pd.isna(value): + return None else: return value