From 7e6f9ef41c59d79efa61ab88a02c8a04aeef7d35 Mon Sep 17 00:00:00 2001 From: supercoder-dev Date: Wed, 26 Jun 2024 16:08:19 +0530 Subject: [PATCH] Add __array_function__ to dispatch NumPy funcs to dask-awkward --- src/dask_awkward/lib/core.py | 51 ++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index e52e43af..4d5b3f9d 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1689,6 +1689,57 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): **kwargs, ) + def __array_function__(self, func, types, args, kwargs): + import dask_awkward as dak + + # List of supported functions + supported_funcs = { + "amin": dak.amin, + "nanmin": dak.nanmin, + "count_nonzero": dak.count_nonzero, + "mean": dak.mean, + "nanmean": dak.nanmean, + "concatenate": dak.concatenate, + "sum": dak.sum, + "nansum": dak.nansum, + "argsort": dak.argsort, + "copy": dak.copy, + "nan_to_num": dak.nan_to_num, + "ones_like": dak.ones_like, + "zeros_like": dak.zeros_like, + "prod": dak.prod, + "nanprod": dak.nanprod, + "sort": dak.sort, + "var": dak.var, + "nanvar": dak.nanvar, + "round": dak.round, + "ptp": dak.ptp, + "any": dak.any, + "imag": dak.imag, + "real": dak.real, + "broadcast_arrays": dak.broadcast_arrays, + "std": dak.std, + "nanstd": dak.nanstd, + "isclose": dak.isclose, + "full_like": dak.full_like, + "all": dak.all, + "amax": dak.amax, + "nanmax": dak.nanmax, + "argmax": dak.argmax, + "nanargmax": dak.nanargmax, + "where": dak.where, + "angle": dak.angle, + "ravel": dak.ravel, + "argmin": dak.argmin, + "nanargmin": dak.nanargmin, + } + + # Check if the function is supported + if func.__name__ in supported_funcs: + return supported_funcs[func.__name__](*args, **kwargs) + else: + return NotImplemented + def __array__(self, *_, **__): raise NotImplementedError