Skip to content

Commit

Permalink
migrate to python 3.11
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed May 14, 2023
1 parent db38cfd commit bd25b93
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![contact](https://img.shields.io/badge/[email protected])](mailto:[email protected])
[![License](https://img.shields.io/github/license/ain-soph/trojanzoo)](https://opensource.org/licenses/GPL-3.0)

![python>=3.10](https://img.shields.io/badge/python->=3.10-informational.svg)
![python>=3.11](https://img.shields.io/badge/python->=3.11-informational.svg)
[![docs](https://github.com/ain-soph/trojanzoo/workflows/docs/badge.svg)](https://ain-soph.github.io/trojanzoo/)

[![release](https://img.shields.io/github/v/release/ain-soph/trojanzoo)](https://github.com/ain-soph/trojanzoo/releases)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ install_requires =
pyyaml>=5.3.1
pandas>=1.1.5
tqdm>=4.54.1
python_requires = >=3.10
python_requires = >=3.11

[options.package_data]
* = *.yml
Expand Down
3 changes: 2 additions & 1 deletion trojanvision/models/normal/bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import os
from collections import OrderedDict
from typing import Self


class _BiT(_ImageModel):
Expand Down Expand Up @@ -164,7 +165,7 @@ def get_official_weights(self, **kwargs) -> OrderedDict[str, torch.Tensor]:
_dict['classifier.fc.bias'] = tf2th(weights['resnet/head/conv2d/bias'])
return _dict

def parametrize_(self, parametrize: bool = True):
def parametrize_(self, parametrize: bool = True) -> Self:
for mod in self.modules():
if isinstance(mod, StdConv2d):
mod.parametrize_(parametrize)
Expand Down
4 changes: 3 additions & 1 deletion trojanvision/utils/model_archs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch.nn as nn
import torch.nn.utils.parametrize as P

from typing import Self


class Std(nn.Module):
def forward(self, X: torch.Tensor):
Expand All @@ -18,7 +20,7 @@ def __init__(self, *args, parametrize: bool = True, **kwargs) -> None:
if parametrize:
P.register_parametrization(self, 'weight', Std())

def parametrize_(self, parametrize: bool = True):
def parametrize_(self, parametrize: bool = True) -> Self:
if parametrize:
if not self.parametrize:
P.register_parametrization(self, 'weight', Std())
Expand Down
14 changes: 7 additions & 7 deletions trojanzoo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from typing import TYPE_CHECKING
# TODO: python 3.10
from typing import Generator, Iterator, Mapping
from trojanzoo.configs import Config # TODO: python 3.10
from typing import Generator, Iterator, Mapping, Self
from trojanzoo.configs import Config
from trojanzoo.utils.model import ExponentialMovingAverage
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -1166,7 +1166,7 @@ def summary(self, depth: int = None, verbose: bool = True,

# -------------------------------Reload---------------------------- #

def train(self, mode: bool = True):
def train(self, mode: bool = True) -> Self:
r"""Sets the module in training mode.
See Also:
Expand All @@ -1176,7 +1176,7 @@ def train(self, mode: bool = True):
self.model.train(mode=mode)
return self

def eval(self):
def eval(self) -> Self:
r"""Sets the module in evaluation mode.
See Also:
Expand All @@ -1186,7 +1186,7 @@ def eval(self):
self.model.eval()
return self

def cpu(self):
def cpu(self) -> Self:
r"""Moves all model parameters and buffers to the CPU.
See Also:
Expand All @@ -1196,7 +1196,7 @@ def cpu(self):
self.model.cpu()
return self

def cuda(self, device: None | int | torch.device = None):
def cuda(self, device: None | int | torch.device = None) -> Self:
r"""Moves all model parameters and buffers to the GPU.
See Also:
Expand All @@ -1206,7 +1206,7 @@ def cuda(self, device: None | int | torch.device = None):
self.model.cuda(device=device)
return self

def zero_grad(self, set_to_none: bool = False):
def zero_grad(self, set_to_none: bool = False) -> Self:
r"""Sets gradients of all model parameters to zero.
See Also:
Expand Down
16 changes: 8 additions & 8 deletions trojanzoo/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm as tqdm_class


from typing import Generator, Iterable, TypeVar # TODO: python 3.10
from typing import Generator, Iterable, Self, TypeVar
_T = TypeVar("_T")

__all__ = ['SmoothedValue', 'MetricLogger', 'AverageMeter']
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(self, name: str = '', window_size: int = None, fmt: str = '{global_
self.total: float = 0.0
self.fmt = fmt

def update(self, value: float, n: int = 1) -> 'SmoothedValue':
def update(self, value: float, n: int = 1) -> Self:
r"""Update :attr:`n` pieces of data with same :attr:`value`.
.. code-block:: python
Expand All @@ -75,7 +75,7 @@ def update(self, value: float, n: int = 1) -> 'SmoothedValue':
self.count += n
return self

def update_list(self, value_list: list[float]) -> 'SmoothedValue':
def update_list(self, value_list: list[float]) -> Self:
r"""Update :attr:`value_list`.
.. code-block:: python
Expand All @@ -97,7 +97,7 @@ def update_list(self, value_list: list[float]) -> 'SmoothedValue':
self.count += len(value_list)
return self

def reset(self) -> 'SmoothedValue':
def reset(self) -> Self:
r"""Reset ``deque``, ``count`` and ``total`` to be empty.
Returns:
Expand Down Expand Up @@ -221,7 +221,7 @@ def __init__(self, delimiter: str = '',
self.data_time = SmoothedValue()
self.memory = SmoothedValue(fmt='{max:.0f}')

def create_meters(self, **kwargs: str) -> 'MetricLogger':
def create_meters(self, **kwargs: str) -> Self:
r"""Create meters with specific ``fmt`` in :attr:`self.meters`.
``self.meters[meter_name] = SmoothedValue(fmt=fmt)``
Expand All @@ -236,7 +236,7 @@ def create_meters(self, **kwargs: str) -> 'MetricLogger':
self.meters[k] = SmoothedValue(fmt='{global_avg:.3f}' if v is None else v)
return self

def update(self, n: int = 1, **kwargs: float) -> 'MetricLogger':
def update(self, n: int = 1, **kwargs: float) -> Self:
r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`.
``self.meters[meter_name].update(float(value), n=n)``
Expand All @@ -252,7 +252,7 @@ def update(self, n: int = 1, **kwargs: float) -> 'MetricLogger':
self.meters[k].update(float(v), n=n)
return self

def update_list(self, **kwargs: list) -> 'MetricLogger':
def update_list(self, **kwargs: list) -> Self:
r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`.
``self.meters[meter_name].update_list(value_list)``
Expand All @@ -267,7 +267,7 @@ def update_list(self, **kwargs: list) -> 'MetricLogger':
self.meters[k].update_list(v)
return self

def reset(self) -> 'MetricLogger':
def reset(self) -> Self:
r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`.
Returns:
Expand Down
18 changes: 9 additions & 9 deletions trojanzoo/utils/module/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from trojanzoo.utils.output import prints

from typing import Generic, MutableMapping, TypeVar
from typing import Generic, MutableMapping, Self, TypeVar
_KT = TypeVar("_KT") # Key type.
_VT = TypeVar("_VT") # Value type.

Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(self, *args: MutableMapping[_KT, _VT], **kwargs: _VT):
return
self.update(*args, **kwargs)

def update(self, *args: MutableMapping[_KT, _VT], **kwargs: _VT):
def update(self, *args: MutableMapping[_KT, _VT], **kwargs: _VT) -> Self:
r"""update values.
Args:
Expand All @@ -49,7 +49,7 @@ def update(self, *args: MutableMapping[_KT, _VT], **kwargs: _VT):
return self

# TODO: issue 4 dict | Module
def _update(self, module: MutableMapping[_KT, _VT]):
def _update(self, module: MutableMapping[_KT, _VT]) -> Self:
for key, value in module.items():
if value is None:
continue
Expand All @@ -64,7 +64,7 @@ def _update(self, module: MutableMapping[_KT, _VT]):
self[key] = value
return self

def remove_none(self):
def remove_none(self) -> Self:
r"""Remove the parameters whose values are ``None``.
Returns:
Expand All @@ -83,7 +83,7 @@ def copy(self):
"""
return type(self)(self)

def clear(self):
def clear(self) -> Self:
r"""Remove all keys.
Returns:
Expand Down Expand Up @@ -164,21 +164,21 @@ class Param(Module, Generic[_KT, _VT]):
"""
_marker = 'P'

def update(self, *args: dict[_KT, _VT], **kwargs: _VT):
def update(self, *args: dict[_KT, _VT], **kwargs: _VT) -> Self:
if len(kwargs) == 0 and len(args) == 1 and \
not isinstance(args[0], (dict, Module)):
self.default = args[0]
return self
return super().update(*args, **kwargs)

def _update(self, module: dict[_KT, _VT]):
def _update(self, module: dict[_KT, _VT]) -> Self:
for key, value in module.items():
if key == 'default':
self.default = value
super()._update(module)
return self # For linting purpose

def remove_none(self):
def remove_none(self) -> Self:
for key in list(self.__data.keys()):
if self.__data[key] is None and \
not (isinstance(key, str) and key == 'default'):
Expand All @@ -201,7 +201,7 @@ def __getitem__(self, key: str) -> _VT:
raise KeyError(key)
return super().__getitem__(key)

def clear(self):
def clear(self) -> Self:
super().clear()
self.default = None
return self
2 changes: 1 addition & 1 deletion trojanzoo/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env python3

__version__ = '2.0.0'
__version__ = '2.0.1'

0 comments on commit bd25b93

Please sign in to comment.