Skip to content

Commit

Permalink
auto convert Python return value to Java value v2 (deephaven#4647)
Browse files Browse the repository at this point in the history
* Auto conv Py UDF returns and numba guvectorized

* Add numba guvectorize tests

* Harden the code for bad type annotations

* More meaningful names

* Clean up test code

* Apply suggestions from code review

Co-authored-by: Chip Kent <[email protected]>

* Add support for datetime data

* Update py/server/deephaven/dtypes.py

Co-authored-by: Chip Kent <[email protected]>

* Respond to review comments and array testcases

* Respond to reivew comments

* Respond to new review comments

* new test case for complex(unsupported) typehint

* Fix a couple of regressions

* Responding to review comments

---------

Co-authored-by: Chip Kent <[email protected]>
  • Loading branch information
jmao-denver and chipkent authored Oct 17, 2023
1 parent ec6d013 commit 4b792cf
Show file tree
Hide file tree
Showing 13 changed files with 709 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.github.javaparser.ast.comments.BlockComment;
import com.github.javaparser.ast.comments.JavadocComment;
import com.github.javaparser.ast.comments.LineComment;
import com.github.javaparser.ast.comments.Comment;
import com.github.javaparser.ast.expr.ArrayAccessExpr;
import com.github.javaparser.ast.expr.ArrayCreationExpr;
import com.github.javaparser.ast.expr.ArrayInitializerExpr;
Expand Down Expand Up @@ -123,6 +122,7 @@
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.time.Instant;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -1214,7 +1214,7 @@ private Expression[] convertParameters(final Executable executable,
// as a single argument
if (ObjectVector.class.isAssignableFrom(expressionTypes[ei])) {
expressions[ei] = new CastExpr(
new ClassOrInterfaceType("java.lang.Object"),
StaticJavaParser.parseClassOrInterfaceType("java.lang.Object"),
expressions[ei]);
expressionTypes[ei] = Object.class;
} else {
Expand Down Expand Up @@ -2023,7 +2023,7 @@ public Class<?> visit(ConditionalExpr n, VisitArgs printer) {
if (classA == boolean.class && classB == Boolean.class) {
// a little hacky, but this handles the null case where it unboxes. very weird stuff
final Expression uncastExpr = n.getThenExpr();
final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr);
final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr);
n.setThenExpr(castExpr);
// fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr)
uncastExpr.setParentNode(castExpr);
Expand All @@ -2032,7 +2032,7 @@ public Class<?> visit(ConditionalExpr n, VisitArgs printer) {
if (classA == Boolean.class && classB == boolean.class) {
// a little hacky, but this handles the null case where it unboxes. very weird stuff
final Expression uncastExpr = n.getElseExpr();
final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr);
final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr);
n.setElseExpr(castExpr);
// fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr)
uncastExpr.setParentNode(castExpr);
Expand Down Expand Up @@ -2159,7 +2159,8 @@ public Class<?> visit(FieldAccessExpr n, VisitArgs printer) {
printer.append(", " + clsName + ".class");

final ClassExpr targetType =
new ClassExpr(new ClassOrInterfaceType(null, printer.pythonCastContext.getSimpleName()));
new ClassExpr(
StaticJavaParser.parseClassOrInterfaceType(printer.pythonCastContext.getSimpleName()));
getAttributeArgs.add(targetType);

// Let's advertise to the caller the cast context type
Expand Down Expand Up @@ -2337,6 +2338,35 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
callMethodCall.setData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS,
new PyCallableDetails(null, methodName));

if (PyCallableWrapper.class.isAssignableFrom(method.getDeclaringClass())) {
final Optional<Class<?>> optionalRetType = pyCallableReturnType(callMethodCall);
if (optionalRetType.isPresent()) {
Class<?> retType = optionalRetType.get();
final Optional<CastExpr> optionalCastExpr =
makeCastExpressionForPyCallable(retType, callMethodCall);
if (optionalCastExpr.isPresent()) {
final CastExpr castExpr = optionalCastExpr.get();
replaceChildExpression(
n.getParentNode().orElseThrow(),
n,
castExpr);

callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).setCasted(true);
try {
return castExpr.accept(this, printer);
} catch (Exception e) {
// exceptions could be thrown by {@link #tryVectorizePythonCallable}
replaceChildExpression(
castExpr.getParentNode().orElseThrow(),
castExpr,
callMethodCall);
callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)
.setCasted(false);
return callMethodCall.accept(this, printer);
}
}
}
}
replaceChildExpression(
n.getParentNode().orElseThrow(),
n,
Expand Down Expand Up @@ -2376,7 +2406,7 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {

final ObjectCreationExpr newPyCallableExpr = new ObjectCreationExpr(
null,
new ClassOrInterfaceType(null, pyCallableWrapperImplName),
StaticJavaParser.parseClassOrInterfaceType(pyCallableWrapperImplName),
NodeList.nodeList(getAttributeCall));

final MethodCallExpr callMethodCall = new MethodCallExpr(
Expand Down Expand Up @@ -2430,6 +2460,60 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
typeArguments);
}

private Optional<CastExpr> makeCastExpressionForPyCallable(Class<?> retType, MethodCallExpr callMethodCall) {
if (retType.isPrimitive()) {
return Optional.of(new CastExpr(
new PrimitiveType(PrimitiveType.Primitive
.valueOf(retType.getSimpleName().toUpperCase())),
callMethodCall));
} else if (retType.getComponentType() != null) {
final Class<?> componentType = retType.getComponentType();
if (componentType.isPrimitive()) {
ArrayType arrayType;
if (componentType == boolean.class) {
arrayType = new ArrayType(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"));
} else {
arrayType = new ArrayType(new PrimitiveType(PrimitiveType.Primitive
.valueOf(retType.getComponentType().getSimpleName().toUpperCase())));
}
return Optional.of(new CastExpr(arrayType, callMethodCall));
} else if (retType.getComponentType() == String.class || retType.getComponentType() == Boolean.class
|| retType.getComponentType() == Instant.class) {
ArrayType arrayType =
new ArrayType(
StaticJavaParser.parseClassOrInterfaceType(retType.getComponentType().getSimpleName()));
return Optional.of(new CastExpr(arrayType, callMethodCall));
}
} else if (retType == Boolean.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"), callMethodCall));
} else if (retType == String.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.String"), callMethodCall));
} else if (retType == Instant.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.time.Instant"), callMethodCall));
}

return Optional.empty();
}

private Optional<Class<?>> pyCallableReturnType(@NotNull MethodCallExpr n) {
final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS);
final String pyMethodName = pyCallableDetails.pythonMethodName;
final QueryScope queryScope = ExecutionContext.getContext().getQueryScope();
final Object paramValueRaw = queryScope.readParamValue(pyMethodName, null);
if (paramValueRaw == null) {
return Optional.empty();
}
if (!(paramValueRaw instanceof PyCallableWrapper)) {
return Optional.empty();
}
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw;
pyCallableWrapper.parseSignature();
return Optional.ofNullable(pyCallableWrapper.getReturnType());
}

@NotNull
private static Expression[] getExpressionsArray(final NodeList<Expression> exprNodeList) {
return exprNodeList == null ? new Expression[0]
Expand Down Expand Up @@ -2558,18 +2642,37 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
@NotNull final PyCallableWrapper pyCallableWrapper) {

pyCallableWrapper.parseSignature();
if (!pyCallableWrapper.isVectorizableReturnType()) {
throw new PythonCallVectorizationFailure(
"Python function return type is not supported: " + pyCallableWrapper.getReturnType());
}

// Python vectorized functions(numba, DH) return arrays of primitive/Object types. This will break the generated
// expression evaluation code that expects singular values. This check makes sure that numba/dh vectorized
// functions must be used alone as the entire expression after removing the enclosing parentheses.

Node n1 = n;
boolean autoCastChecked = false;
while (n1.hasParentNode()) {
n1 = n1.getParentNode().orElseThrow();
Class<?> cls = n1.getClass();

if (cls == CastExpr.class) {
throw new PythonCallVectorizationFailure(
"The return values of Python vectorized function can't be cast: " + n1);
if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()) {
autoCastChecked = true;
} else {
throw new PythonCallVectorizationFailure(
"The return values of Python vectorized functions can't be cast: " + n1);
}
} else if (cls == MethodCallExpr.class) {
String methodName = ((MethodCallExpr) n1).getNameAsString();
if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()
&& methodName.endsWith("Cast")) {
autoCastChecked = true;
} else {
throw new PythonCallVectorizationFailure(
"Python vectorized function can't be used in another expression: " + n1);
}
} else if (cls != EnclosedExpr.class && cls != WrapperNode.class) {
throw new PythonCallVectorizationFailure(
"Python vectorized function can't be used in another expression: " + n1);
Expand Down Expand Up @@ -3233,6 +3336,17 @@ private static class PyCallableDetails {
@NotNull
private final String pythonMethodName;

@NotNull
private boolean isCasted = false;

public boolean isCasted() {
return isCasted;
}

public void setCasted(boolean casted) {
isCasted = casted;
}

private PyCallableDetails(@Nullable String pythonScopeExpr, @NotNull String pythonMethodName) {
this.pythonScopeExpr = pythonScopeExpr;
this.pythonMethodName = pythonMethodName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ public Object getValue() {
}
}

boolean isVectorizableReturnType();
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import org.jpy.PyModule;
import org.jpy.PyObject;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs
Expand All @@ -20,6 +22,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private static final Logger log = LoggerFactory.getLogger(PyCallableWrapperJpyImpl.class);

private static final PyObject NUMBA_VECTORIZED_FUNC_TYPE = getNumbaVectorizedFuncType();
private static final PyObject NUMBA_GUVECTORIZED_FUNC_TYPE = getNumbaGUVectorizedFuncType();

private static final PyModule dh_table_module = PyModule.importModule("deephaven.table");

Expand All @@ -34,9 +37,21 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
numpyType2JavaClass.put('b', byte.class);
numpyType2JavaClass.put('?', boolean.class);
numpyType2JavaClass.put('U', String.class);
numpyType2JavaClass.put('M', Instant.class);
numpyType2JavaClass.put('O', Object.class);
}

// TODO: support for vectorizing functions that return arrays
// https://github.com/deephaven/deephaven-core/issues/4649
private static final Set<Class<?>> vectorizableReturnTypes = Set.of(int.class, long.class, short.class, float.class,
double.class, byte.class, Boolean.class, String.class, Instant.class, PyObject.class);

@Override
public boolean isVectorizableReturnType() {
parseSignature();
return vectorizableReturnTypes.contains(returnType);
}

private final PyObject pyCallable;

private String signature = null;
Expand All @@ -47,6 +62,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private Collection<ChunkArgument> chunkArguments;
private boolean numbaVectorized;
private PyObject unwrapped;
private PyObject pyUdfDecoratedCallable;

public PyCallableWrapperJpyImpl(PyObject pyCallable) {
this.pyCallable = pyCallable;
Expand Down Expand Up @@ -91,23 +107,37 @@ private static PyObject getNumbaVectorizedFuncType() {
}
}

private static PyObject getNumbaGUVectorizedFuncType() {
try {
return PyModule.importModule("numba.np.ufunc.gufunc").getAttribute("GUFunc");
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("Numba isn't installed in the Python environment.");
}
return null;
}
}

private void prepareSignature() {
if (pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE)) {
boolean isNumbaVectorized = pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE);
boolean isNumbaGUVectorized = pyCallable.equals(NUMBA_GUVECTORIZED_FUNC_TYPE);
if (isNumbaGUVectorized || isNumbaVectorized) {
List<PyObject> params = pyCallable.getAttribute("types").asList();
if (params.isEmpty()) {
throw new IllegalArgumentException(
"numba vectorized function must have an explicit signature: " + pyCallable);
"numba vectorized/guvectorized function must have an explicit signature: " + pyCallable);
}
// numba allows a vectorized function to have multiple signatures
if (params.size() > 1) {
throw new UnsupportedOperationException(
pyCallable
+ " has multiple signatures; this is not currently supported for numba vectorized functions");
+ " has multiple signatures; this is not currently supported for numba vectorized/guvectorized functions");
}
signature = params.get(0).getStringValue();
unwrapped = null;
numbaVectorized = true;
vectorized = true;
unwrapped = pyCallable;
// since vectorization doesn't support array type parameters, don't flag numba guvectorized as vectorized
numbaVectorized = isNumbaVectorized;
vectorized = isNumbaVectorized;
} else if (pyCallable.hasAttribute("dh_vectorized")) {
signature = pyCallable.getAttribute("signature").toString();
unwrapped = pyCallable.getAttribute("callable");
Expand All @@ -119,6 +149,7 @@ private void prepareSignature() {
numbaVectorized = false;
vectorized = false;
}
pyUdfDecoratedCallable = dh_table_module.call("_py_udf", unwrapped);
}

@Override
Expand All @@ -135,14 +166,6 @@ public void parseSignature() {
throw new IllegalStateException("Signature should always be available.");
}

char numpyTypeCode = signature.charAt(signature.length() - 1);
Class<?> returnType = numpyType2JavaClass.get(numpyTypeCode);
if (returnType == null) {
throw new IllegalStateException(
"Vectorized functions should always have an integral, floating point, boolean, String, or Object return type: "
+ numpyTypeCode);
}

List<Class<?>> paramTypes = new ArrayList<>();
for (char numpyTypeChar : signature.toCharArray()) {
if (numpyTypeChar != '-') {
Expand All @@ -159,25 +182,31 @@ public void parseSignature() {
}

this.paramTypes = paramTypes;
if (returnType == Object.class) {
this.returnType = PyObject.class;
} else if (returnType == boolean.class) {

returnType = pyUdfDecoratedCallable.getAttribute("return_type", null);
if (returnType == null) {
throw new IllegalStateException(
"Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type");
}

if (returnType == boolean.class) {
this.returnType = Boolean.class;
} else {
this.returnType = returnType;
}
}

// In vectorized mode, we want to call the vectorized function directly.
public PyObject vectorizedCallable() {
if (numbaVectorized) {
if (numbaVectorized || vectorized) {
return pyCallable;
} else {
return dh_table_module.call("dh_vectorize", unwrapped);
}
}

// In non-vectorized mode, we want to call the udf decorated function or the original function.
@Override
public Object call(Object... args) {
PyObject pyCallable = this.pyUdfDecoratedCallable != null ? this.pyUdfDecoratedCallable : this.pyCallable;
return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public void addChunkArgument(ChunkArgument ignored) {}

@Override
public Class<?> getReturnType() {
throw new UnsupportedOperationException();
return Object.class;
}

@Override
public boolean isVectorizableReturnType() {
return false;
}
}
Loading

0 comments on commit 4b792cf

Please sign in to comment.