diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index f07d5ae88cf77..d7ffc3efb24d8 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -1,6 +1,6 @@ # global import abc -from typing import Optional, Union +from typing import Optional, Union, Literal # local import ivy @@ -44,6 +44,7 @@ def leaky_relu( *, alpha: float = 0.2, out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ ivy.Array instance method variant of ivy.leaky_relu. This method simply wraps @@ -59,6 +60,8 @@ def leaky_relu( out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -72,7 +75,9 @@ def leaky_relu( >>> print(y) ivy.array([ 0.39, -0.17]) """ - return ivy.leaky_relu(self._data, alpha=alpha, out=out) + return ivy.leaky_relu( + self._data, alpha=alpha, out=out, complex_mode=complex_mode + ) def gelu( self: ivy.Array, diff --git a/ivy/data_classes/array/experimental/layers.py b/ivy/data_classes/array/experimental/layers.py index d595e245a45a8..f52efa09783d7 100644 --- a/ivy/data_classes/array/experimental/layers.py +++ b/ivy/data_classes/array/experimental/layers.py @@ -832,6 +832,33 @@ def adaptive_avg_pool2d( output_size, ) + def adaptive_max_pool2d( + self: ivy.Array, + output_size: Union[Sequence[int], int], + ) -> ivy.Array: + """ + Apply a 2D adaptive maximum pooling over an input signal composed of several + input planes. + + Parameters + ---------- + self + Input array. Must have shape (N, C, H_in, W_in) or (C, H_in, W_in) where N + is the batch dimension, C is the feature dimension, and H_in and W_in are + the 2 spatial dimensions. + output_size + Spatial output size. + + Returns + ------- + The result of the pooling operation. Will have shape (N, C, S_0, S_1) or + (C, S_0, S_1), where S = `output_size` + """ + return ivy.adaptive_max_pool2d( + self._data, + output_size, + ) + def reduce_window( self: ivy.Array, init_value: Union[int, float], diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index 87387b1c253f0..dabd3be451f8e 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -1,7 +1,7 @@ # local import ivy from ivy.data_classes.container.base import ContainerBase -from typing import Optional, Union, List, Dict +from typing import Optional, Union, List, Dict, Literal # ToDo: implement all methods here as public instance methods @@ -140,6 +140,7 @@ def _static_leaky_relu( prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.leaky_relu. This method simply wraps @@ -166,6 +167,8 @@ def _static_leaky_relu( out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -191,6 +194,7 @@ def _static_leaky_relu( prune_unapplied=prune_unapplied, map_sequences=map_sequences, out=out, + complex_mode=complex_mode, ) def leaky_relu( @@ -203,6 +207,7 @@ def leaky_relu( prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.leaky_relu. This method simply @@ -229,6 +234,8 @@ def leaky_relu( out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -253,6 +260,7 @@ def leaky_relu( prune_unapplied=prune_unapplied, map_sequences=map_sequences, out=out, + complex_mode=complex_mode, ) @staticmethod diff --git a/ivy/data_classes/container/experimental/layers.py b/ivy/data_classes/container/experimental/layers.py index 38688d97d98d2..8bf647ea04dad 100644 --- a/ivy/data_classes/container/experimental/layers.py +++ b/ivy/data_classes/container/experimental/layers.py @@ -1910,6 +1910,78 @@ def adaptive_avg_pool2d( map_sequences=map_sequences, ) + @staticmethod + def static_adaptive_max_pool2d( + input: Union[ivy.Array, ivy.NativeArray, ivy.Container], + output_size: Union[Sequence[int], int, ivy.Container], + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.adaptive_max_pool2d. This method + simply wraps the function, and so the docstring for ivy.adaptive_max_pool2d also + applies to this method with minimal changes. + + Parameters + ---------- + input + Input array. Must have shape (N, C, H_in, W_in) or (C, H_in, W_in) where N + is the batch dimension, C is the feature dimension, and H_in and W_in are + the 2 spatial dimensions. + output_size + Spatial output size. + + Returns + ------- + The result of the pooling operation. Will have shape (N, C, S_0, S_1) or + (C, S_0, S_1), where S = `output_size` + """ + return ContainerBase.cont_multi_map_in_function( + "adaptive_max_pool2d", + input, + output_size, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + + def adaptive_max_pool2d( + self: ivy.Container, + output_size: Union[int, ivy.Container], + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + ) -> ivy.Container: + """ + Apply a 2D adaptive maximum pooling over an input signal composed of several + input planes. + + Parameters + ---------- + self + Input container. + output_size + Spatial output size. + + Returns + ------- + The result of the pooling operation. + """ + return self.static_adaptive_max_pool2d( + self, + output_size, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + @staticmethod def static_ifftn( x: ivy.Container, diff --git a/ivy/func_wrapper.py b/ivy/func_wrapper.py index 19bee69912e26..7a7bcc80202c3 100644 --- a/ivy/func_wrapper.py +++ b/ivy/func_wrapper.py @@ -6,13 +6,16 @@ import warnings import copy as python_copy from types import FunctionType -from typing import Callable +from typing import Callable, Literal import inspect import numpy as np +from ivy.utils.exceptions import IvyValueError + # for wrapping (sequence matters) FN_DECORATORS = [ + "handle_complex_input", "infer_device", "handle_device_shifting", "infer_dtype", @@ -1385,6 +1388,127 @@ def _handle_nans(*args, **kwargs): return _handle_nans +# Complex number handling # +# ----------------------- # +def handle_complex_input(fn: Callable) -> Callable: + @functools.wraps(fn) + def _handle_complex_input( + inp, + *args, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + **kwargs, + ): + """ + Check whether the first positional argument is an array of complex type, and if + so handle it according to the provided `complex_mode`. + + The options are: + `"jax"` (default): emulate the behaviour of the JAX framework. If the function + has a `jax_like` attribute then this will be used to decide on the + behaviour (see below) and if not, then the entire array will be passed to + the function. + `"split"`: execute the function separately on the real and imaginary parts of + the input. + `"magnitude"`: execute the function on the magnitude of the input, and keep the + angle constant. + + The `jax_like` attribute (which should be added to the function itself, and not + passed as a parameter) has the following options: + `"entire"` (default): pass the entire input to the function. This is best used + for purely mathematical operators which are already well defined on complex + inputs, as many backends will throw exceptions otherwise. + `"split"`: as the `"split"` option for `complex_mode` + `"magnitude"`: as the `"magnitude"` option for `complex_mode` + A callable function: the function will be called instead of the originally + decorated function. It will be passed `inp` and `*args` as positional + arguments, and the original `**kwargs` plus `fn_original` as keyword + arguments. The latter is the original function, in case the `jax_like` + function wishes to call it. + + Parameters + ---------- + inp + The first positional argument to the function, which is expected to be an + :class:`ivy.Array`. + args + The remaining positional arguments to be passed to the function. + complex_mode + Optional argument which specifies the method that will be used to handle + the input, if it is complex. + kwargs + The keyword arguments to be passed to the function. + + Returns + ------- + The return of the function, with handling of inputs based + on the selected `complex_mode`. + + Examples + -------- + Using the default `jax_like` behaviour + >>> @handle_complex_input + >>> def my_func(inp): + >>> return ivy.ones_like(inp) + + >>> x = ivy.array([1+1j, 3+4j, 5+12j]) + >>> my_func(x) # equivalent to setting complex_mode="jax" + ivy.array([1.+0.j, 1.+0.j, 1.+0.j]) + + >>> my_func(x, complex_mode="split") + ivy.array([1.+1.j, 1.+1.j, 1.+1.j]) + + >>> my_func(x, complex_mode="magnitude") + ivy.array([0.70710681+0.70710675j, 0.60000001+0.79999999j, + 0.38461535+0.92307694j]) + + Using non-default `jax_like` behaviour + >>> @handle_complex_input + >>> def my_func(inp): + >>> return ivy.ones_like(inp) + >>> my_func.jax_like = "split" + >>> my_func(x, complex_mode="jax") + ivy.array([1.+1.j, 1.+1.j, 1.+1.j]) + + Using callable `jax_like` behaviour + >>> def _my_func_jax_like(inp, fn_original=None): + >>> return fn_original(inp) * 3j + >>> @handle_complex_input + >>> def my_func(inp): + >>> return ivy.ones_like(inp) + >>> my_func.jax_like = _my_func_jax_like + >>> my_func(x, complex_mode="jax") + ivy.array([0.+3.j, 0.+3.j, 0.+3.j]) + """ + if not ivy.is_complex_dtype(inp): + return fn(inp, *args, **kwargs) + + jax_like = fn.jax_like if hasattr(fn, "jax_like") else "entire" + + if complex_mode == "split" or (complex_mode == "jax" and jax_like == "split"): + real_inp = ivy.real(inp) + imag_inp = ivy.imag(inp) + return fn(real_inp, *args, **kwargs) + 1j * fn(imag_inp, *args, **kwargs) + + elif complex_mode == "magnitude" or ( + complex_mode == "jax" and jax_like == "magnitude" + ): + mag_inp = ivy.abs(inp) + angle_inp = ivy.angle(inp) + return fn(mag_inp, *args, **kwargs) * ivy.exp(1j * angle_inp) + + elif complex_mode == "jax" and jax_like == "entire": + return fn(inp, *args, **kwargs) + + elif complex_mode == "jax": + return jax_like(inp, *args, **kwargs, fn_original=fn) + + else: + raise IvyValueError(f"complex_mode '{complex_mode}' is not recognised.") + + _handle_complex_input.handle_complex_input = True + return _handle_complex_input + + attribute_dict = { "unsupported_dtypes", "supported_dtypes", diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index c812a09e89320..303e20f06bfbe 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -38,6 +38,9 @@ def relu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. return F.relu(x) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version +) def leaky_relu( x: paddle.Tensor, /, diff --git a/ivy/functional/backends/paddle/experimental/layers.py b/ivy/functional/backends/paddle/experimental/layers.py index 962987b20064b..bdad826195364 100644 --- a/ivy/functional/backends/paddle/experimental/layers.py +++ b/ivy/functional/backends/paddle/experimental/layers.py @@ -381,6 +381,15 @@ def interpolate( raise IvyNotImplementedException() +def adaptive_max_pool2d( + input: paddle.Tensor, output_size: Union[Sequence[int], int] +) -> paddle.Tensor: + squeeze = input.ndim == 3 + x = paddle.unsqueeze(input, axis=0) if squeeze else input + ret = paddle.nn.functional.adaptive_max_pool2d(x, output_size) + return paddle.squeeze(ret, axis=0) if squeeze else ret + + def ifftn( x: paddle.Tensor, s: Optional[Union[int, Tuple[int]]] = None, diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index bd5b619d7dc84..c45c9f3201524 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -24,7 +24,6 @@ def gelu( return tf.nn.gelu(x, approximate) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def leaky_relu( x: Tensor, /, *, alpha: float = 0.2, out: Optional[Tensor] = None ) -> Tensor: diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index ff7c6721deea5..f5f4f500974a7 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -22,7 +22,7 @@ def relu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten return torch.relu(x) -@with_unsupported_dtypes({"2.0.1 and below": ("complex", "float16")}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def leaky_relu( x: torch.Tensor, /, diff --git a/ivy/functional/backends/torch/experimental/layers.py b/ivy/functional/backends/torch/experimental/layers.py index a297b9b37ddd1..ccb0384916977 100644 --- a/ivy/functional/backends/torch/experimental/layers.py +++ b/ivy/functional/backends/torch/experimental/layers.py @@ -865,6 +865,13 @@ def interpolate( ] +@with_unsupported_dtypes({"2.0.1 and below": ("bfloat16", "float16")}, backend_version) +def adaptive_max_pool2d( + input: torch.Tensor, output_size: Union[Sequence[int], int] +) -> torch.Tensor: + return torch.nn.functional.adaptive_max_pool2d(input, output_size) + + @with_unsupported_dtypes({"2.0.1 and below": ("bfloat16", "float16")}, backend_version) def adaptive_avg_pool1d(input, output_size): return torch.nn.functional.adaptive_avg_pool1d(input, output_size) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 835bc2b9fca4a..c94e6a69be361 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -8,21 +8,23 @@ def _type_conversion(x): # Does type conversion, floats maps to float, + # complex maps to complex, # 64bit dtype to float64, everything else to float32 x = ivy.asarray(x) dtype = ivy.as_ivy_dtype(x.dtype) - if "float" not in dtype: + if not ("float" in dtype or "complex" in dtype): dtype = "float64" if "64" in dtype[-2:] else "float32" - return ivy.astype(x, dtype) def _type_conversion_64(x): # Does type conversion, floats maps to float, - # everything else to float64 + # complex maps to complex, everything else to float64 x = ivy.asarray(x) dtype = ivy.as_ivy_dtype(x.dtype) - return ivy.astype(x, dtype) if "float" in dtype else ivy.astype(x, "float64") + if not ("float" in dtype or "complex" in dtype): + dtype = "float64" + return ivy.astype(x, dtype) def _batch_promotion(*args, default_dtype="float64"): @@ -160,7 +162,7 @@ def hard_tanh(x): @to_ivy_arrays_and_back def leaky_relu(x, negative_slope=0.01): x = _type_conversion_64(x) - return ivy.leaky_relu(x, alpha=negative_slope) + return ivy.leaky_relu(x, alpha=negative_slope, complex_mode="jax") @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/onnx/helper.py b/ivy/functional/frontends/onnx/helper.py index 3bd0250f28d55..af438f7e7a072 100644 --- a/ivy/functional/frontends/onnx/helper.py +++ b/ivy/functional/frontends/onnx/helper.py @@ -1,13 +1,13 @@ -import ivy -import ivy.functional.frontends.onnx as front_onnx from ivy.functional.frontends.onnx.proto import NodeProto from ivy_tests.test_ivy.helpers.testing_helpers import _import_fn -def make_node(op_type, inputs, outputs, name = None, doc_string = None, domain = None, **kwargs): +def make_node( + op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs +): # keep things upper case to follow ONNX naming convention - fn_tree = 'ivy.functional.frontends.onnx.' + op_type + fn_tree = "ivy.functional.frontends.onnx." + op_type callable_fn, fn_name, fn_mod = _import_fn(fn_tree) node = NodeProto() @@ -17,5 +17,5 @@ def make_node(op_type, inputs, outputs, name = None, doc_string = None, domain = node.input = inputs node.output = outputs node.name = name - - return node \ No newline at end of file + + return node diff --git a/ivy/functional/frontends/onnx/proto.py b/ivy/functional/frontends/onnx/proto.py index 4f2eb18dfcc8c..827ba60768857 100644 --- a/ivy/functional/frontends/onnx/proto.py +++ b/ivy/functional/frontends/onnx/proto.py @@ -1,4 +1,4 @@ -class NodeProto(): +class NodeProto: def __init__(self): self._fn = None self._fn_mod = None @@ -6,6 +6,6 @@ def __init__(self): self.input = None self.output = None self.name = None - + def __call__(self, *args, **kwargs): return self._fn(*args, **kwargs) diff --git a/ivy/functional/frontends/paddle/nn/functional/pooling.py b/ivy/functional/frontends/paddle/nn/functional/pooling.py index 30f6b69aea3bd..49bec303499bb 100644 --- a/ivy/functional/frontends/paddle/nn/functional/pooling.py +++ b/ivy/functional/frontends/paddle/nn/functional/pooling.py @@ -83,6 +83,12 @@ def adaptive_avg_pool2d(x, output_size, data_format="NCHW", name=None): return ivy.adaptive_avg_pool2d(x, output_size) +@to_ivy_arrays_and_back +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +def adaptive_max_pool2d(x, output_size, return_mask=None, name=None): + return ivy.adaptive_max_pool2d(x, output_size) + + @to_ivy_arrays_and_back @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def max_unpool1d( diff --git a/ivy/functional/frontends/paddle/nn/functional/vision.py b/ivy/functional/frontends/paddle/nn/functional/vision.py index 912816bb8f133..e1d69b7811760 100644 --- a/ivy/functional/frontends/paddle/nn/functional/vision.py +++ b/ivy/functional/frontends/paddle/nn/functional/vision.py @@ -5,6 +5,65 @@ from ivy.functional.frontends.paddle.func_wrapper import ( to_ivy_arrays_and_back, ) +from ivy.utils.assertions import check_equal + + +@to_ivy_arrays_and_back +def pixel_shuffle(x, upscale_factor, data_format="NCHW"): + input_shape = ivy.shape(x) + check_equal( + len(input_shape), + 4, + message="pixel shuffle requires a 4D input, but got input size {}".format( + input_shape + ), + ) + + if not isinstance(upscale_factor, int): + raise ValueError("upscale factor must be int type") + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCHW' or 'NHWC'." + "But recevie Attr(data_format): {} ".format(data_format) + ) + + b = input_shape[0] + c = input_shape[1] if data_format == "NCHW" else input_shape[3] + h = input_shape[2] if data_format == "NCHW" else input_shape[1] + w = input_shape[3] if data_format == "NCHW" else input_shape[2] + + upscale_factor_squared = upscale_factor**2 + + check_equal( + c % upscale_factor_squared, + 0, + message=( + "pixel shuffle expects input channel to be divisible by square of upscale" + " factor, but got input with sizes {}, upscale factor={}, and" + " self.size(1)={}, is not divisible by {}".format( + input_shape, upscale_factor, c, upscale_factor_squared + ) + ), + as_array=False, + ) + + oc = int(c / upscale_factor_squared) + oh = h * upscale_factor + ow = w * upscale_factor + + if data_format == "NCHW": + input_reshaped = ivy.reshape(x, (b, oc, upscale_factor, upscale_factor, h, w)) + else: + input_reshaped = ivy.reshape(x, (b, h, w, upscale_factor, upscale_factor, oc)) + + if data_format == "NCHW": + return ivy.reshape( + ivy.permute_dims(input_reshaped, (0, 1, 4, 2, 5, 3)), (b, oc, oh, ow) + ) + return ivy.reshape( + ivy.permute_dims(input_reshaped, (0, 1, 4, 2, 5, 3)), (b, oh, ow, oc) + ) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/paddle/tensor/math.py b/ivy/functional/frontends/paddle/tensor/math.py index 287c24fab9740..02ab5a09bc555 100644 --- a/ivy/functional/frontends/paddle/tensor/math.py +++ b/ivy/functional/frontends/paddle/tensor/math.py @@ -398,3 +398,11 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): @to_ivy_arrays_and_back def any(x, axis=None, keepdim=False, name=None): return ivy.any(x, axis=axis, keepdims=keepdim) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): + return ivy.diff(x, n=n, axis=axis, prepend=prepend, append=append) diff --git a/ivy/functional/frontends/tensorflow/math.py b/ivy/functional/frontends/tensorflow/math.py index cecc8bca7ba28..10840903ad429 100644 --- a/ivy/functional/frontends/tensorflow/math.py +++ b/ivy/functional/frontends/tensorflow/math.py @@ -432,17 +432,14 @@ def unsorted_segment_mean( @to_ivy_arrays_and_back -def unsorted_segment_sum( - data, segment_ids, num_segments, name="unsorted_segment_sum" -): +def unsorted_segment_sum(data, segment_ids, num_segments, name="unsorted_segment_sum"): data = ivy.array(data) segment_ids = ivy.array(segment_ids) ivy.utils.assertions.check_equal( list(segment_ids.shape), [list(data.shape)[0]], as_array=False ) sum_array = ivy.zeros( - tuple([num_segments.item()] + (list(data.shape))[1:]), - dtype=ivy.int32 + tuple([num_segments.item()] + (list(data.shape))[1:]), dtype=ivy.int32 ) for i in range((segment_ids).shape[0]): sum_array[segment_ids[i]] = sum_array[segment_ids[i]] + data[i] diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py index b4747cf9f797f..b1bacd44fa103 100644 --- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py @@ -159,6 +159,16 @@ def max_pool2d( return ret +@to_ivy_arrays_and_back +def adaptive_max_pool2d( + input, + output_size, + return_indices=False, +): + # ToDo: Add return_indices once superset is implemented + return ivy.adaptive_max_pool2d(input, output_size) + + @with_unsupported_dtypes( { "2.0.1 and below": ( diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 7489926f0570d..0bfe93a8d5538 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -1,6 +1,6 @@ """Collection of Ivy activation functions.""" -from typing import Union, Optional +from typing import Union, Optional, Callable, Literal # local import ivy @@ -12,6 +12,7 @@ handle_nestable, handle_array_like_without_promotion, handle_device_shifting, + handle_complex_input, ) from ivy.utils.exceptions import handle_exceptions @@ -77,6 +78,25 @@ def gelu( return current_backend(x).gelu(x, approximate=approximate, out=out) +def _leaky_relu_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original: Optional[Callable] = None, + alpha: float = 0.2, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + return ivy.where( + ( + ivy.logical_or( + ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) + ) + ), + ivy.astype(x * alpha, x.dtype), + x, + ) + + @handle_exceptions @handle_nestable @handle_array_like_without_promotion @@ -84,16 +104,23 @@ def gelu( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def leaky_relu( x: Union[ivy.Array, ivy.NativeArray], /, *, alpha: float = 0.2, out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the leaky rectified linear unit function element-wise. + If the input is complex, then by default each element is scaled by `alpha` if + either its real part is strictly negative or if its real part is zero and its + imaginary part is negative. This behaviour can be changed by specifying a different + `complex_mode`. + Parameters ---------- x @@ -103,6 +130,9 @@ def leaky_relu( out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -144,6 +174,9 @@ def leaky_relu( return current_backend(x).leaky_relu(x, alpha=alpha, out=out) +leaky_relu.jax_like = _leaky_relu_jax_like + + @handle_exceptions @handle_nestable @handle_array_like_without_promotion diff --git a/ivy/functional/ivy/data_type.py b/ivy/functional/ivy/data_type.py index d85cf700e6c29..04ff0ca1572b0 100644 --- a/ivy/functional/ivy/data_type.py +++ b/ivy/functional/ivy/data_type.py @@ -2492,34 +2492,24 @@ def _special_case(a1, a2): # check for float number and integer array case return isinstance(a1, float) and "int" in str(a2.dtype) + def _get_target_dtype(scalar, arr): + # identify a good dtype to give the scalar value, + # based on it's own type and that of the arr value + if _special_case(scalar, arr): + return "float64" + elif arr.dtype == bool and not isinstance(scalar, bool): + return None # let ivy infer a dtype + elif isinstance(scalar, complex) and not ivy.is_complex_dtype(arr): + return "complex128" + else: + return arr.dtype + if hasattr(x1, "dtype") and not hasattr(x2, "dtype"): device = ivy.default_device(item=x1, as_native=True) - if x1.dtype == bool and not isinstance(x2, bool): - x2 = ( - ivy.asarray(x2, device=device) - if not _special_case(x2, x1) - else ivy.asarray(x2, dtype="float64", device=device) - ) - else: - x2 = ( - ivy.asarray(x2, dtype=x1.dtype, device=device) - if not _special_case(x2, x1) - else ivy.asarray(x2, dtype="float64", device=device) - ) + x2 = ivy.asarray(x2, dtype=_get_target_dtype(x2, x1), device=device) elif hasattr(x2, "dtype") and not hasattr(x1, "dtype"): device = ivy.default_device(item=x2, as_native=True) - if x2.dtype == bool and not isinstance(x1, bool): - x1 = ( - ivy.asarray(x1, device=device) - if not _special_case(x1, x2) - else ivy.asarray(x1, dtype="float64", device=device) - ) - else: - x1 = ( - ivy.asarray(x1, dtype=x2.dtype, device=device) - if not _special_case(x1, x2) - else ivy.asarray(x1, dtype="float64", device=device) - ) + x1 = ivy.asarray(x1, dtype=_get_target_dtype(x1, x2), device=device) elif not (hasattr(x1, "dtype") or hasattr(x2, "dtype")): x1 = ivy.asarray(x1) x2 = ivy.asarray(x2) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 3052ace23bfde..0e0e20f676c59 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Union, Optional, Sequence +from typing import Union, Optional # local import ivy @@ -459,15 +459,18 @@ def elu( """ return current_backend(x).elu(x, alpha=alpha, out=out) + def sequence_length( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None) -> ivy.int64: + x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None +) -> ivy.int64: """ - Produces a scalar (tensor of empty shape) containing the number of tensors in the ivy array input. + Produces a scalar (tensor of empty shape) containing the number of tensors in the + ivy array input. Parameters ---------- x - Can be a sequence of any tensor type: bool, complex128, complex64, double, float, + Can be a sequence of any tensor type: bool, complex128, complex64, double, float, float16, int16, int32, int64, int8, string, uint16, uint32, uint64, uint8 Returns diff --git a/ivy/functional/ivy/experimental/layers.py b/ivy/functional/ivy/experimental/layers.py index f84be25d0967a..c529381fcf2b1 100644 --- a/ivy/functional/ivy/experimental/layers.py +++ b/ivy/functional/ivy/experimental/layers.py @@ -2081,7 +2081,7 @@ def _expand_to_dim(x, dim): return x -def _mask(vals, length, range_max, dim): +def _mask(vals, length, range_max, dim, mask_value=0.0): if isinstance(length, int): return vals, length else: @@ -2089,11 +2089,103 @@ def _mask(vals, length, range_max, dim): mask = ivy.greater_equal(range_max, ivy.expand_dims(length, axis=-1)) if dim == -2: mask = _expand_to_dim(mask, 4) - vals = ivy.where(mask, 0.0, vals) + vals = ivy.where(mask, ivy.array(mask_value, device=vals.device), vals) length = _expand_to_dim(length, -dim) return vals, length +@handle_nestable +@inputs_to_ivy_arrays +def adaptive_max_pool2d( + input: Union[ivy.Array, ivy.NativeArray], + output_size: Union[Sequence[int], int], +): + """ + Apply a 2D adaptive maximum pooling over an input signal composed of several input + planes. + + Parameters + ---------- + input + Input array. Must have shape (N, C, H_in, W_in) or (C, H_in, W_in) where N is + the batch dimension, C is the feature dimension, and H_in and W_in are the 2 + spatial dimensions. + output_size + Spatial output size. + + Returns + ------- + The result of the pooling operation. Will have shape (N, C, S_0, S_1) or + (C, S_0, S_1), where S = `output_size` + """ + squeeze = False + if input.ndim == 3: + input = ivy.expand_dims(input, axis=0) + squeeze = True + elif input.ndim != 4: + raise ivy.utils.exceptions.IvyException( + f"Got {len(input.shape)}D input, but only 3D and 4D inputs are supported.", + ) + + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if all(i_s % o_s == 0 for i_s, o_s in zip(input.shape[-2:], output_size)): + stride = tuple(i_s // o_s for i_s, o_s in zip(input.shape[-2:], output_size)) + kernel_size = stride # Mathematically identical to the previous expression + pooled_output = ivy.max_pool2d( + input, kernel_size, stride, "VALID", data_format="NCHW" + ) + if squeeze: + return ivy.squeeze(pooled_output, axis=0) + return pooled_output + + idxh, length_h, range_max_h, adaptive_h = _compute_idx( + input.shape[-2], output_size[-2], input.device + ) + idxw, length_w, range_max_w, adaptive_w = _compute_idx( + input.shape[-1], output_size[-1], input.device + ) + + # to numpy and back in order to bypass a slicing error in tensorflow + vals = ivy.array( + input.to_numpy()[..., _expand_to_dim(idxh, 4), idxw], device=input.device + ) + + if not adaptive_h and not adaptive_w: + ret = ivy.max(vals, axis=(-3, -1)) + ret = ivy.squeeze(ret, axis=0) if squeeze else ret + return ret + + vals, length_h = _mask( + vals, length_h, range_max_h, dim=-2, mask_value=float("-inf") + ) + vals, length_w = _mask( + vals, length_w, range_max_w, dim=-1, mask_value=float("-inf") + ) + + ret = None + for i, j in itertools.product(range(vals.shape[-3]), range(vals.shape[-1])): + if ret is None: + ret = vals[..., i, :, j] + else: + ret = ivy.maximum(ret, vals[..., i, :, j]) + pooled_output = ret.astype(vals.dtype) + + pooled_output = ivy.squeeze(pooled_output, axis=0) if squeeze else pooled_output + return pooled_output + + +adaptive_max_pool2d.mixed_backend_wrappers = { + "to_add": ( + "inputs_to_native_arrays", + "outputs_to_ivy_arrays", + "handle_device_shifting", + ), + "to_skip": ("inputs_to_ivy_arrays",), +} + + @handle_nestable @inputs_to_ivy_arrays def adaptive_avg_pool1d( @@ -2214,10 +2306,7 @@ def adaptive_avg_pool2d( if all(i_s % o_s == 0 for i_s, o_s in zip(input.shape[-2:], output_size)): stride = tuple(i_s // o_s for i_s, o_s in zip(input.shape[-2:], output_size)) - kernel_size = tuple( - i_s - (o_s - 1) * st - for i_s, o_s, st in zip(input.shape[-2:], output_size, stride) - ) + kernel_size = stride # Mathematically identical to the previous expression pooled_output = ivy.avg_pool2d( input, kernel_size, stride, "VALID", data_format="NCHW" ) diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index ff768543984bf..0958770452912 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -3,6 +3,7 @@ # local import ivy from ivy.stateful.module import Module +from typing import Literal class GELU(Module): @@ -73,7 +74,11 @@ def _forward(self, x): class LeakyReLU(Module): - def __init__(self, alpha: float = 0.2): + def __init__( + self, + alpha: float = 0.2, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + ): """ Apply the LEAKY RELU activation function. @@ -81,11 +86,14 @@ def __init__(self, alpha: float = 0.2): ---------- alpha Negative slope for ReLU. + complex_mode + Specifies how to handle complex input. """ self._alpha = alpha + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, *, alpha=None): + def _forward(self, x, *, alpha=None, complex_mode=None): """ Parameters @@ -94,13 +102,19 @@ def _forward(self, x, *, alpha=None): Inputs to process *[batch_shape, d]*. alpha Negative slope for ReLU. + complex_mode + Specifies how to handle complex input. Returns ------- ret The outputs following the LEAKY RELU activation *[batch_shape, d]* """ - return ivy.leaky_relu(x, alpha=ivy.default(alpha, self._alpha)) + return ivy.leaky_relu( + x, + alpha=ivy.default(alpha, self._alpha), + complex_mode=ivy.default(complex_mode, self._complex_mode), + ) class LogSoftmax(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index a3aa70dc91db5..0ac715d8cc786 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -133,7 +133,7 @@ def test_jax_silu( @handle_frontend_test( fn_tree="jax.nn.leaky_relu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_pooling.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_pooling.py index 6166fa95859b8..1863053b9432f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_pooling.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_pooling.py @@ -190,6 +190,50 @@ def test_paddle_adaptive_avg_pool2d( ) +# adaptive_max_pool2d +@handle_frontend_test( + fn_tree="paddle.nn.functional.adaptive_max_pool2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=4, + max_num_dims=4, + min_dim_size=1, + # Setting max and min value because this operation in paddle is not + # numerically stable + max_value=100, + min_value=-100, + ), + output_size=st.one_of( + st.tuples( + helpers.ints(min_value=1, max_value=5), + helpers.ints(min_value=1, max_value=5), + ), + helpers.ints(min_value=1, max_value=5), + ), +) +def test_paddle_adaptive_max_pool2d( + *, + dtype_and_x, + output_size, + test_flags, + frontend, + on_device, + backend_fw, + fn_tree, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + on_device=on_device, + fn_tree=fn_tree, + x=x[0], + output_size=output_size, + ) + + # max_unpool1d @handle_frontend_test( fn_tree="paddle.nn.functional.max_unpool1d", diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py index 27576be9d50dc..098e3aadbb852 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py @@ -1,11 +1,54 @@ # global -from hypothesis import strategies as st +import ivy +from hypothesis import assume, strategies as st # local import ivy_tests.test_ivy.helpers as helpers from ivy_tests.test_ivy.helpers import handle_frontend_test +# pixel_shuffle +@handle_frontend_test( + fn_tree="paddle.nn.functional.pixel_shuffle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float32", "float64"], + min_value=0, + min_num_dims=4, + max_num_dims=4, + min_dim_size=3, + ), + factor=helpers.ints(min_value=1), + data_format=st.sampled_from(["NCHW", "NHWC"]), +) +def test_paddle_pixel_shuffle( + *, + dtype_and_x, + factor, + data_format, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + if data_format == "NCHW": + assume(ivy.shape(x[0])[1] % (factor**2) == 0) + else: + assume(ivy.shape(x[0])[3] % (factor**2) == 0) + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + upscale_factor=factor, + data_format=data_format, + backend_to_test=backend_fw, + ) + + @st.composite def _affine_grid_helper(draw): align_corners = draw(st.booleans()) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py index f15157416e8db..7dc18e7a097d4 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py @@ -1,4 +1,5 @@ # global +from hypothesis import strategies as st # local import ivy_tests.test_ivy.helpers as helpers @@ -1782,3 +1783,54 @@ def test_paddle_any( axis=axis, keepdim=False, ) + + +# diff +@handle_frontend_test( + fn_tree="paddle.tensor.math.diff", + dtype_n_x_n_axis=helpers.dtype_values_axis( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), + n=st.integers(min_value=1, max_value=1), + dtype_prepend=helpers.dtype_and_values( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + max_num_dims=1, + ), + dtype_append=helpers.dtype_and_values( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + max_num_dims=1, + ), +) +def test_paddle_diff( + *, + dtype_n_x_n_axis, + n, + dtype_prepend, + dtype_append, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtype, x, axis = dtype_n_x_n_axis + _, prepend = dtype_prepend + _, append = dtype_append + helpers.test_frontend_function( + input_dtypes=input_dtype, + test_flags=test_flags, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + n=n, + axis=axis, + prepend=prepend[0], + append=append[0], + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py index e5300c6ec1451..e0e5530b5f5da 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py @@ -374,6 +374,52 @@ def test_torch_adaptive_avg_pool2d( ) +# adaptive_max_pool2d +@handle_frontend_test( + fn_tree="torch.nn.functional.adaptive_max_pool2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=4, + min_dim_size=5, + # Setting max and min value because this operation in paddle is not + # numerically stable + max_value=100, + min_value=-100, + ), + output_size=st.one_of( + st.tuples( + helpers.ints(min_value=1, max_value=10), + helpers.ints(min_value=1, max_value=10), + ), + helpers.ints(min_value=1, max_value=10), + ), + test_with_out=st.just(False), +) +def test_torch_adaptive_max_pool2d( + *, + dtype_and_x, + output_size, + on_device, + frontend, + test_flags, + fn_tree, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + output_size=output_size, + atol=1e-2, + ) + + # avg_pool1d @handle_frontend_test( fn_tree="torch.nn.functional.lp_pool1d", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py index 4a4a1151b85b2..903faa5e10458 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py @@ -857,6 +857,43 @@ def test_dft( ) +@handle_test( + fn_tree="functional.ivy.experimental.adaptive_max_pool2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=4, + min_dim_size=1, + # Setting max and min value because this operation in paddle is not + # numerically stable + max_value=100, + min_value=-100, + ), + output_size=st.one_of( + st.tuples( + helpers.ints(min_value=1, max_value=5), + helpers.ints(min_value=1, max_value=5), + ), + helpers.ints(min_value=1, max_value=5), + ), + test_with_out=st.just(False), + ground_truth_backend="torch", +) +def test_adaptive_max_pool2d( + *, dtype_and_x, output_size, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + input=x[0], + output_size=output_size, + ) + + @handle_test( fn_tree="functional.ivy.experimental.adaptive_avg_pool1d", dtype_and_x=helpers.dtype_and_values( diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 6832a1eccc452..4a9f8c69b6249 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -34,7 +34,9 @@ def test_relu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.leaky_relu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=False, key="leaky_relu"), + available_dtypes=helpers.get_dtypes( + "float_and_complex", full=False, key="leaky_relu" + ), large_abs_safety_factor=16, small_abs_safety_factor=16, safety_factor_scale="log", diff --git a/ivy_tests/test_ivy/test_misc/test_func_wrapper.py b/ivy_tests/test_ivy/test_misc/test_func_wrapper.py index 2fd2009c4f72d..3e5befbccdaef 100644 --- a/ivy_tests/test_ivy/test_misc/test_func_wrapper.py +++ b/ivy_tests/test_ivy/test_misc/test_func_wrapper.py @@ -140,3 +140,64 @@ def test_views(array_to_update, backend_fw): assert np.allclose(d, d_copy + 1) assert np.allclose(e[0], e_copy + 1) ivy.previous_backend() + + +def _fn8(x): + return ivy.ones_like(x) + + +def _jl(x, *args, fn_original, **kwargs): + return fn_original(x) * 3j + + +@pytest.mark.parametrize( + ("x", "mode", "jax_like", "expected"), + [ + ([3.0, 7.0, -5.0], None, None, [1.0, 1.0, 1.0]), + ([3 + 4j, 7 - 6j, -5 - 2j], None, None, [1 + 0j, 1 + 0j, 1 + 0j]), + ([3 + 4j, 7 - 6j, -5 - 2j], "split", None, [1 + 1j, 1 + 1j, 1 + 1j]), + ( + [3 + 4j, 7 - 6j, -5 - 2j], + "magnitude", + None, + [0.6 + 0.8j, 0.75926 - 0.65079j, -0.92848 - 0.37139j], + ), + ([3 + 4j, 7 - 6j, -5 - 2j], "jax", None, [1 + 0j, 1 + 0j, 1 + 0j]), + ([3 + 4j, 7 - 6j, -5 - 2j], "jax", "entire", [1 + 0j, 1 + 0j, 1 + 0j]), + ([3 + 4j, 7 - 6j, -5 - 2j], "jax", "split", [1 + 1j, 1 + 1j, 1 + 1j]), + ( + [3 + 4j, 7 - 6j, -5 - 2j], + "jax", + "magnitude", + [0.6 + 0.8j, 0.75926 - 0.65079j, -0.92848 - 0.37139j], + ), + ([3 + 4j, 7 - 6j, -5 - 2j], "jax", _jl, [3j, 3j, 3j]), + ], +) +def test_handle_complex_input(x, mode, jax_like, expected, backend_fw): + ivy.set_backend(backend_fw) + x = ivy.array(x) + expected = ivy.array(expected) + if jax_like is not None: + _fn8.jax_like = jax_like + elif hasattr(_fn8, "jax_like"): + # _fn8 might have the jax_like attribute still attached from previous tests + delattr(_fn8, "jax_like") + test_fn = ivy.handle_complex_input(_fn8) + out = test_fn(x) if mode is None else test_fn(x, complex_mode=mode) + if "float" in x.dtype: + assert ivy.all(out == expected) + else: + assert ivy.all( + ivy.logical_or( + ivy.real(out) > ivy.real(expected) - 1e-4, + ivy.real(out) < ivy.real(expected) + 1e-4, + ) + ) + assert ivy.all( + ivy.logical_or( + ivy.imag(out) > ivy.imag(expected) - 1e-4, + ivy.imag(out) < ivy.imag(expected) + 1e-4, + ) + ) + ivy.previous_backend() diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index 1c3bc130d10cc..1e7e1c7a49ead 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -144,7 +144,9 @@ def test_relu( @handle_method( method_tree="stateful.activations.LeakyReLU.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=False, key="leaky_relu"), + available_dtypes=helpers.get_dtypes( + "float_and_complex", full=False, key="leaky_relu" + ), large_abs_safety_factor=16, small_abs_safety_factor=16, safety_factor_scale="log",