Skip to content

Commit

Permalink
Updated some unittests.
Browse files Browse the repository at this point in the history
  • Loading branch information
weidler committed Jul 25, 2022
1 parent 4bd3f07 commit ab7a542
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions angorapy/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import abc
import math
from typing import Union, Tuple
from typing import Union, Tuple, List

import gym
import numpy as np
Expand Down Expand Up @@ -341,7 +341,7 @@ def _entropy_from_params(self, stdevs: tf.Tensor):
Input Shape: (B, A) or (B, S, A) for recurrent
Since the given r.v.'serialization are independent, the subadditivity property of entropy narrows down to an equality
Since the given r.v.'s are independent, the sub-additivity property of entropy narrows down to an equality
of the joint entropy and the sum of marginal entropies.
"""
entropy = .5 * tf.math.log(2 * math.pi * math.e * tf.pow(stdevs, 2))
Expand Down Expand Up @@ -373,7 +373,7 @@ def _approx_entropy_from_log(self, log_stdevs: tf.Tensor):
return tf.reduce_sum(log_stdevs, axis=-1)

@tf.function
def entropy(self, params: tf.Tensor):
def entropy(self, params: Tuple[tf.Tensor, tf.Tensor]):
"""Calculate the joint entropy of Gaussian random variables described by their log standard deviations.
Input Shape: (B, A) or (B, S, A) for recurrent
Expand Down
2 changes: 1 addition & 1 deletion tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_gaussian_entropy(self):
sig = tf.convert_to_tensor([[1.0, 1.0], [1.0, 5.0]], dtype=tf.float32)

result_reference = np.sum(norm.entropy(loc=mu, scale=sig), axis=-1)
result_log = distro.entropy(np.log(sig)).numpy()
result_log = distro.entropy([mu, np.log(sig)]).numpy()
result = distro._entropy_from_params(sig).numpy()

self.assertTrue(np.allclose(result_reference, result), msg="Gaussian entropy returns wrong result")
Expand Down

0 comments on commit ab7a542

Please sign in to comment.