From 4b792cf7ae1558f2177475b7687d3d1bfce15c4a Mon Sep 17 00:00:00 2001 From: Jianfeng Mao <4297243+jmao-denver@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:10:20 -0600 Subject: [PATCH] auto convert Python return value to Java value v2 (#4647) * Auto conv Py UDF returns and numba guvectorized * Add numba guvectorize tests * Harden the code for bad type annotations * More meaningful names * Clean up test code * Apply suggestions from code review Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> * Add support for datetime data * Update py/server/deephaven/dtypes.py Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> * Respond to review comments and array testcases * Respond to reivew comments * Respond to new review comments * new test case for complex(unsupported) typehint * Fix a couple of regressions * Responding to review comments --------- Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> --- .../table/impl/lang/QueryLanguageParser.java | 130 ++++++++++- .../engine/util/PyCallableWrapper.java | 1 + .../engine/util/PyCallableWrapperJpyImpl.java | 69 ++++-- .../impl/lang/PyCallableWrapperDummyImpl.java | 7 +- py/server/deephaven/column.py | 28 +-- py/server/deephaven/dtypes.py | 122 +++++++++- py/server/deephaven/table.py | 94 ++++++-- py/server/tests/test_numba_guvectorize.py | 94 ++++++++ .../tests/test_numba_vectorized_column.py | 2 +- .../tests/test_numba_vectorized_filter.py | 2 +- .../tests/test_pyfunc_return_java_values.py | 208 ++++++++++++++++++ py/server/tests/test_table_factory.py | 31 ++- py/server/tests/test_vectorization.py | 6 +- 13 files changed, 709 insertions(+), 85 deletions(-) create mode 100644 py/server/tests/test_numba_guvectorize.py create mode 100644 py/server/tests/test_pyfunc_return_java_values.py diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 6b166b2d797..71addc2fdaa 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -20,7 +20,6 @@ import com.github.javaparser.ast.comments.BlockComment; import com.github.javaparser.ast.comments.JavadocComment; import com.github.javaparser.ast.comments.LineComment; -import com.github.javaparser.ast.comments.Comment; import com.github.javaparser.ast.expr.ArrayAccessExpr; import com.github.javaparser.ast.expr.ArrayCreationExpr; import com.github.javaparser.ast.expr.ArrayInitializerExpr; @@ -123,6 +122,7 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; +import java.time.Instant; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -1214,7 +1214,7 @@ private Expression[] convertParameters(final Executable executable, // as a single argument if (ObjectVector.class.isAssignableFrom(expressionTypes[ei])) { expressions[ei] = new CastExpr( - new ClassOrInterfaceType("java.lang.Object"), + StaticJavaParser.parseClassOrInterfaceType("java.lang.Object"), expressions[ei]); expressionTypes[ei] = Object.class; } else { @@ -2023,7 +2023,7 @@ public Class visit(ConditionalExpr n, VisitArgs printer) { if (classA == boolean.class && classB == Boolean.class) { // a little hacky, but this handles the null case where it unboxes. very weird stuff final Expression uncastExpr = n.getThenExpr(); - final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr); + final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr); n.setThenExpr(castExpr); // fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr) uncastExpr.setParentNode(castExpr); @@ -2032,7 +2032,7 @@ public Class visit(ConditionalExpr n, VisitArgs printer) { if (classA == Boolean.class && classB == boolean.class) { // a little hacky, but this handles the null case where it unboxes. very weird stuff final Expression uncastExpr = n.getElseExpr(); - final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr); + final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr); n.setElseExpr(castExpr); // fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr) uncastExpr.setParentNode(castExpr); @@ -2159,7 +2159,8 @@ public Class visit(FieldAccessExpr n, VisitArgs printer) { printer.append(", " + clsName + ".class"); final ClassExpr targetType = - new ClassExpr(new ClassOrInterfaceType(null, printer.pythonCastContext.getSimpleName())); + new ClassExpr( + StaticJavaParser.parseClassOrInterfaceType(printer.pythonCastContext.getSimpleName())); getAttributeArgs.add(targetType); // Let's advertise to the caller the cast context type @@ -2337,6 +2338,35 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { callMethodCall.setData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS, new PyCallableDetails(null, methodName)); + if (PyCallableWrapper.class.isAssignableFrom(method.getDeclaringClass())) { + final Optional> optionalRetType = pyCallableReturnType(callMethodCall); + if (optionalRetType.isPresent()) { + Class retType = optionalRetType.get(); + final Optional optionalCastExpr = + makeCastExpressionForPyCallable(retType, callMethodCall); + if (optionalCastExpr.isPresent()) { + final CastExpr castExpr = optionalCastExpr.get(); + replaceChildExpression( + n.getParentNode().orElseThrow(), + n, + castExpr); + + callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).setCasted(true); + try { + return castExpr.accept(this, printer); + } catch (Exception e) { + // exceptions could be thrown by {@link #tryVectorizePythonCallable} + replaceChildExpression( + castExpr.getParentNode().orElseThrow(), + castExpr, + callMethodCall); + callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS) + .setCasted(false); + return callMethodCall.accept(this, printer); + } + } + } + } replaceChildExpression( n.getParentNode().orElseThrow(), n, @@ -2376,7 +2406,7 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { final ObjectCreationExpr newPyCallableExpr = new ObjectCreationExpr( null, - new ClassOrInterfaceType(null, pyCallableWrapperImplName), + StaticJavaParser.parseClassOrInterfaceType(pyCallableWrapperImplName), NodeList.nodeList(getAttributeCall)); final MethodCallExpr callMethodCall = new MethodCallExpr( @@ -2430,6 +2460,60 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { typeArguments); } + private Optional makeCastExpressionForPyCallable(Class retType, MethodCallExpr callMethodCall) { + if (retType.isPrimitive()) { + return Optional.of(new CastExpr( + new PrimitiveType(PrimitiveType.Primitive + .valueOf(retType.getSimpleName().toUpperCase())), + callMethodCall)); + } else if (retType.getComponentType() != null) { + final Class componentType = retType.getComponentType(); + if (componentType.isPrimitive()) { + ArrayType arrayType; + if (componentType == boolean.class) { + arrayType = new ArrayType(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean")); + } else { + arrayType = new ArrayType(new PrimitiveType(PrimitiveType.Primitive + .valueOf(retType.getComponentType().getSimpleName().toUpperCase()))); + } + return Optional.of(new CastExpr(arrayType, callMethodCall)); + } else if (retType.getComponentType() == String.class || retType.getComponentType() == Boolean.class + || retType.getComponentType() == Instant.class) { + ArrayType arrayType = + new ArrayType( + StaticJavaParser.parseClassOrInterfaceType(retType.getComponentType().getSimpleName())); + return Optional.of(new CastExpr(arrayType, callMethodCall)); + } + } else if (retType == Boolean.class) { + return Optional + .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"), callMethodCall)); + } else if (retType == String.class) { + return Optional + .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.String"), callMethodCall)); + } else if (retType == Instant.class) { + return Optional + .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.time.Instant"), callMethodCall)); + } + + return Optional.empty(); + } + + private Optional> pyCallableReturnType(@NotNull MethodCallExpr n) { + final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS); + final String pyMethodName = pyCallableDetails.pythonMethodName; + final QueryScope queryScope = ExecutionContext.getContext().getQueryScope(); + final Object paramValueRaw = queryScope.readParamValue(pyMethodName, null); + if (paramValueRaw == null) { + return Optional.empty(); + } + if (!(paramValueRaw instanceof PyCallableWrapper)) { + return Optional.empty(); + } + final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw; + pyCallableWrapper.parseSignature(); + return Optional.ofNullable(pyCallableWrapper.getReturnType()); + } + @NotNull private static Expression[] getExpressionsArray(final NodeList exprNodeList) { return exprNodeList == null ? new Expression[0] @@ -2558,18 +2642,37 @@ private void checkVectorizability(@NotNull final MethodCallExpr n, @NotNull final PyCallableWrapper pyCallableWrapper) { pyCallableWrapper.parseSignature(); + if (!pyCallableWrapper.isVectorizableReturnType()) { + throw new PythonCallVectorizationFailure( + "Python function return type is not supported: " + pyCallableWrapper.getReturnType()); + } // Python vectorized functions(numba, DH) return arrays of primitive/Object types. This will break the generated // expression evaluation code that expects singular values. This check makes sure that numba/dh vectorized // functions must be used alone as the entire expression after removing the enclosing parentheses. + Node n1 = n; + boolean autoCastChecked = false; while (n1.hasParentNode()) { n1 = n1.getParentNode().orElseThrow(); Class cls = n1.getClass(); if (cls == CastExpr.class) { - throw new PythonCallVectorizationFailure( - "The return values of Python vectorized function can't be cast: " + n1); + if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()) { + autoCastChecked = true; + } else { + throw new PythonCallVectorizationFailure( + "The return values of Python vectorized functions can't be cast: " + n1); + } + } else if (cls == MethodCallExpr.class) { + String methodName = ((MethodCallExpr) n1).getNameAsString(); + if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted() + && methodName.endsWith("Cast")) { + autoCastChecked = true; + } else { + throw new PythonCallVectorizationFailure( + "Python vectorized function can't be used in another expression: " + n1); + } } else if (cls != EnclosedExpr.class && cls != WrapperNode.class) { throw new PythonCallVectorizationFailure( "Python vectorized function can't be used in another expression: " + n1); @@ -3233,6 +3336,17 @@ private static class PyCallableDetails { @NotNull private final String pythonMethodName; + @NotNull + private boolean isCasted = false; + + public boolean isCasted() { + return isCasted; + } + + public void setCasted(boolean casted) { + isCasted = casted; + } + private PyCallableDetails(@Nullable String pythonScopeExpr, @NotNull String pythonMethodName) { this.pythonScopeExpr = pythonScopeExpr; this.pythonMethodName = pythonMethodName; diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java index b1dd6b19dea..bc22881b57a 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java @@ -84,4 +84,5 @@ public Object getValue() { } } + boolean isVectorizableReturnType(); } diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index 5f72ec781ac..18262f8e7f0 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -6,11 +6,13 @@ import org.jpy.PyModule; import org.jpy.PyObject; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; /** * When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs @@ -20,6 +22,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { private static final Logger log = LoggerFactory.getLogger(PyCallableWrapperJpyImpl.class); private static final PyObject NUMBA_VECTORIZED_FUNC_TYPE = getNumbaVectorizedFuncType(); + private static final PyObject NUMBA_GUVECTORIZED_FUNC_TYPE = getNumbaGUVectorizedFuncType(); private static final PyModule dh_table_module = PyModule.importModule("deephaven.table"); @@ -34,9 +37,21 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { numpyType2JavaClass.put('b', byte.class); numpyType2JavaClass.put('?', boolean.class); numpyType2JavaClass.put('U', String.class); + numpyType2JavaClass.put('M', Instant.class); numpyType2JavaClass.put('O', Object.class); } + // TODO: support for vectorizing functions that return arrays + // https://github.com/deephaven/deephaven-core/issues/4649 + private static final Set> vectorizableReturnTypes = Set.of(int.class, long.class, short.class, float.class, + double.class, byte.class, Boolean.class, String.class, Instant.class, PyObject.class); + + @Override + public boolean isVectorizableReturnType() { + parseSignature(); + return vectorizableReturnTypes.contains(returnType); + } + private final PyObject pyCallable; private String signature = null; @@ -47,6 +62,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { private Collection chunkArguments; private boolean numbaVectorized; private PyObject unwrapped; + private PyObject pyUdfDecoratedCallable; public PyCallableWrapperJpyImpl(PyObject pyCallable) { this.pyCallable = pyCallable; @@ -91,23 +107,37 @@ private static PyObject getNumbaVectorizedFuncType() { } } + private static PyObject getNumbaGUVectorizedFuncType() { + try { + return PyModule.importModule("numba.np.ufunc.gufunc").getAttribute("GUFunc"); + } catch (Exception e) { + if (log.isDebugEnabled()) { + log.debug("Numba isn't installed in the Python environment."); + } + return null; + } + } + private void prepareSignature() { - if (pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE)) { + boolean isNumbaVectorized = pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE); + boolean isNumbaGUVectorized = pyCallable.equals(NUMBA_GUVECTORIZED_FUNC_TYPE); + if (isNumbaGUVectorized || isNumbaVectorized) { List params = pyCallable.getAttribute("types").asList(); if (params.isEmpty()) { throw new IllegalArgumentException( - "numba vectorized function must have an explicit signature: " + pyCallable); + "numba vectorized/guvectorized function must have an explicit signature: " + pyCallable); } // numba allows a vectorized function to have multiple signatures if (params.size() > 1) { throw new UnsupportedOperationException( pyCallable - + " has multiple signatures; this is not currently supported for numba vectorized functions"); + + " has multiple signatures; this is not currently supported for numba vectorized/guvectorized functions"); } signature = params.get(0).getStringValue(); - unwrapped = null; - numbaVectorized = true; - vectorized = true; + unwrapped = pyCallable; + // since vectorization doesn't support array type parameters, don't flag numba guvectorized as vectorized + numbaVectorized = isNumbaVectorized; + vectorized = isNumbaVectorized; } else if (pyCallable.hasAttribute("dh_vectorized")) { signature = pyCallable.getAttribute("signature").toString(); unwrapped = pyCallable.getAttribute("callable"); @@ -119,6 +149,7 @@ private void prepareSignature() { numbaVectorized = false; vectorized = false; } + pyUdfDecoratedCallable = dh_table_module.call("_py_udf", unwrapped); } @Override @@ -135,14 +166,6 @@ public void parseSignature() { throw new IllegalStateException("Signature should always be available."); } - char numpyTypeCode = signature.charAt(signature.length() - 1); - Class returnType = numpyType2JavaClass.get(numpyTypeCode); - if (returnType == null) { - throw new IllegalStateException( - "Vectorized functions should always have an integral, floating point, boolean, String, or Object return type: " - + numpyTypeCode); - } - List> paramTypes = new ArrayList<>(); for (char numpyTypeChar : signature.toCharArray()) { if (numpyTypeChar != '-') { @@ -159,25 +182,31 @@ public void parseSignature() { } this.paramTypes = paramTypes; - if (returnType == Object.class) { - this.returnType = PyObject.class; - } else if (returnType == boolean.class) { + + returnType = pyUdfDecoratedCallable.getAttribute("return_type", null); + if (returnType == null) { + throw new IllegalStateException( + "Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type"); + } + + if (returnType == boolean.class) { this.returnType = Boolean.class; - } else { - this.returnType = returnType; } } + // In vectorized mode, we want to call the vectorized function directly. public PyObject vectorizedCallable() { - if (numbaVectorized) { + if (numbaVectorized || vectorized) { return pyCallable; } else { return dh_table_module.call("dh_vectorize", unwrapped); } } + // In non-vectorized mode, we want to call the udf decorated function or the original function. @Override public Object call(Object... args) { + PyObject pyCallable = this.pyUdfDecoratedCallable != null ? this.pyUdfDecoratedCallable : this.pyCallable; return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args)); } diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java index b1e405b67d0..573f8003f96 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java @@ -69,6 +69,11 @@ public void addChunkArgument(ChunkArgument ignored) {} @Override public Class getReturnType() { - throw new UnsupportedOperationException(); + return Object.class; + } + + @Override + public boolean isVectorizableReturnType() { + return false; } } diff --git a/py/server/deephaven/column.py b/py/server/deephaven/column.py index 635f60794f1..81fd5ffa202 100644 --- a/py/server/deephaven/column.py +++ b/py/server/deephaven/column.py @@ -9,13 +9,11 @@ from typing import Sequence, Any import jpy -import numpy as np -import pandas as pd import deephaven.dtypes as dtypes -from deephaven import DHError, time +from deephaven import DHError from deephaven.dtypes import DType -from deephaven.time import to_j_instant +from deephaven.dtypes import _instant_array _JColumnHeader = jpy.get_type("io.deephaven.qst.column.header.ColumnHeader") _JColumn = jpy.get_type("io.deephaven.qst.column.Column") @@ -206,27 +204,7 @@ def datetime_col(name: str, data: Sequence) -> InputColumn: Returns: a new input column """ - - # try to convert to numpy array of datetime64 if not already, so that we can call translateArrayLongToInstant on - # it to reduce the number of round trips to the JVM - if not isinstance(data, np.ndarray): - try: - data = np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64) - except Exception as e: - ... - - if isinstance(data, np.ndarray) and data.dtype.kind in ('M', 'i', 'U'): - if data.dtype.kind == 'M': - longs = jpy.array('long', data.astype('datetime64[ns]').astype('int64')) - elif data.dtype.kind == 'i': - longs = jpy.array('long', data.astype('int64')) - else: # data.dtype.kind == 'U' - longs = jpy.array('long', [pd.Timestamp(str(dt)).to_numpy().astype('int64') for dt in data]) - data = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) - - if not isinstance(data, dtypes.instant_array.j_type): - data = [to_j_instant(d) for d in data] - + data = _instant_array(data) return InputColumn(name=name, data_type=dtypes.Instant, input_data=data) diff --git a/py/server/deephaven/dtypes.py b/py/server/deephaven/dtypes.py index a5f940f38eb..dddf778f635 100644 --- a/py/server/deephaven/dtypes.py +++ b/py/server/deephaven/dtypes.py @@ -8,7 +8,9 @@ """ from __future__ import annotations -from typing import Any, Sequence, Callable, Dict, Type, Union +import datetime +import sys +from typing import Any, Sequence, Callable, Dict, Type, Union, _GenericAlias, Optional import jpy import numpy as np @@ -120,6 +122,8 @@ def __call__(self, *args, **kwargs): """Python object type""" JObject = DType(j_name="java.lang.Object") """Java Object type""" +bool_array = DType(j_name='[Z') +"""boolean array type""" byte_array = DType(j_name='[B') """Byte array type""" int8_array = byte_array @@ -128,6 +132,8 @@ def __call__(self, *args, **kwargs): """Short array type""" int16_array = short_array """Short array type""" +char_array = DType(j_name='[C') +"""char array type""" int32_array = DType(j_name='[I') """32bit integer array type""" long_array = DType(j_name='[J') @@ -136,7 +142,7 @@ def __call__(self, *args, **kwargs): """64bit integer array type""" int_array = long_array """64bit integer array type""" -single_array = DType(j_name='[S') +single_array = DType(j_name='[F') """Single-precision floating-point array type""" float32_array = single_array """Single-precision floating-point array type""" @@ -148,6 +154,8 @@ def __call__(self, *args, **kwargs): """Double-precision floating-point array type""" string_array = DType(j_name='[Ljava.lang.String;') """Java String array type""" +boolean_array = DType(j_name='[Ljava.lang.Boolean;') +"""Java Boolean array type""" instant_array = DType(j_name='[Ljava.time.Instant;') """Java Instant array type""" zdt_array = DType(j_name='[Ljava.time.ZonedDateTime;') @@ -164,6 +172,19 @@ def __call__(self, *args, **kwargs): float64: NULL_DOUBLE, } +_BUILDABLE_ARRAY_DTYPE_MAP = { + bool_: bool_array, + byte: int8_array, + char: char_array, + int16: int16_array, + int32: int32_array, + int64: int64_array, + float32: float32_array, + float64: float64_array, + string: string_array, + Instant: instant_array, +} + def null_remap(dtype: DType) -> Callable[[Any], Any]: """ Creates a null value remap function for the provided DType. @@ -184,6 +205,34 @@ def null_remap(dtype: DType) -> Callable[[Any], Any]: return lambda v: null_value if v is None else v +def _instant_array(data: Sequence) -> jpy.JType: + """Converts a sequence of either datetime64[ns], datetime.datetime, pandas.Timestamp, datetime strings, + or integers in nanoseconds, to a Java array of Instant values. """ + # try to convert to numpy array of datetime64 if not already, so that we can call translateArrayLongToInstant on + # it to reduce the number of round trips to the JVM + if not isinstance(data, np.ndarray): + try: + data = np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64) + except Exception as e: + ... + + if isinstance(data, np.ndarray) and data.dtype.kind in ('M', 'i', 'U'): + if data.dtype.kind == 'M': + longs = jpy.array('long', data.astype('datetime64[ns]').astype('int64')) + elif data.dtype.kind == 'i': + longs = jpy.array('long', data.astype('int64')) + else: # data.dtype.kind == 'U' + longs = jpy.array('long', [pd.Timestamp(str(dt)).to_numpy().astype('int64') for dt in data]) + data = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) + return data + + if not isinstance(data, instant_array.j_type): + from deephaven.time import to_j_instant + data = [to_j_instant(d) for d in data] + + return jpy.array(Instant.j_type, data) + + def array(dtype: DType, seq: Sequence, remap: Callable[[Any], Any] = None) -> jpy.JType: """ Creates a Java array of the specified data type populated with values from a sequence. @@ -215,14 +264,14 @@ def array(dtype: DType, seq: Sequence, remap: Callable[[Any], Any] = None) -> jp raise ValueError("Not a callable") seq = [remap(v) for v in seq] + if dtype == Instant: + return _instant_array(seq) + if isinstance(seq, np.ndarray): if dtype == bool_: bytes_ = seq.astype(dtype=np.int8) j_bytes = array(byte, bytes_) seq = _JPrimitiveArrayConversionUtility.translateArrayByteToBoolean(j_bytes) - elif dtype == Instant: - longs = jpy.array('long', seq.astype('datetime64[ns]').astype('int64')) - seq = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) return jpy.array(dtype.j_type, seq) except Exception as e: @@ -266,3 +315,66 @@ def from_np_dtype(np_dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]) - return dtype return PyObject + + +_NUMPY_INT_TYPE_CODES = ["i", "l", "h", "b"] +_NUMPY_FLOATING_TYPE_CODES = ["f", "d"] + + +def _scalar(x): + """Converts a Python value to a Java scalar value. It converts the numpy primitive types, string to + their Python equivalents so that JPY can handle them. For datetime values, it converts them to Java Instant. + Otherwise, it returns the value as is.""" + if hasattr(x, "dtype"): + if x.dtype.char in _NUMPY_INT_TYPE_CODES: + return int(x) + elif x.dtype.char in _NUMPY_FLOATING_TYPE_CODES: + return float(x) + elif x.dtype.char == '?': + return bool(x) + elif x.dtype.char == 'U': + return str(x) + elif x.dtype.char == 'O': + return x + elif x.dtype.char == 'M': + from deephaven.time import to_j_instant + return to_j_instant(x) + else: + raise TypeError(f"Unsupported dtype: {x.dtype}") + else: + if isinstance(x, (datetime.datetime, pd.Timestamp)): + from deephaven.time import to_j_instant + return to_j_instant(x) + return x + + +def _np_dtype_char(t: Union[type, str]) -> str: + """Returns the numpy dtype character code for the given type.""" + try: + np_dtype = np.dtype(t if t else "object") + if np_dtype.kind == "O": + if t in (datetime.datetime, pd.Timestamp): + return "M" + except TypeError: + np_dtype = np.dtype("object") + + return np_dtype.char + + +def _component_np_dtype_char(t: type) -> Optional[str]: + """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or + numpy ndarray, otherwise return None. """ + component_type = None + if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): + component_type = t.__args__[0] + + # np.ndarray as a generic alias is only supported in Python 3.9+ + if not component_type and sys.version_info.minor > 8: + import types + if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray): + component_type = t.__args__[0] + + if component_type: + return _np_dtype_char(component_type) + else: + return None diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index 15181e87691..78849e4cf46 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -11,10 +11,12 @@ import inspect from enum import Enum from enum import auto +from functools import wraps from typing import Any, Optional, Callable, Dict from typing import Sequence, List, Union, Protocol import jpy +import numba import numpy as np from deephaven import DHError @@ -29,6 +31,8 @@ from deephaven.jcompat import to_sequence, j_array_list from deephaven.update_graph import auto_locking_ctx, UpdateGraph from deephaven.updateby import UpdateByOperation +from deephaven.dtypes import _BUILDABLE_ARRAY_DTYPE_MAP, _scalar, _np_dtype_char, \ + _component_np_dtype_char # Table _J_Table = jpy.get_type("io.deephaven.engine.table.Table") @@ -359,40 +363,93 @@ def _j_py_script_session() -> _JPythonScriptSession: return None -_numpy_type_codes = ["i", "l", "h", "f", "d", "b", "?", "U", "O"] +_SUPPORTED_NP_TYPE_CODES = ["i", "l", "h", "f", "d", "b", "?", "U", "M", "O"] def _encode_signature(fn: Callable) -> str: """Encode the signature of a Python function by mapping the annotations of the parameter types and the return - type to numpy dtype chars (i,l,h,f,d,b,?,U,O), and pack them into a string with parameter type chars first, + type to numpy dtype chars (i,l,h,f,d,b,?,U,M,O), and pack them into a string with parameter type chars first, in their original order, followed by the delimiter string '->', then the return type_char. If a parameter or the return of the function is not annotated, the default 'O' - object type, will be used. """ sig = inspect.signature(fn) - parameter_types = [] + np_type_codes = [] for n, p in sig.parameters.items(): - try: - np_dtype = np.dtype(p.annotation if p.annotation else "object") - parameter_types.append(np_dtype) - except TypeError: - parameter_types.append(np.dtype("object")) - - try: - return_type = np.dtype(sig.return_annotation if sig.return_annotation else "object") - except TypeError: - return_type = np.dtype("object") + np_type_codes.append(_np_dtype_char(p.annotation)) - np_type_codes = [np.dtype(p).char for p in parameter_types] - np_type_codes = [c if c in _numpy_type_codes else "O" for c in np_type_codes] - return_type_code = np.dtype(return_type).char - return_type_code = return_type_code if return_type_code in _numpy_type_codes else "O" + return_type_code = _np_dtype_char(sig.return_annotation) + np_type_codes = [c if c in _SUPPORTED_NP_TYPE_CODES else "O" for c in np_type_codes] + return_type_code = return_type_code if return_type_code in _SUPPORTED_NP_TYPE_CODES else "O" np_type_codes.extend(["-", ">", return_type_code]) return "".join(np_type_codes) +def _py_udf(fn: Callable): + """A decorator that acts as a transparent translator for Python UDFs used in Deephaven query formulas between + Python and Java. This decorator is intended for use by the Deephaven query engine and should not be used by + users. + + For now, this decorator is only capable of converting Python function return values to Java values. It + does not yet convert Java values in arguments to usable Python object (e.g. numpy arrays) or properly translate + Deephaven primitive null values. + + For properly annotated functions, including numba vectorized and guvectorized ones, this decorator inspects the + signature of the function and determines its return type, including supported primitive types and arrays of + the supported primitive types. It then converts the return value of the function to the corresponding Java value + of the same type. For unsupported types, the decorator returns the original Python value which appears as + org.jpy.PyObject in Java. + """ + + if hasattr(fn, "return_type"): + return fn + + if isinstance(fn, (numba.np.ufunc.dufunc.DUFunc, numba.np.ufunc.gufunc.GUFunc)) and hasattr(fn, "types"): + dh_dtype = dtypes.from_np_dtype(np.dtype(fn.types[0][-1])) + else: + dh_dtype = dtypes.from_np_dtype(np.dtype(_encode_signature(fn)[-1])) + + return_array = False + + # If the function is a numba guvectorized function, examine the signature of the function to determine if it + # returns an array. + if isinstance(fn, numba.np.ufunc.gufunc.GUFunc): + sig = fn.signature + rtype = sig.split("->")[-1].strip("()") + if rtype: + return_array = True + else: + component_type = _component_np_dtype_char(inspect.signature(fn).return_annotation) + if component_type: + dh_dtype = dtypes.from_np_dtype(np.dtype(component_type)) + if dh_dtype in _BUILDABLE_ARRAY_DTYPE_MAP: + return_array = True + + @wraps(fn) + def wrapper(*args, **kwargs): + ret = fn(*args, **kwargs) + if return_array: + return dtypes.array(dh_dtype, ret) + elif dh_dtype == dtypes.PyObject: + return ret + else: + return _scalar(ret) + + wrapper.j_name = dh_dtype.j_name + ret_dtype = _BUILDABLE_ARRAY_DTYPE_MAP.get(dh_dtype) if return_array else dh_dtype + + if hasattr(dh_dtype.j_type, 'jclass'): + j_class = ret_dtype.j_type.jclass + else: + j_class = ret_dtype.qst_type.clazz() + + wrapper.return_type = j_class + + return wrapper + + def dh_vectorize(fn): """A decorator to vectorize a Python function used in Deephaven query formulas and invoked on a row basis. @@ -410,6 +467,7 @@ def dh_vectorize(fn): """ signature = _encode_signature(fn) + @wraps(fn) def wrapper(*args): if len(args) != len(signature) - len("->?") + 2: raise ValueError( @@ -423,7 +481,7 @@ def wrapper(*args): vectorized_args = zip(*args[2:]) for i in range(chunk_size): scalar_args = next(vectorized_args) - chunk_result[i] = fn(*scalar_args) + chunk_result[i] = _scalar(fn(*scalar_args)) else: for i in range(chunk_size): chunk_result[i] = fn() diff --git a/py/server/tests/test_numba_guvectorize.py b/py/server/tests/test_numba_guvectorize.py new file mode 100644 index 00000000000..c82b92296e3 --- /dev/null +++ b/py/server/tests/test_numba_guvectorize.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2016-2023 Deephaven Data Labs and Patent Pending +# + +import unittest + +import numpy as np +from numba import guvectorize, int64 + +from deephaven import empty_table, dtypes +from tests.testbase import BaseTestCase + +a = np.arange(5, dtype=np.int64) + + +class NumbaGuvectorizeTestCase(BaseTestCase): + def test_scalar_return(self): + # vector input to scalar output function (m)->() + @guvectorize([(int64[:], int64[:])], "(m)->()", nopython=True) + def g(x, res): + res[0] = 0 + for xi in x: + res[0] += xi + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y)") + m = t.meta_table + self.assertEqual(t.columns[2].data_type, dtypes.int64) + + def test_vector_return(self): + # vector and scalar input to vector ouput function + @guvectorize([(int64[:], int64, int64[:])], "(m),()->(m)", nopython=True) + def g(x, y, res): + for i in range(len(x)): + res[i] = x[i] + y + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,2)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + + def test_fixed_length_vector_return(self): + # NOTE: the following does not work according to this thread from 7 years ago: + # https://numba-users.continuum.narkive.com/7OAX8Suv/numba-guvectorize-with-fixed-size-output-array + # but the latest numpy Generalized Universal Function API does seem to support frozen dimensions + # https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html#generalized-universal-function-api + # There is an old numba ticket about it + # https://github.com/numba/numba/issues/1668 + # Possibly we could contribute a fix + + # fails with: bad token in signature "2" + + # #vector input to fixed-length vector ouput function + # @guvectorize([(int64[:],int64[:])],"(m)->(2)",nopython=True) + # def g3(x, res): + # res[0] = min(x) + # res[1] = max(x) + + # t3 = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g3(Y)") + # m3 = t3.meta_table + + # ** Workaround ** + + dummy = np.array([0, 0], dtype=np.int64) + + # vector input to fixed-length vector ouput function -- second arg is a dummy just to get a fixed size output + @guvectorize([(int64[:], int64[:], int64[:])], "(m),(n)->(n)", nopython=True) + def g(x, dummy, res): + res[0] = min(x) + res[1] = max(x) + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,dummy)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + + def test_np_on_java_array(self): + dummy = np.array([0, 0], dtype=np.int64) + + # vector input to fixed-length vector output function -- second arg is a dummy just to get a fixed size output + @guvectorize([(int64[:], int64[:], int64[:])], "(m),(n)->(n)", nopython=True) + def g(x, dummy, res): + res[0] = np.min(x) + res[1] = np.max(x) + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,dummy)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + + def test_np_on_java_array2(self): + @guvectorize([(int64[:], int64[:])], "(m)->(m)", nopython=True) + def g(x, res): + res[:] = x + 5 + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/server/tests/test_numba_vectorized_column.py b/py/server/tests/test_numba_vectorized_column.py index d47687c96d1..6f14eb7b659 100644 --- a/py/server/tests/test_numba_vectorized_column.py +++ b/py/server/tests/test_numba_vectorized_column.py @@ -16,7 +16,7 @@ def vectorized_func(x, y): return x % 3 + y -class TestNumbaVectorizedColumnClass(BaseTestCase): +class NumbaVectorizedColumnTestCase(BaseTestCase): def test_part_of_expr(self): with self.assertRaises(Exception): diff --git a/py/server/tests/test_numba_vectorized_filter.py b/py/server/tests/test_numba_vectorized_filter.py index 98155dadff4..468cc9e04c3 100644 --- a/py/server/tests/test_numba_vectorized_filter.py +++ b/py/server/tests/test_numba_vectorized_filter.py @@ -22,7 +22,7 @@ def vectorized_func_wrong_return_type(x, y): return x % 2 > y % 5 -class TestNumbaVectorizedFilterClass(BaseTestCase): +class NumbaVectorizedFilterTestCase(BaseTestCase): def test_wrong_return_type(self): with self.assertRaises(Exception): diff --git a/py/server/tests/test_pyfunc_return_java_values.py b/py/server/tests/test_pyfunc_return_java_values.py new file mode 100644 index 00000000000..0ed48a37505 --- /dev/null +++ b/py/server/tests/test_pyfunc_return_java_values.py @@ -0,0 +1,208 @@ +# +# Copyright (c) 2016-2023 Deephaven Data Labs and Patent Pending +# +import datetime +import unittest +from typing import List, Union, Tuple, Sequence + +import numba as nb +import numpy as np +import pandas as pd + +from deephaven import empty_table, dtypes +from tests.testbase import BaseTestCase + +_J_TYPE_NP_DTYPE_MAP = { + dtypes.double: "np.float64", + dtypes.float32: "np.float32", + dtypes.int32: "np.int32", + dtypes.long: "np.int64", + dtypes.short: "np.int16", + dtypes.byte: "np.int8", + dtypes.bool_: "np.bool_", + dtypes.string: "np.str_", + # dtypes.char: "np.uint16", +} + + +class PyFuncReturnJavaTestCase(BaseTestCase): + def test_scalar_return(self): + for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): + with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype): + func_str = f""" +def fn(col) -> {np_dtype}: + return {np_dtype}(col) +""" + exec(func_str, globals()) + + t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") + self.assertEqual(t.columns[1].data_type, dh_dtype) + + def test_array_return(self): + component_types = { + "int": dtypes.long_array, + "float": dtypes.double_array, + "np.int8": dtypes.byte_array, + "np.int16": dtypes.short_array, + "np.int32": dtypes.int32_array, + "np.int64": dtypes.long_array, + "np.float32": dtypes.float32_array, + "np.float64": dtypes.double_array, + "bool": dtypes.boolean_array, + "np.str_": dtypes.string_array, + # "np.uint16": dtypes.char_array, + } + container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] + for component_type, dh_dtype in component_types.items(): + for container_type in container_types: + with self.subTest(component_type=component_type, container_type=container_type): + func_decl_str = f"""def fn(col) -> {container_type}[{component_type}]:""" + if container_type == "np.ndarray": + func_body_str = f""" return np.array([{component_type}(c) for c in col])""" + else: + func_body_str = f""" return [{component_type}(c) for c in col]""" + exec("\n".join([func_decl_str, func_body_str]), globals()) + t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") + self.assertEqual(t.columns[2].data_type, dh_dtype) + + def test_scalar_return_class_method_not_supported(self): + for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): + with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype): + func_str = f""" +class Foo: + def fn(self, col) -> {np_dtype}: + return {np_dtype}(col) +foo = Foo() +""" + exec(func_str, globals()) + + t = empty_table(10).update("X = i").update(f"Y= foo.fn(X + 1)") + self.assertNotEqual(t.columns[1].data_type, dh_dtype) + + def test_datetime_scalar_return(self): + dt_dtypes = [ + "np.dtype('datetime64[ns]')", + "np.dtype('datetime64[ms]')", + "datetime.datetime", + "pd.Timestamp" + ] + + for np_dtype in dt_dtypes: + with self.subTest(np_dtype=np_dtype): + func_decl_str = f"""def fn(col) -> {np_dtype}:""" + if np_dtype == "np.dtype('datetime64[ns]')": + func_body_str = f""" return pd.Timestamp(col).to_numpy()""" + elif np_dtype == "datetime.datetime": + func_body_str = f""" return pd.Timestamp(col).to_pydatetime()""" + elif np_dtype == "pd.Timestamp": + func_body_str = f""" return pd.Timestamp(col)""" + + exec("\n".join([func_decl_str, func_body_str]), globals()) + + t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.Instant) + # vectorized + t = empty_table(10).update("X = i").update(f"Y= fn(X)") + self.assertEqual(t.columns[1].data_type, dtypes.Instant) + + def test_datetime_array_return(self): + + dt = datetime.datetime.now() + ts = pd.Timestamp(dt) + np_dt = np.datetime64(dt) + dt_list = [ts, np_dt, dt] + + # test if we can convert to numpy datetime64 array + np_array = np.array([pd.Timestamp(dt).to_numpy() for dt in dt_list], dtype=np.datetime64) + + dt_dtypes = [ + "np.ndarray[np.dtype('datetime64[ns]')]", + "List[datetime.datetime]", + "Tuple[pd.Timestamp]" + ] + + dt_data = [ + "dt_list", + "np_array", + ] + + # we are capable of mapping all datetime arrays (Sequence, np.ndarray) to instant arrays, so even if the actual + # return value of the function doesn't match its type hint (which will be caught by a static type checker), + # as long as it is a valid datetime collection, we can still convert it to instant array + for np_dtype in dt_dtypes: + for data in dt_data: + with self.subTest(np_dtype=np_dtype, data=data): + func_decl_str = f"""def fn(col) -> {np_dtype}:""" + func_body_str = f""" return {data}""" + exec("\n".join([func_decl_str, func_body_str]), globals().update( + {"dt_list": dt_list, "np_array": np_array})) + + t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.instant_array) + + def test_return_value_errors(self): + def fn(col) -> List[object]: + return [col] + + def fn1(col) -> List: + return [col] + + def fn2(col): + return col + + def fn3(col) -> List[Union[datetime.datetime, int]]: + return [col] + + with self.subTest(fn): + t = empty_table(1).update("X = i").update(f"Y= fn(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.JObject) + + with self.subTest(fn1): + t = empty_table(1).update("X = i").update(f"Y= fn1(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.JObject) + + with self.subTest(fn2): + t = empty_table(1).update("X = i").update(f"Y= fn2(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.JObject) + + with self.subTest(fn3): + t = empty_table(1).update("X = i").update(f"Y= fn3(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.JObject) + + + def test_vectorization_off_on_return_type(self): + def f1(x) -> List[str]: + return ["a"] + + t = empty_table(10).update("X = f1(3 + i)") + self.assertEqual(t.columns[0].data_type, dtypes.string_array) + + t = empty_table(10).update("X = f1(i)") + self.assertEqual(t.columns[0].data_type, dtypes.string_array) + + t = empty_table(10).update(["A=i%2", "B=i"]).group_by("A") + + # Testing https://github.com/deephaven/deephaven-core/issues/4557 + def f4557_1(x, y) -> np.ndarray[np.int64]: + # np.array is still needed as of v0.29 + return np.array(x) + y + + # Testing https://github.com/deephaven/deephaven-core/issues/4562 + @nb.guvectorize([(nb.int64[:], nb.int64, nb.int64[:])], "(m),()->(m)", nopython=True) + def f4562_1(x, y, res): + res[:] = x + y + + t2 = t.update([ + "X = f4557_1(B,3)", + "Y = f4562_1(B,3)" + ]) + self.assertEqual(t2.columns[2].data_type, dtypes.long_array) + self.assertEqual(t2.columns[3].data_type, dtypes.long_array) + + t3 = t2.ungroup() + self.assertEqual(t3.columns[2].data_type, dtypes.int64) + self.assertEqual(t3.columns[3].data_type, dtypes.int64) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/server/tests/test_table_factory.py b/py/server/tests/test_table_factory.py index c7bff4c19ec..9cb9aba2445 100644 --- a/py/server/tests/test_table_factory.py +++ b/py/server/tests/test_table_factory.py @@ -22,6 +22,7 @@ _JBlinkTableTools = jpy.get_type("io.deephaven.engine.table.impl.BlinkTableTools") _JDateTimeUtils = jpy.get_type("io.deephaven.time.DateTimeUtils") + @dataclass class CustomClass: f1: int @@ -56,7 +57,9 @@ def test_time_table(self): t = time_table("PT00:00:01", start_time="2021-11-06T13:21:00 ET") self.assertEqual(1, len(t.columns)) self.assertTrue(t.is_refreshing) - self.assertEqual("2021-11-06T13:21:00.000000000 ET", _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), time.to_j_time_zone('ET'))) + self.assertEqual("2021-11-06T13:21:00.000000000 ET", + _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), + time.to_j_time_zone('ET'))) t = time_table(1000_000_000) self.assertEqual(1, len(t.columns)) @@ -65,7 +68,9 @@ def test_time_table(self): t = time_table(1000_1000_1000, start_time="2021-11-06T13:21:00 ET") self.assertEqual(1, len(t.columns)) self.assertTrue(t.is_refreshing) - self.assertEqual("2021-11-06T13:21:00.000000000 ET", _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), time.to_j_time_zone('ET'))) + self.assertEqual("2021-11-06T13:21:00.000000000 ET", + _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), + time.to_j_time_zone('ET'))) p = time.to_timedelta(time.to_j_duration("PT1s")) t = time_table(p) @@ -76,7 +81,9 @@ def test_time_table(self): t = time_table(p, start_time=st) self.assertEqual(1, len(t.columns)) self.assertTrue(t.is_refreshing) - self.assertEqual("2021-11-06T13:21:00.000000000 ET", _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), time.to_j_time_zone('ET'))) + self.assertEqual("2021-11-06T13:21:00.000000000 ET", + _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), + time.to_j_time_zone('ET'))) def test_time_table_blink(self): t = time_table("PT1s", blink_table=True) @@ -334,5 +341,23 @@ def test_blink_to_append_only(self): self.assertTrue(st.is_refreshing) self.assertFalse(jpy.cast(st.j_table, _JBaseTable).isBlink()) + def test_instant_array(self): + from deephaven import DynamicTableWriter + from deephaven import dtypes as dht + from deephaven import time as dhtu + + col_defs_5 = \ + { \ + "InstantArray": dht.instant_array \ + } + + dtw5 = DynamicTableWriter(col_defs_5) + t5 = dtw5.table + dtw5.write_row(dht.array(dht.Instant, [dhtu.to_j_instant("2021-01-01T00:00:00 ET"), + dhtu.to_j_instant("2022-01-01T00:00:00 ET")])) + self.wait_ticking_table_update(t5, row_count=1, timeout=5) + self.assertEqual(t5.size, 1) + + if __name__ == '__main__': unittest.main() diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index b38004f8dcb..4bb941788fd 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -143,7 +143,7 @@ def pyfunc(p1, p2, p3) -> int: return p1 + p2 + p3 t = empty_table(1).update("X = i").update(["Y = pyfunc(X, i, 33)", "Z = pyfunc(X, ii, 66)"]) - self.assertEqual(deephaven.table._vectorized_count, 3) + self.assertEqual(deephaven.table._vectorized_count, 1) self.assertIn("33", t.to_string(cols=["Y"])) self.assertIn("66", t.to_string(cols=["Z"])) @@ -186,11 +186,11 @@ def pyfunc_bool(p1, p2, p3) -> bool: conditions = ["pyfunc_bool(I, 3, J)", "pyfunc_bool(i, 10, ii)"] filters = Filter.from_(conditions) t = empty_table(10).view(formulas=["I=ii", "J=(ii * 2)"]).where(filters) - self.assertEqual(3, deephaven.table._vectorized_count) + self.assertEqual(1, deephaven.table._vectorized_count) filter_and = and_(filters) t1 = empty_table(10).view(formulas=["I=ii", "J=(ii * 2)"]).where(filter_and) - self.assertEqual(5, deephaven.table._vectorized_count) + self.assertEqual(1, deephaven.table._vectorized_count) self.assertEqual(t1.size, t.size) self.assertEqual(9, t.size)