Skip to content

Commit

Permalink
Apply black.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Feb 13, 2024
1 parent 25f53e3 commit 8c9d1f3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
6 changes: 2 additions & 4 deletions candle-pyo3/py_src/candle/nn/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,10 @@ class Sequential(Module):
_modules: Dict[str, Module] # type: ignore[assignment]

@overload
def __init__(self, *args: Module) -> None:
...
def __init__(self, *args: Module) -> None: ...

@overload
def __init__(self, arg: "OrderedDict[str, Module]") -> None:
...
def __init__(self, arg: "OrderedDict[str, Module]") -> None: ...

def __init__(self, *args):
super().__init__()
Expand Down
12 changes: 4 additions & 8 deletions candle-pyo3/py_src/candle/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,10 @@ def named_buffers(
T_destination = TypeVar("T_destination", bound=Dict[str, Any])

@overload
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
...
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...

@overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
...
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ...

def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module.
Expand Down Expand Up @@ -586,12 +584,10 @@ def to(
self: T,
device: str = ...,
dtype: Optional[Union[DType, str]] = ...,
) -> T:
...
) -> T: ...

@overload
def to(self: T, dtype: Union[DType, str]) -> T:
...
def to(self: T, dtype: Union[DType, str]) -> T: ...

def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
Expand Down
1 change: 1 addition & 0 deletions candle-pyo3/py_src/candle/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class LayerNorm(Module):
math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
"""

__constants__ = ["normalized_shape", "eps"]
normalized_shape: Tuple[int, ...]
eps: float
Expand Down

0 comments on commit 8c9d1f3

Please sign in to comment.