Skip to content

Commit

Permalink
Revert "[SPARK-48591][PYTHON] Simplify the if-else branches with `F.l…
Browse files Browse the repository at this point in the history
…it`"

revert #46946 since it may cause circular import issue
```
  File "/home/jenkins/python/pyspark/sql/connect/functions/__init__.py", line 20, in <module>
    from pyspark.sql.connect.functions.builtin import *  # noqa: F401,F403
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jenkins/python/pyspark/sql/connect/functions/builtin.py", line 60, in <module>
    from pyspark.sql.connect.udf import _create_py_udf
  File "/home/jenkins/python/pyspark/sql/connect/udf.py", line 38, in <module>
    from pyspark.sql.connect.column import Column
ImportError: cannot import name 'Column' from partially initialized module 'pyspark.sql.connect.column' (most likely due to a circular import) (/home/jenkins/python/pyspark/sql/connect/column.py)
Had test failures in delta.connect.tests.test_deltatable with python; see logs.
```

Closes #46985 from zhengruifeng/revert_simplify_column.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jun 14, 2024
1 parent 2d2bedf commit aa4bfb0
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
Any,
Union,
Optional,
cast,
)

from pyspark.sql.column import Column as ParentColumn
from pyspark.errors import PySparkTypeError, PySparkAttributeError, PySparkValueError
from pyspark.sql.types import DataType

import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.functions import builtin as F
from pyspark.sql.connect.expressions import (
Expression,
UnresolvedFunction,
Expand Down Expand Up @@ -310,12 +308,14 @@ def when(self, condition: ParentColumn, value: Any) -> ParentColumn:
message_parameters={},
)

return Column(
CaseWhen(
branches=self._expr._branches + [(condition._expr, F.lit(value)._expr)],
else_value=None,
)
)
if isinstance(value, Column):
_value = value._expr
else:
_value = LiteralExpression._from_value(value)

_branches = self._expr._branches + [(condition._expr, _value)]

return Column(CaseWhen(branches=_branches, else_value=None))

def otherwise(self, value: Any) -> ParentColumn:
if not isinstance(self._expr, CaseWhen):
Expand All @@ -328,12 +328,12 @@ def otherwise(self, value: Any) -> ParentColumn:
"otherwise() can only be applied once on a Column previously generated by when()"
)

return Column(
CaseWhen(
branches=self._expr._branches,
else_value=cast(Expression, F.lit(value)._expr),
)
)
if isinstance(value, Column):
_value = value._expr
else:
_value = LiteralExpression._from_value(value)

return Column(CaseWhen(branches=self._expr._branches, else_value=_value))

def like(self: ParentColumn, other: str) -> ParentColumn:
return _bin_op("like", self, other)
Expand Down Expand Up @@ -457,11 +457,14 @@ def isin(self, *cols: Any) -> ParentColumn:
else:
_cols = list(cols)

return Column(
UnresolvedFunction(
"in", [self._expr] + [cast(Expression, F.lit(c)._expr) for c in _cols]
)
)
_exprs = [self._expr]
for c in _cols:
if isinstance(c, Column):
_exprs.append(c._expr)
else:
_exprs.append(LiteralExpression._from_value(c))

return Column(UnresolvedFunction("in", _exprs))

def between(
self,
Expand Down Expand Up @@ -551,8 +554,10 @@ def __getitem__(self, k: Any) -> ParentColumn:
message_parameters={},
)
return self.substr(k.start, k.stop)
elif isinstance(k, Column):
return Column(UnresolvedExtractValue(self._expr, k._expr))
else:
return Column(UnresolvedExtractValue(self._expr, cast(Expression, F.lit(k)._expr)))
return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k)))

def __iter__(self) -> None:
raise PySparkTypeError(
Expand Down

0 comments on commit aa4bfb0

Please sign in to comment.