Skip to content

Commit

Permalink
Complex activations (ivy-llc#20850)
Browse files Browse the repository at this point in the history
  • Loading branch information
jshepherd01 authored Aug 4, 2023
1 parent cea8c90 commit bd03368
Show file tree
Hide file tree
Showing 14 changed files with 285 additions and 42 deletions.
9 changes: 7 additions & 2 deletions ivy/data_classes/array/activations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# global
import abc
from typing import Optional, Union
from typing import Optional, Union, Literal

# local
import ivy
Expand Down Expand Up @@ -44,6 +44,7 @@ def leaky_relu(
*,
alpha: float = 0.2,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.leaky_relu. This method simply wraps
Expand All @@ -59,6 +60,8 @@ def leaky_relu(
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
Returns
-------
Expand All @@ -72,7 +75,9 @@ def leaky_relu(
>>> print(y)
ivy.array([ 0.39, -0.17])
"""
return ivy.leaky_relu(self._data, alpha=alpha, out=out)
return ivy.leaky_relu(
self._data, alpha=alpha, out=out, complex_mode=complex_mode
)

def gelu(
self: ivy.Array,
Expand Down
10 changes: 9 additions & 1 deletion ivy/data_classes/container/activations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# local
import ivy
from ivy.data_classes.container.base import ContainerBase
from typing import Optional, Union, List, Dict
from typing import Optional, Union, List, Dict, Literal


# ToDo: implement all methods here as public instance methods
Expand Down Expand Up @@ -140,6 +140,7 @@ def _static_leaky_relu(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.leaky_relu. This method simply wraps
Expand All @@ -166,6 +167,8 @@ def _static_leaky_relu(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
Returns
-------
Expand All @@ -191,6 +194,7 @@ def _static_leaky_relu(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

def leaky_relu(
Expand All @@ -203,6 +207,7 @@ def leaky_relu(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.leaky_relu. This method simply
Expand All @@ -229,6 +234,8 @@ def leaky_relu(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
Returns
-------
Expand All @@ -253,6 +260,7 @@ def leaky_relu(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

@staticmethod
Expand Down
126 changes: 125 additions & 1 deletion ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import warnings
import copy as python_copy
from types import FunctionType
from typing import Callable
from typing import Callable, Literal
import inspect
import numpy as np

from ivy.utils.exceptions import IvyValueError


# for wrapping (sequence matters)
FN_DECORATORS = [
"handle_complex_input",
"infer_device",
"handle_device_shifting",
"infer_dtype",
Expand Down Expand Up @@ -1385,6 +1388,127 @@ def _handle_nans(*args, **kwargs):
return _handle_nans


# Complex number handling #
# ----------------------- #
def handle_complex_input(fn: Callable) -> Callable:
@functools.wraps(fn)
def _handle_complex_input(
inp,
*args,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
**kwargs,
):
"""
Check whether the first positional argument is an array of complex type, and if
so handle it according to the provided `complex_mode`.
The options are:
`"jax"` (default): emulate the behaviour of the JAX framework. If the function
has a `jax_like` attribute then this will be used to decide on the
behaviour (see below) and if not, then the entire array will be passed to
the function.
`"split"`: execute the function separately on the real and imaginary parts of
the input.
`"magnitude"`: execute the function on the magnitude of the input, and keep the
angle constant.
The `jax_like` attribute (which should be added to the function itself, and not
passed as a parameter) has the following options:
`"entire"` (default): pass the entire input to the function. This is best used
for purely mathematical operators which are already well defined on complex
inputs, as many backends will throw exceptions otherwise.
`"split"`: as the `"split"` option for `complex_mode`
`"magnitude"`: as the `"magnitude"` option for `complex_mode`
A callable function: the function will be called instead of the originally
decorated function. It will be passed `inp` and `*args` as positional
arguments, and the original `**kwargs` plus `fn_original` as keyword
arguments. The latter is the original function, in case the `jax_like`
function wishes to call it.
Parameters
----------
inp
The first positional argument to the function, which is expected to be an
:class:`ivy.Array`.
args
The remaining positional arguments to be passed to the function.
complex_mode
Optional argument which specifies the method that will be used to handle
the input, if it is complex.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with handling of inputs based
on the selected `complex_mode`.
Examples
--------
Using the default `jax_like` behaviour
>>> @handle_complex_input
>>> def my_func(inp):
>>> return ivy.ones_like(inp)
>>> x = ivy.array([1+1j, 3+4j, 5+12j])
>>> my_func(x) # equivalent to setting complex_mode="jax"
ivy.array([1.+0.j, 1.+0.j, 1.+0.j])
>>> my_func(x, complex_mode="split")
ivy.array([1.+1.j, 1.+1.j, 1.+1.j])
>>> my_func(x, complex_mode="magnitude")
ivy.array([0.70710681+0.70710675j, 0.60000001+0.79999999j,
0.38461535+0.92307694j])
Using non-default `jax_like` behaviour
>>> @handle_complex_input
>>> def my_func(inp):
>>> return ivy.ones_like(inp)
>>> my_func.jax_like = "split"
>>> my_func(x, complex_mode="jax")
ivy.array([1.+1.j, 1.+1.j, 1.+1.j])
Using callable `jax_like` behaviour
>>> def _my_func_jax_like(inp, fn_original=None):
>>> return fn_original(inp) * 3j
>>> @handle_complex_input
>>> def my_func(inp):
>>> return ivy.ones_like(inp)
>>> my_func.jax_like = _my_func_jax_like
>>> my_func(x, complex_mode="jax")
ivy.array([0.+3.j, 0.+3.j, 0.+3.j])
"""
if not ivy.is_complex_dtype(inp):
return fn(inp, *args, **kwargs)

jax_like = fn.jax_like if hasattr(fn, "jax_like") else "entire"

if complex_mode == "split" or (complex_mode == "jax" and jax_like == "split"):
real_inp = ivy.real(inp)
imag_inp = ivy.imag(inp)
return fn(real_inp, *args, **kwargs) + 1j * fn(imag_inp, *args, **kwargs)

elif complex_mode == "magnitude" or (
complex_mode == "jax" and jax_like == "magnitude"
):
mag_inp = ivy.abs(inp)
angle_inp = ivy.angle(inp)
return fn(mag_inp, *args, **kwargs) * ivy.exp(1j * angle_inp)

elif complex_mode == "jax" and jax_like == "entire":
return fn(inp, *args, **kwargs)

elif complex_mode == "jax":
return jax_like(inp, *args, **kwargs, fn_original=fn)

else:
raise IvyValueError(f"complex_mode '{complex_mode}' is not recognised.")

_handle_complex_input.handle_complex_input = True
return _handle_complex_input


attribute_dict = {
"unsupported_dtypes",
"supported_dtypes",
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/backends/paddle/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def relu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.
return F.relu(x)


@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version
)
def leaky_relu(
x: paddle.Tensor,
/,
Expand Down
1 change: 0 additions & 1 deletion ivy/functional/backends/tensorflow/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def gelu(
return tf.nn.gelu(x, approximate)


@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
def leaky_relu(
x: Tensor, /, *, alpha: float = 0.2, out: Optional[Tensor] = None
) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/torch/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def relu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten
return torch.relu(x)


@with_unsupported_dtypes({"2.0.1 and below": ("complex", "float16")}, backend_version)
@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version)
def leaky_relu(
x: torch.Tensor,
/,
Expand Down
12 changes: 7 additions & 5 deletions ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@

def _type_conversion(x):
# Does type conversion, floats maps to float,
# complex maps to complex,
# 64bit dtype to float64, everything else to float32
x = ivy.asarray(x)
dtype = ivy.as_ivy_dtype(x.dtype)
if "float" not in dtype:
if not ("float" in dtype or "complex" in dtype):
dtype = "float64" if "64" in dtype[-2:] else "float32"

return ivy.astype(x, dtype)


def _type_conversion_64(x):
# Does type conversion, floats maps to float,
# everything else to float64
# complex maps to complex, everything else to float64
x = ivy.asarray(x)
dtype = ivy.as_ivy_dtype(x.dtype)
return ivy.astype(x, dtype) if "float" in dtype else ivy.astype(x, "float64")
if not ("float" in dtype or "complex" in dtype):
dtype = "float64"
return ivy.astype(x, dtype)


def _batch_promotion(*args, default_dtype="float64"):
Expand Down Expand Up @@ -160,7 +162,7 @@ def hard_tanh(x):
@to_ivy_arrays_and_back
def leaky_relu(x, negative_slope=0.01):
x = _type_conversion_64(x)
return ivy.leaky_relu(x, alpha=negative_slope)
return ivy.leaky_relu(x, alpha=negative_slope, complex_mode="jax")


@to_ivy_arrays_and_back
Expand Down
35 changes: 34 additions & 1 deletion ivy/functional/ivy/activations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Collection of Ivy activation functions."""

from typing import Union, Optional
from typing import Union, Optional, Callable, Literal

# local
import ivy
Expand All @@ -12,6 +12,7 @@
handle_nestable,
handle_array_like_without_promotion,
handle_device_shifting,
handle_complex_input,
)
from ivy.utils.exceptions import handle_exceptions

Expand Down Expand Up @@ -77,23 +78,49 @@ def gelu(
return current_backend(x).gelu(x, approximate=approximate, out=out)


def _leaky_relu_jax_like(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
fn_original: Optional[Callable] = None,
alpha: float = 0.2,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
return ivy.where(
(
ivy.logical_or(
ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0)
)
),
ivy.astype(x * alpha, x.dtype),
x,
)


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_complex_input
def leaky_relu(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
alpha: float = 0.2,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
Apply the leaky rectified linear unit function element-wise.
If the input is complex, then by default each element is scaled by `alpha` if
either its real part is strictly negative or if its real part is zero and its
imaginary part is negative. This behaviour can be changed by specifying a different
`complex_mode`.
Parameters
----------
x
Expand All @@ -103,6 +130,9 @@ def leaky_relu(
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
`ivy.func_wrapper.handle_complex_input` for more detail.
Returns
-------
Expand Down Expand Up @@ -144,6 +174,9 @@ def leaky_relu(
return current_backend(x).leaky_relu(x, alpha=alpha, out=out)


leaky_relu.jax_like = _leaky_relu_jax_like


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
Expand Down
Loading

0 comments on commit bd03368

Please sign in to comment.