Skip to content

Commit

Permalink
Refactor histogram op
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Oct 2, 2024
1 parent 08910e2 commit 1071028
Show file tree
Hide file tree
Showing 18 changed files with 259 additions and 241 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_shear import (
RandomShear,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
2 changes: 1 addition & 1 deletion keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from keras.src.ops.math import extract_sequences
from keras.src.ops.math import fft
from keras.src.ops.math import fft2
from keras.src.ops.math import histogram
from keras.src.ops.math import in_top_k
from keras.src.ops.math import irfft
from keras.src.ops.math import istft
Expand Down Expand Up @@ -160,6 +159,7 @@
from keras.src.ops.numpy import get_item
from keras.src.ops.numpy import greater
from keras.src.ops.numpy import greater_equal
from keras.src.ops.numpy import histogram
from keras.src.ops.numpy import hstack
from keras.src.ops.numpy import identity
from keras.src.ops.numpy import imag
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from keras.src.ops.numpy import get_item
from keras.src.ops.numpy import greater
from keras.src.ops.numpy import greater_equal
from keras.src.ops.numpy import histogram
from keras.src.ops.numpy import hstack
from keras.src.ops.numpy import identity
from keras.src.ops.numpy import imag
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_shear import (
RandomShear,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
2 changes: 1 addition & 1 deletion keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from keras.src.ops.math import extract_sequences
from keras.src.ops.math import fft
from keras.src.ops.math import fft2
from keras.src.ops.math import histogram
from keras.src.ops.math import in_top_k
from keras.src.ops.math import irfft
from keras.src.ops.math import istft
Expand Down Expand Up @@ -160,6 +159,7 @@
from keras.src.ops.numpy import get_item
from keras.src.ops.numpy import greater
from keras.src.ops.numpy import greater_equal
from keras.src.ops.numpy import histogram
from keras.src.ops.numpy import hstack
from keras.src.ops.numpy import identity
from keras.src.ops.numpy import imag
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from keras.src.ops.numpy import get_item
from keras.src.ops.numpy import greater
from keras.src.ops.numpy import greater_equal
from keras.src.ops.numpy import histogram
from keras.src.ops.numpy import hstack
from keras.src.ops.numpy import identity
from keras.src.ops.numpy import imag
Expand Down
4 changes: 0 additions & 4 deletions keras/src/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,3 @@ def logdet(x):
# `np.log(np.linalg.det(x))`. See
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
return slogdet(x)[1]


def histogram(x, bins, range):
return jnp.histogram(x, bins=bins, range=range)
4 changes: 4 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,3 +1246,7 @@ def slogdet(x):

def argpartition(x, kth, axis=-1):
return jnp.argpartition(x, kth, axis)


def histogram(x, bins, range):
return jnp.histogram(x, bins=bins, range=range)
4 changes: 0 additions & 4 deletions keras/src/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,3 @@ def logdet(x):
# In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See
# https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
return slogdet(x)[1]


def histogram(x, bins, range):
return np.histogram(x, bins=bins, range=range)
4 changes: 4 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,3 +1177,7 @@ def slogdet(x):

def argpartition(x, kth, axis=-1):
return np.argpartition(x, kth, axis).astype("int32")


def histogram(x, bins, range):
return np.histogram(x, bins=bins, range=range)
28 changes: 0 additions & 28 deletions keras/src/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,31 +370,3 @@ def norm(x, ord=None, axis=None, keepdims=False):
def logdet(x):
x = convert_to_tensor(x)
return tf.linalg.logdet(x)


def histogram(x, bins, range):
"""
Computes a histogram of the data tensor `x` using TensorFlow.
The `tf.histogram_fixed_width()` and `tf.histogram_fixed_width_bins()`
methods yielded slight numerical differences on some edge cases.
"""

x = tf.convert_to_tensor(x, dtype=x.dtype)

# Handle the range argument
if range is None:
min_val = tf.reduce_min(x)
max_val = tf.reduce_max(x)
else:
min_val, max_val = range

x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
bin_edges = tf.linspace(min_val, max_val, bins + 1)
bin_edges_list = bin_edges.numpy().tolist()
bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])

bin_counts = tf.math.bincount(
bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype
)

return bin_counts, bin_edges
28 changes: 28 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,3 +2546,31 @@ def argpartition(x, kth, axis=-1):

out = tf.concat([bottom_ind, top_ind], axis=x.ndim - 1)
return swapaxes(out, -1, axis)


def histogram(x, bins, range):
"""Computes a histogram of the data tensor `x`.
Note: the `tf.histogram_fixed_width()` and
`tf.histogram_fixed_width_bins()` functions
yield slight numerical differences for some edge cases.
"""

x = tf.convert_to_tensor(x, dtype=x.dtype)

# Handle the range argument
if range is None:
min_val = tf.reduce_min(x)
max_val = tf.reduce_max(x)
else:
min_val, max_val = range

x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
bin_edges = tf.linspace(min_val, max_val, bins + 1)
bin_edges_list = bin_edges.numpy().tolist()
bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])

bin_counts = tf.math.bincount(
bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype
)
return bin_counts, bin_edges
5 changes: 0 additions & 5 deletions keras/src/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,3 @@ def norm(x, ord=None, axis=None, keepdims=False):
def logdet(x):
x = convert_to_tensor(x)
return torch.logdet(x)


def histogram(x, bins, range):
hist_result = torch.histogram(x, bins=bins, range=range)
return hist_result.hist, hist_result.bin_edges
5 changes: 5 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,3 +1701,8 @@ def set_to_zero(a, i):
top_ind = torch.topk(proxy, x.shape[-1] - kth - 1)[1]
out = torch.cat([bottom_ind, top_ind], dim=x.dim() - 1)
return cast(torch.transpose(out, -1, axis), "int32")


def histogram(x, bins, range):
hist_result = torch.histogram(x, bins=bins, range=range)
return hist_result.hist, hist_result.bin_edges
86 changes: 0 additions & 86 deletions keras/src/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,89 +971,3 @@ def logdet(x):
if any_symbolic_tensors((x,)):
return Logdet().symbolic_call(x)
return backend.math.logdet(x)


class Histogram(Operation):
def __init__(self, bins=10, range=None):
super().__init__()

if not isinstance(bins, int):
raise TypeError("bins must be of type `int`")
if bins < 0:
raise ValueError("`bins` should be a non-negative integer")

if range:
if len(range) < 2 or not isinstance(range, tuple):
raise ValueError("range must be a tuple of two elements")

if range[1] < range[0]:
raise ValueError(
"The second element of range must be greater than the first"
)

self.bins = bins
self.range = range

def call(self, x):
x = backend.convert_to_tensor(x)
if len(x.shape) > 1:
raise ValueError("Input tensor must be 1-dimensional")
return backend.math.histogram(x, bins=self.bins, range=self.range)

def compute_output_spec(self, x):
return (
KerasTensor(shape=(self.bins,), dtype=x.dtype),
KerasTensor(shape=(self.bins + 1,), dtype=x.dtype),
)


@keras_export("keras.ops.histogram")
def histogram(x, bins=10, range=None):
"""Computes a histogram of the data tensor `x`.
Args:
x: Input tensor.
bins: An integer representing the number of histogram bins.
Defaults to 10.
range: A tuple representing the lower and upper range of the bins.
If not specified, it will use the min and max of `x`.
Returns:
A tuple containing:
- A tensor representing the counts of elements in each bin.
- A tensor representing the bin edges.
Example:
```
>>> nput_tensor = np.random.rand(8)
>>> keras.ops.histogram(input_tensor)
(array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32),
array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262,
0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101,
0.85892869]))
```
"""

if not isinstance(bins, int):
raise TypeError("bins must be of type `int`")
if bins < 0:
raise ValueError("`bins` should be a non-negative integer")

if range:
if len(range) < 2 or not isinstance(range, tuple):
raise ValueError("range must be a tuple of two elements")

if range[1] < range[0]:
raise ValueError(
"The second element of range must be greater than the first"
)

if any_symbolic_tensors((x,)):
return Histogram(bins=bins, range=range).symbolic_call(x)

x = backend.convert_to_tensor(x)
if len(x.shape) > 1:
raise ValueError("Input tensor must be 1-dimensional")
return backend.math.histogram(x, bins=bins, range=range)
112 changes: 0 additions & 112 deletions keras/src/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,115 +1468,3 @@ def test_istft_invalid_window_shape_2D_inputs(self):
fft_length,
window=incorrect_window,
)


class HistogramTest(testing.TestCase):
def test_histogram_default_args(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

# Expected output
expected_counts, expected_edges = np.histogram(input_tensor)

counts, edges = hist_op(input_tensor)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_custom_bins(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)
bins = 5

# Expected output
expected_counts, expected_edges = np.histogram(input_tensor, bins=bins)

counts, edges = hist_op(input_tensor, bins=bins)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_custom_range(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(10)
range_specified = (2, 8)

# Expected output
expected_counts, expected_edges = np.histogram(
input_tensor, range=range_specified
)

counts, edges = hist_op(input_tensor, range=range_specified)

self.assertEqual(counts.shape, expected_counts.shape)
self.assertAllClose(counts, expected_counts)
self.assertEqual(edges.shape, expected_edges.shape)
self.assertAllClose(edges, expected_edges)

def test_histogram_symbolic_input(self):
hist_op = kmath.histogram
input_tensor = KerasTensor(shape=(None,), dtype="float32")

counts, edges = hist_op(input_tensor)

self.assertEqual(counts.shape, (10,))
self.assertEqual(edges.shape, (11,))

def test_histogram_non_integer_bins_raises_error(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

with self.assertRaisesRegex(
ValueError, "`bins` should be a non-negative integer"
):
hist_op(input_tensor, bins=-5)

def test_histogram_range_validation(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

with self.assertRaisesRegex(
ValueError, "range must be a tuple of two elements"
):
hist_op(input_tensor, range=(1,))

with self.assertRaisesRegex(
ValueError,
"The second element of range must be greater than the first",
):
hist_op(input_tensor, range=(5, 1))

def test_histogram_large_values(self):
hist_op = kmath.histogram
input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10])

counts, edges = hist_op(input_tensor, bins=5)

expected_counts, expected_edges = np.histogram(input_tensor, bins=5)

self.assertAllClose(counts, expected_counts)
self.assertAllClose(edges, expected_edges)

def test_histogram_float_input(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(8)

counts, edges = hist_op(input_tensor, bins=5)

expected_counts, expected_edges = np.histogram(input_tensor, bins=5)

self.assertAllClose(counts, expected_counts)
self.assertAllClose(edges, expected_edges)

def test_histogram_high_dimensional_input(self):
hist_op = kmath.histogram
input_tensor = np.random.rand(3, 4, 5)

with self.assertRaisesRegex(
ValueError, "Input tensor must be 1-dimensional"
):
hist_op(input_tensor)
Loading

0 comments on commit 1071028

Please sign in to comment.