forked from AlexCheema/mlx-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
distributions.py
31 lines (24 loc) · 850 Bytes
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Optional, Tuple, Union
import mlx.core as mx
class Normal:
def __init__(self, mu: mx.array, sigma: mx.array):
super().__init__()
self.mu = mu
self.sigma = sigma
def sample(
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
):
return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu
def log_prob(self, x: mx.array):
return (
-0.5 * math.log(2 * math.pi)
- mx.log(self.sigma)
- 0.5 * ((x - self.mu) / self.sigma) ** 2
)
def sample_and_log_prob(
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
):
x = self.sample(sample_shape, key=key)
return x, self.log_prob(x)