diff --git a/.github/workflows/python-tox.yml b/.github/workflows/python-tox.yml new file mode 100644 index 0000000..fe8dca2 --- /dev/null +++ b/.github/workflows/python-tox.yml @@ -0,0 +1,52 @@ +name: tox + +on: [ + push, + pull_request, +] + +jobs: + build: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ + ubuntu-latest, + macos-12, + windows-latest, + ] + + steps: + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: "pip" + - name: Install dependencies + run: pip install -r requirements.txt + - name: Install package + run: pip install -e . + + tox: + + runs-on: ubuntu-latest + strategy: + matrix: + python: [ + "3.10", + "3.12", + ] + + steps: + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: "pip" + - name: Install tox + run: pip install tox + - name: Run tox + run: tox diff --git a/CHANGELOG.md b/CHANGELOG.md index 84bba8f..5d3c705 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,17 @@ # Changelog -## 2027/08/09 - 00 - v0.0.7 - Update Callback +## 2024/10/14 - 00 - v0.0.8 - Instantiation and GitHub CI +* Instantiated backends and solvers with different numerical libraries: + * Added `context_manager` modules to `quantrl.backends` and `quantrl.solvers` packages. + * Added support for JIT-compilation in `quantrl.backends` package. + * Changed naming convention for GPU-based options as `gpu` from `cuda`. +* Added `tox`-based testing and continuous integration with GitHub workflows. +* Updated all modules with PEP-based code styling and added `pylintrc`. +* Updated `requirements.txt` and added version to `quantrl.__init__`. +* Replaced `setup.py` with `pyproject.toml`. +* Updated `docs` and `README`. + +## 2024/08/09 - 00 - v0.0.7 - Update Callback * Fixed issue with cache indexing in `quantrl.envs.base` module. * Updated best reward callback in `quantrl.utils` module. * Added `seaborn` to `requirements` and `setup`. diff --git a/README.md b/README.md index 89f9ab7..bdf2bd7 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # QuantRL: Quantum Control using Reinforcement Learning -![Latest Version](https://img.shields.io/badge/version-0.0.7-red?style=for-the-badge) +![Latest Version](https://img.shields.io/badge/version-0.0.8-red?style=for-the-badge) > A library of modules to interface deterministic and stochastic quantum models for reinforcement learning. diff --git a/docs/source/quantrl.backends.rst b/docs/source/quantrl.backends.rst index 9f24f3c..71ac2c0 100644 --- a/docs/source/quantrl.backends.rst +++ b/docs/source/quantrl.backends.rst @@ -10,6 +10,15 @@ quantrl.backends.base module :undoc-members: :show-inheritance: +quantrl.backends.context_manager module +--------------------------------------- + +.. automodule:: quantrl.backends.context_manager + :members: + :private-members: + :undoc-members: + :show-inheritance: + quantrl.backends.jax module --------------------------- diff --git a/docs/source/quantrl.solvers.rst b/docs/source/quantrl.solvers.rst index 50fa967..c945439 100644 --- a/docs/source/quantrl.solvers.rst +++ b/docs/source/quantrl.solvers.rst @@ -10,6 +10,15 @@ quantrl.solvers.base module :undoc-members: :show-inheritance: +quantrl.backends.context_manager module +--------------------------------------- + +.. automodule:: quantrl.backends.context_manager + :members: + :private-members: + :undoc-members: + :show-inheritance: + quantrl.solvers.jax module -------------------------- diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000..fd636f5 --- /dev/null +++ b/pylintrc @@ -0,0 +1,26 @@ +[MESSAGES CONTROL] +disable=too-many-lines, + too-many-locals, + too-many-arguments, + too-many-statements, + too-few-public-methods, + too-many-public-methods, + too-many-instance-attributes, + too-many-positional-arguments, + line-too-long, # no bounded lines + redefined-builtin, # catches __name__ + duplicate-code, # catches __init__ + unused-argument, # catches Gym.Env methods + import-outside-toplevel, # for numerical libraries + not-callable, # catches JAX JIT functions + fixme + +[BASIC] +attr-naming-style=any +variable-naming-style=any +argument-naming-style=any +const-naming-style=UPPER_CASE +method-naming-style=any +function-naming-style=any +class-naming-style=PascalCase +class-attribute-naming-style=snake_case diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7c799c5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,87 @@ +[build-system] +requires = ["cython", "setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +dynamic = ["version"] +name = "quantrl" +authors = [ + {name = "Sampreet Kalita", email = "9553215+Sampreet@users.noreply.github.com"}, +] +maintainers = [ + {name = "Sampreet Kalita", email = "9553215+Sampreet@users.noreply.github.com"}, +] +description = "Quantum Control with Reinforcement Learning" +keywords = [ + "quantum", + "toolbox", + "reinforcement learning", + "python3", +] +readme = "README.md" +license = {file = "LICENSE.txt"} +requires-python = ">=3.10" +dependencies = [ + "numpy<2.0.0", + "scipy", + "matplotlib", + "tqdm", + "pillow", + "pandas", + "gymnasium", + "stable-baselines3", +] +classifiers = [ + "Programming Language :: Python :: 3", + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering", +] + +[project.optional-dependencies] +jax-cpu = [ + "jax", + "diffrax", +] +jax-gpu = [ + "jax[cuda12]", + "diffrax", +] +torch = [ + "torch", + "torchvision", + "torchaudio", + "torchdiffeq", +] + +[project.urls] +Homepage = "https://github.com/sampreet/quantrl" +Repository = "https://github.com/sampreet/quantrl" +Issues = "https://github.com/sampreet/quantrl/issues" +Changelog = "https://github.com/sampreet/quantrl/blob/master/CHANGELOG.md" + +[tool.setuptools.packages.find] +include = ["quantrl"] +namespaces = false + +[tool.setuptools.dynamic] +version = {attr = "quantrl.__version__"} + +[tool.tox] +legacy_tox_ini = """ +[tox] +requires = + tox>=4 + virtualenv>=20 +env_list = + lint + +[testenv:lint] +description = run pylint under {base_python} +deps = + -r requirements_tox.txt +commands = + pylint quantrl +""" diff --git a/quantrl/__init__.py b/quantrl/__init__.py index e69de29..e83f115 100644 --- a/quantrl/__init__.py +++ b/quantrl/__init__.py @@ -0,0 +1,2 @@ +"""Module to initialize QuantRL.""" +__version__ = "0.0.8" diff --git a/quantrl/backends/base.py b/quantrl/backends/base.py index d8e4025..8fed94f 100644 --- a/quantrl/backends/base.py +++ b/quantrl/backends/base.py @@ -6,10 +6,11 @@ __name__ = 'quantrl.backends.base' __authors__ = ["Sampreet Kalita"] __created__ = "2024-03-10" -__updated__ = "2024-10-09" +__updated__ = "2024-10-13" # dependencies from abc import ABC, abstractmethod + import numpy as np class BaseBackend(ABC): @@ -29,9 +30,9 @@ class BaseBackend(ABC): def __init__( self, - name, - library, - tensor_type, + name:str='numpy', + library:np=np, + tensor_type:np.ndarray=np.ndarray, precision:str='double' ): # validate parameters @@ -68,6 +69,7 @@ def __init__( } } } + self.seed_sequence = None def is_typed(self, tensor, @@ -91,11 +93,32 @@ def is_typed(self, _dtype = self.dtype_from_str( dtype=dtype ) - if type(tensor) == self.tensor_type: + if isinstance(tensor, self.tensor_type): if dtype is None or (dtype is not None and tensor.dtype == _dtype): return True return False + def get_seedsequence(self, + seed:int=None + ) -> np.random.SeedSequence: + """Method to obtain a SeedSequence object. + + Parameters + ---------- + seed: int + Initial seed to obtain the entropy. + + Returns + ------- + seed_sequence: :class:`numpy.random.SeedSequence` + The SeedSequence object. + """ + if seed is None: + entropy = np.random.randint(1234567890) + else: + entropy = np.random.default_rng(seed).integers(0, 1234567890, (1, ))[0] + return np.random.SeedSequence(entropy) + @abstractmethod def convert_to_typed(self, tensor, @@ -546,37 +569,41 @@ def dtype_from_str(self, dtype: type Selected data-type. """ - - # validate params - assert dtype is None or dtype in ['integer', 'real', 'complex'], "parameter ``dtype`` can be either ``'integer'``, ``'real'`` or ``'complex'``." - # default dtype is the real data-type - if dtype is None: + if dtype is None or dtype not in ['integer', 'real', 'complex']: dtype = 'real' return self.dtypes['numpy' if numpy else 'typed'][self.precision][dtype] - + def jit_transpose(self, tensor, axis_0, axis_1): + """Method to JIT-compile transposition.""" return self.transpose(tensor, axis_0, axis_1) - + def jit_repeat(self, tensor, repeats, axis): + """Method to JIT-compile repitition.""" return self.repeat(tensor, repeats, axis) - + def jit_add(self, tensor_0, tensor_1, out): + """Method to JIT-compile addition.""" return self.add(tensor_0, tensor_1, out=out) - + def jit_matmul(self, tensor_0, tensor_1, out): + """Method to JIT-compile matrix multiplication.""" return self.matmul(tensor_0, tensor_1, out) - + def jit_dot(self, tensor_0, tensor_1, out): + """Method to JIT-compile dot product.""" return self.dot(tensor_0, tensor_1, out) - + def jit_concatenate(self, tensors, axis, out): + """Method to JIT-compile concatenation.""" return self.concatenate(tensors, axis, out) - + def jit_stack(self, tensors, axis, out): + """Method to JIT-compile stacking.""" return self.stack(tensors, axis, out) - + def jit_update(self, tensor, indices, values): + """Method to JIT-compile updation.""" return self.update(tensor, indices, values) def empty(self, @@ -1003,4 +1030,4 @@ def argmax(self, return self.library.argmax(self.convert_to_typed( tensor=tensor - )) \ No newline at end of file + )) diff --git a/quantrl/backends/context_manager.py b/quantrl/backends/context_manager.py index 1aee460..b5c5be2 100644 --- a/quantrl/backends/context_manager.py +++ b/quantrl/backends/context_manager.py @@ -6,18 +6,35 @@ __name__ = 'quantrl.backends.context_manager' __authors__ = ["Sampreet Kalita"] __created__ = "2024-10-09" -__updated__ = "2024-10-09" +__updated__ = "2024-10-13" # quantrl modules from .base import BaseBackend -BACKEND_INSTANCES = dict() +BACKEND_INSTANCES = {} +# TODO: validate arguments def get_backend_instance( library:str, precision:str='double', - device:str='cuda' + device:str='gpu' ) -> BaseBackend: + """Method to obtain an instantiated backend. + + Parameters + ---------- + library: str + Name of the library. Options are ``'jax'``, ``'numpy'`` and ``'torch'``. + precision: str, default='double' + Precision of the numerical values in the backend. Options are ``'single'`` and ``'double'``. + device: str, default='gpu' + Device for the backend. Options are ``'cpu'`` and ``'gpu'``. + + Returns + ------- + backend: :class:`quantrl.backends.base.BaseBackend` + The instantiated backend. + """ if library in BACKEND_INSTANCES: return BACKEND_INSTANCES[library] if 'jax' in library.lower(): @@ -33,4 +50,4 @@ def get_backend_instance( from .numpy import NumPyBackend BACKEND_INSTANCES['numpy'] = NumPyBackend(precision=precision) library = 'numpy' - return BACKEND_INSTANCES[library] \ No newline at end of file + return BACKEND_INSTANCES[library] diff --git a/quantrl/backends/jax.py b/quantrl/backends/jax.py index ebef77f..55b8673 100644 --- a/quantrl/backends/jax.py +++ b/quantrl/backends/jax.py @@ -6,23 +6,30 @@ __name__ = 'quantrl.backends.jax' __authors__ = ["Sampreet Kalita"] __created__ = "2024-03-10" -__updated__ = "2024-10-09" +__updated__ = "2024-10-13" # dependencies from inspect import getfullargspec -import numpy as np + import jax import jax.numpy as jnp +import numpy as np # quantrl modules from .base import BaseBackend # TODO: Implement buffers +# TODO: Implement equinox class JaxBackend(BaseBackend): + """Backend to interface the JAX library. + + Refer to :class:`quantrl.backends.base.BaseBackend` for further documentation. + """ + def __init__(self, precision:str='double' - ): + ): # initialize BaseBackend super().__init__( name='jax', @@ -38,60 +45,59 @@ def __init__(self, # set key self.key = None - def numpy_transpose( - tensor, + def transpose( + tensor:jax.Array, axis_0:int=None, axis_1:int=None - ): - # get swapped axes - _shape = self.shape( - tensor=tensor - ) - _axes = np.arange(len(_shape)) - if axis_0 is not None and axis_1 is not None: - _axes[axis_1] = axis_0 % len(_shape) - _axes[axis_0] = axis_1 % len(_shape) - return np.transpose(tensor, axes=_axes) - else: + ) -> jax.Array: + if axis_0 is None or axis_1 is None: return self.convert_to_typed( tensor=tensor ).T + # get swapped axes + _shape = jnp.shape(tensor) + _axes = jnp.arange(len(_shape)) + _axes[axis_1] = axis_0 % len(_shape) + _axes[axis_0] = axis_1 % len(_shape) + + return jnp.transpose(tensor, axes=_axes) + self.jit_transpose = jax.jit( - fun=numpy_transpose, - static_argnames=('axis_0', 'axis_1') + fun=transpose, + static_argnums=(1, 2) ) self.jit_repeat = jax.jit( - fun=lambda tensor, repeats, axis: jnp.repeat(tensor, repeats, axis), - static_argnames=('repeats', 'axis') + fun=jnp.repeat, + static_argnums=(1, 2) ) self.jit_add = jax.jit( fun=lambda tensor_0, tensor_1, out: jnp.add(tensor_0, tensor_1), - donate_argnames='out' + donate_argnums=(2, ) ) self.jit_matmul = jax.jit( fun=lambda tensor_0, tensor_1, out: jnp.matmul(tensor_0, tensor_1), - donate_argnames='out' + donate_argnums=(2, ) ) self.jit_dot = jax.jit( fun=lambda tensor_0, tensor_1, out: jnp.dot(tensor_0, tensor_1), - donate_argnames='out' + donate_argnums=(2, ) ) self.jit_concatenate = jax.jit( fun=lambda tensors, axis, out: jnp.concatenate(tensors, axis), - static_argnames='axis', - donate_argnames='out' + static_argnums=(1, ), + donate_argnums=(2, ) ) self.jit_stack = jax.jit( fun=lambda tensors, axis, out: jnp.stack(tensors, axis), - static_argnames='axis', - donate_argnames='out' + static_argnums=(1, ), + donate_argnums=(2, ) ) self.jit_update = jax.jit( @@ -138,7 +144,7 @@ def integers(self, high:int=1000, dtype:str=None ) -> jax.Array: - return jax.random.randint(generator, shape, low, high, dtype=self.dtype_from_str( + return jnp.asarray(jax.random.randint(generator, shape, low, high), dtype=self.dtype_from_str( dtype=dtype )) @@ -246,14 +252,14 @@ def stack(self, axis=axis, out=out ) - + def update(self, tensor, indices, values ) -> jax.Array: return tensor.at[indices].set(values) - + def if_else(self, condition, func_true, @@ -279,4 +285,4 @@ def body_func(i, state): # loop and return typed tensor return jax.lax.fori_loop(0, iterations_i, body_func, (self.convert_to_typed( tensor=Y - ), args))[0] \ No newline at end of file + ), args))[0] diff --git a/quantrl/backends/numpy.py b/quantrl/backends/numpy.py index a8e85a2..18449f3 100644 --- a/quantrl/backends/numpy.py +++ b/quantrl/backends/numpy.py @@ -6,7 +6,7 @@ __name__ = 'quantrl.backends.numpy' __authors__ = ["Sampreet Kalita"] __created__ = "2024-03-10" -__updated__ = "2024-10-09" +__updated__ = "2024-10-13" # dependencies import numpy as np @@ -15,6 +15,13 @@ from .base import BaseBackend class NumPyBackend(BaseBackend): + """Backend to interface the NumPy library. + + Parameters + ---------- + precision: str, default='double' + Precision of the numerical values in the backend. Options are ``'single'`` and ``'double'``. + """ def __init__(self, precision:str='double' ): @@ -25,20 +32,14 @@ def __init__(self, tensor_type=np.ndarray, precision=precision ) - # set seeder - self.seeder = None def convert_to_typed(self, tensor, dtype:str=None ) -> np.ndarray: - if self.is_typed( - tensor=tensor, - dtype=dtype - ): - return tensor - return np.array(tensor, dtype=self.dtype_from_str( - dtype=dtype + return np.asarray(tensor, dtype=self.dtype_from_str( + dtype=dtype, + numpy=True ) if dtype is not None else None) def convert_to_numpy(self, @@ -53,13 +54,9 @@ def convert_to_numpy(self, def generator(self, seed:int=None ) -> np.random.Generator: - if self.seeder is None: - if seed is None: - entropy = np.random.randint(1234567890) - else: - entropy = np.random.default_rng(seed).integers(0, 1234567890, (1, ))[0] - self.seeder = np.random.SeedSequence(entropy) - return np.random.default_rng(self.seeder.spawn(1)[0]) + if self.seed_sequence is None: + self.seed_sequence = self.get_seedsequence(seed) + return np.random.default_rng(self.seed_sequence.spawn(1)[0]) def integers(self, generator:np.random.Generator, @@ -99,19 +96,19 @@ def transpose(self, axis_0:int=None, axis_1:int=None ) -> np.ndarray: - _shape = self.shape( - tensor=tensor - ) - _axes = np.arange(len(_shape)) - if axis_0 is not None and axis_1 is not None: - _axes[axis_1] = axis_0 % len(_shape) - _axes[axis_0] = axis_1 % len(_shape) - return np.transpose(tensor, axes=_axes) - else: + if axis_0 is None or axis_1 is None: return self.convert_to_typed( tensor=tensor ).T + # get swapped axes + _shape = np.shape(tensor) + _axes = np.arange(len(_shape)) + _axes[axis_1] = axis_0 % len(_shape) + _axes[axis_0] = axis_1 % len(_shape) + + return np.transpose(tensor, axes=_axes) + def repeat(self, tensor, repeats, @@ -159,7 +156,7 @@ def stack(self, out ) -> np.ndarray: return np.stack(tensors, axis=axis, out=out) - + def update(self, tensor, indices, @@ -167,7 +164,7 @@ def update(self, ) -> np.ndarray: tensor[indices] = values return tensor - + def if_else(self, condition, func_true, @@ -186,4 +183,4 @@ def iterate_i(self, ): for i in range(iterations_i): Y = func(i, Y, args) - return Y \ No newline at end of file + return Y diff --git a/quantrl/backends/torch.py b/quantrl/backends/torch.py index f99e686..863e7be 100644 --- a/quantrl/backends/torch.py +++ b/quantrl/backends/torch.py @@ -6,7 +6,7 @@ __name__ = 'quantrl.backends.torch' __authors__ = ["Sampreet Kalita"] __created__ = "2024-03-10" -__updated__ = "2024-10-09" +__updated__ = "2024-10-13" # dependencies import numpy as np @@ -16,9 +16,18 @@ from .base import BaseBackend class TorchBackend(BaseBackend): + """Backend to interface the PyTorch library. + + Parameters + ---------- + precision: str, default='double' + Precision of the numerical values in the backend. Options are ``'single'`` and ``'double'``. + device: str, default='gpu' + Device for the backend. Options are ``'cpu'`` and ``'gpu'``. + """ def __init__(self, precision:str='double', - device:str='cuda' + device:str='gpu' ): # initialize BaseBackend super().__init__( @@ -29,16 +38,13 @@ def __init__(self, ) # set default device - assert 'cpu' in device or 'cuda' in device, "Invalid precision opted, options are ``'cpu'`` and ``'cuda'``." - if 'cuda' in device and not torch.cuda.is_available(): + assert 'cpu' in device or 'gpu' in device, "Invalid precision opted, options are ``'cpu'`` and ``'gpu'``." + if 'gpu' in device and not torch.cuda.is_available(): print("CUDA not available, defaulting to ``'cpu'``") device = 'cpu' torch.set_default_device(device) self.device = device - # set seeder - self.seeder = None - def convert_to_typed(self, tensor, dtype:str=None @@ -64,14 +70,10 @@ def convert_to_numpy(self, def generator(self, seed:int=None ) -> torch.Generator: - if self.seeder is None: - if seed is None: - entropy = np.random.randint(1234567890) - else: - entropy = np.random.default_rng(seed).integers(0, 1234567890, (1, ))[0] - self.seeder = np.random.SeedSequence(entropy) + if self.seed_sequence is None: + self.seed_sequence = self.get_seedsequence(seed) generator = torch.Generator(device=self.device) - generator.manual_seed(self.seeder.spawn(1)[0]) + generator.manual_seed(self.seed_sequence.spawn(1)[0]) return generator def integers(self, @@ -113,11 +115,11 @@ def transpose(self, axis_0:int=None, axis_1:int=None ) -> torch.Tensor: - if axis_0 is not None and axis_1 is not None: - return torch.transpose(tensor, dim0=axis_0, dim1=axis_1) - return self.convert_to_typed( - tensor=tensor - ).T + if axis_0 is None or axis_1 is None: + return self.convert_to_typed( + tensor=tensor + ).T + return torch.transpose(tensor, dim0=axis_0, dim1=axis_1) def repeat(self, tensor, @@ -166,7 +168,7 @@ def stack(self, out ) -> torch.Tensor: return torch.stack(tensors, dim=axis, out=out) - + def update(self, tensor, indices, @@ -174,7 +176,7 @@ def update(self, ) -> torch.Tensor: tensor[indices] = values return tensor - + def if_else(self, condition, func_true, @@ -193,4 +195,4 @@ def iterate_i(self, ): for i in range(iterations_i): Y = func(i, Y, args) - return Y \ No newline at end of file + return Y diff --git a/quantrl/envs/base.py b/quantrl/envs/base.py index 6e97d5e..f3e270e 100644 --- a/quantrl/envs/base.py +++ b/quantrl/envs/base.py @@ -6,16 +6,18 @@ __name__ = 'quantrl.envs.base' __authors__ = ["Sampreet Kalita"] __created__ = "2023-04-25" -__updated__ = "2024-07-23" +__updated__ = "2024-10-14" # dependencies from abc import ABC, abstractmethod +import sys + from gymnasium import Env from gymnasium.spaces import Box, MultiDiscrete +import numpy as np from stable_baselines3.common import env_util from stable_baselines3.common.vec_env import VecEnv from tqdm.rich import tqdm -import numpy as np # quantrl modules from ..backends.base import BaseBackend @@ -84,29 +86,29 @@ class BaseEnv(ABC): ======================== ================================================ """ - default_axis_args_learning_curve=['Episodes', 'Average Return', [np.sqrt(10) * 1e-4, np.sqrt(10) * 1e6], 'log'] + default_axis_args_learning_curve = ['Episodes', 'Average Return', [np.sqrt(10) * 1e-4, np.sqrt(10) * 1e6], 'log'] """list: Default axis arguments to plot the learning curve.""" - base_env_kwargs = dict( - has_delay=False, - observation_space_range=[-1e12, 1e12], - observation_stds=None, - action_space_range=[-1.0, 1.0], - action_space_type='box', - seed=None, - cache_all_data=True, - cache_dump_interval=100, - average_over=100, - plot=True, - plot_interval=10, - plot_idxs=[-1], - axes_args=[ + base_env_kwargs = { + 'has_delay': False, + 'observation_space_range': [-1e12, 1e12], + 'observation_stds': None, + 'action_space_range': [-1.0, 1.0], + 'action_space_type': 'box', + 'seed': None, + 'cache_all_data': True, + 'cache_dump_interval': 100, + 'average_over': 100, + 'plot': True, + 'plot_interval': 10, + 'plot_idxs': [-1], + 'axes_args': [ ['$t / \\tau$', '$\\tilde{R}$', [np.sqrt(10) * 1e-5, np.sqrt(10) * 1e4], 'log'] ], - axes_lines_max=10, - axes_cols=2, - plot_buffer=False - ) + 'axes_lines_max': 10, + 'axes_cols': 2, + 'plot_buffer': False + } """dict: Default values of all keyword arguments.""" def __init__(self, @@ -127,7 +129,7 @@ def __init__(self, """Class constructor for BaseEnv.""" # update keyword arguments - for key in self.base_env_kwargs: + for key, _ in self.base_env_kwargs.items(): kwargs[key] = kwargs.get(key, self.base_env_kwargs[key]) # validate arguments @@ -135,9 +137,9 @@ def __init__(self, assert n_properties >= 0, "parameter ``n_properties`` should be non-negative" assert action_interval > 0, "parameter ``action_interval`` should be a positive integer" assert len(data_idxs) > 0, "parameter ``data_idxs`` should be a list containing at least one element" - assert kwargs['observation_stds'] is None or type(kwargs['observation_stds']) is list, "parameter ``observation_stds`` should be a list" - assert kwargs['seed'] is None or type(kwargs['seed']) is int, "parameter ``seed`` should be an integer or ``None``" - assert type(kwargs['cache_all_data']) is bool, "parameter ``cache_all_data`` should be a boolean" + assert kwargs['observation_stds'] is None or isinstance(kwargs['observation_stds'], list), "parameter ``observation_stds`` should be a list" + assert kwargs['seed'] is None or isinstance(kwargs['seed'], int), "parameter ``seed`` should be an integer or ``None``" + assert isinstance(kwargs['cache_all_data'], bool), "parameter ``cache_all_data`` should be a boolean" assert kwargs['plot_interval'] > 0, "parameter ``plot_interval`` should be a positive integer" assert len(kwargs['plot_idxs']) == len(kwargs['axes_args']), "number of indices for plot should match number of axes arguments" assert len(kwargs['observation_space_range']) == 2, "parameter ``observation_space_range`` should contain two elements for the minimum and maximum values, both inclusive" @@ -253,12 +255,19 @@ def __init__(self, ) # initialize buffers + self.t_idx = 0 + self.t = self.numpy_real(0.0) + self.action_idx = 0 self._idx_s = 0 self.T_step = None self.States = None self.Observations = None + self.Observation_noises = None self.Reward = None self.Properties = None + if self.plot: + self.plotter_env_idxs = None + self.plotter_env_data = None @abstractmethod def _update_states(self): @@ -307,12 +316,12 @@ def get_reward(self): raise NotImplementedError - def validate_environment(self, + def validate_base(self, shape_reset_states:tuple, shape_get_properties:tuple, shape_get_reward:tuple ): - """Method to validate the interfaced environment. + """Method to validate the base environment. Parameters ---------- @@ -331,7 +340,7 @@ def validate_environment(self, ) assert self.backend.shape( tensor=states_0 - ) == shape_reset_states, "``reset_states`` should return an array with shape ``{}``".format(shape_reset_states) + ) == shape_reset_states, f"``reset_states`` should return an array with shape ``{shape_reset_states}``" # initialize states self.States = self.backend.repeat( tensor=self.backend.reshape( @@ -354,17 +363,17 @@ def validate_environment(self, ) assert self.backend.shape( tensor=self.Properties - ) == shape_get_properties, "``get_properties`` should return an array with shape ``{}``".format(shape_get_properties) + ) == shape_get_properties, f"``get_properties`` should return an array with shape ``{shape_get_properties}``" # validate reward self.Reward = self.backend.convert_to_typed( tensor=self.get_reward() ) assert self.backend.shape( tensor=self.Reward - ) == shape_get_reward, "``get_reward`` should return an array with shape ``{}``".format(shape_get_reward) + ) == shape_get_reward, f"``get_reward`` should return an array with shape ``{shape_get_reward}``" except AttributeError as error: print(f"Missing required method or attribute: ({error}). Refer to **Notes** of :class:`quantrl.envs.base.BaseEnv` for the implementation format of the missing method or add the missing attribute to the ``reset_states`` method.") - exit() + sys.exit() def reset(self): """Method to reset the time and obtain initial states as a typed tensor. @@ -413,13 +422,8 @@ def reset(self): repeats=self.shape_T[0], axis=0 ) - else: - self.Observation_noises = self.backend.zeros( - shape=(self.shape_T[0], *_shape), - dtype='real' - ) # initialize observations - observations_0 = states_0 + self.Observation_noises[0] + observations_0 = states_0 + (self.Observation_noises[0] if self.observation_stds is not None else 0.0) self.Observations = self.backend.repeat( tensor=self.backend.reshape( tensor=observations_0, @@ -455,7 +459,7 @@ def update(self): # update actual states and observed states self.States = self._update_states() - self.Observations = self.States + self.Observation_noises[self.t_idx:self.t_idx + _dim_T_step] + self.Observations = self.States + (self.Observation_noises[self.t_idx:self.t_idx + _dim_T_step] if self.observation_stds is not None else 0.0) # update properties if self.n_properties > 0: @@ -474,7 +478,7 @@ def update(self): self.action_idx += 1 # check if completed - terminated = False if self.t_idx + 1 < _dim_T else True + terminated = not self.t_idx + 1 < _dim_T return self.Observations[_dim_T_step - 1], self.Reward[_dim_T_step - 1], terminated @@ -629,11 +633,11 @@ def replay_trajectories(self, # close plotter self.plotter.close() - def close(self, + def close_base(self, n_episodes, save_replay=True ): - """Method to close the environment. + """Method to close the base environment. Parameters ---------- @@ -648,7 +652,7 @@ def close(self, self.plotter.make_gif( file_name=self.file_path_prefix + '_' + '_'.join([ 'replay', - str(0), + str(self._idx_s), str(n_episodes - 1), str(self.plot_interval) ]) @@ -725,12 +729,14 @@ def __init__(self, dtype='real' ) self.rewards = None - self.data_rewards = list() + self.data_rewards = [] self.all_data = None self.data = None - def validate_environment(self): - return super().validate_environment( + def validate(self): + """Method to validate BaseGymEnv.""" + + return super().validate_base( shape_reset_states=(self.n_observations, ), shape_get_properties=(self.action_interval + 1, self.n_properties), shape_get_reward=(self.action_interval + 1, ) @@ -773,7 +779,7 @@ def reset(self, } def step(self, - actions + action ): """Method to take one single step and obtain the observations and reward as NumPy arrays or typed tensors. @@ -798,7 +804,7 @@ def step(self, # set actions self.actions = self.backend.convert_to_typed( - tensor=actions, + tensor=action, dtype='real' ) * self.action_maximums @@ -810,7 +816,8 @@ def step(self, # check if truncation required truncated = self.check_truncation() - print(f'Trajectory #{self.traj_idx} truncated') if truncated > 0 else True + if truncated > 0: + print(f'Trajectory #{self.traj_idx} truncated') # if trajectory ends if terminated or truncated: @@ -898,6 +905,8 @@ def evolve(self, Parameters ---------- + show_progress: bool, default=True + Option to display the progress. close: bool, default=True Option to close the environment. """ @@ -982,7 +991,7 @@ def close(self, ) # clean - super().close( + super().close_base( n_episodes=self.traj_idx, save_replay=save ) @@ -1034,7 +1043,7 @@ def __init__(self, # update attributes self.n_envs = n_envs - self.render_mode = [None] * n_envs + self.render_mode = None self.action_maximums_batch = self.backend.repeat( tensor=self.backend.reshape( tensor=self.action_maximums, @@ -1071,13 +1080,15 @@ def __init__(self, dtype='real' ) self.rewards = None - self.data_rewards = list() + self.data_rewards = [] self.data = None if self.plot: self.env_idx_arr = np.arange(self.n_envs, dtype=self.numpy_int) - def validate_environment(self): - return super().validate_environment( + def validate(self): + """Method to validate BaseSB3Env.""" + + return super().validate_base( shape_reset_states=(self.n_envs, self.n_observations), shape_get_properties=(self.action_interval + 1, self.n_envs, self.n_properties), shape_get_reward=(self.action_interval + 1, self.n_envs) @@ -1133,7 +1144,7 @@ def env_method(self, def get_attr(self, attr_name, - indices=None + indices=None ): """Method to obtain attributes of the sub-environments. @@ -1169,7 +1180,7 @@ def set_attr(self, Indices of the environments. If ``None``, the values for all sub-environments are considered. """ - [setattr(self, attr_name, value) for _ in range(indices if indices is not None else self.n_envs)] + return [setattr(self, attr_name, value) for _ in range(indices if indices is not None else self.n_envs)] def reset(self, seed:float=None, @@ -1259,17 +1270,18 @@ def step_wait(self): # check if truncation required truncated = self.check_truncation() - print(f'Batch #{self.batch_idx} truncated') if truncated > 0 else True + if truncated > 0: + print(f'Batch #{self.batch_idx} truncated') # if trajectory ends if terminated or truncated: # update plotter and io if self.plot: - for _i in range(len(self.plotter_env_idxs)): + for i, plotter_env_idx in enumerate(self.plotter_env_idxs): self.plotter.plot_lines( xs=self.T_norm, - Y=self.plotter_env_data[_i, :, :], - traj_idx=self.batch_idx * self.n_envs + self.plotter_env_idxs[_i], + Y=self.plotter_env_data[i, :, :], + traj_idx=self.batch_idx * self.n_envs + plotter_env_idx, update_buffer=self.plot_buffer ) # update episode reward @@ -1312,6 +1324,7 @@ def update_data(self): axis_0=1, axis_1=0 ) + _Properties = None if self.n_properties > 0: _Properties = self.backend.transpose( tensor=self.Properties[:_dim], @@ -1375,6 +1388,8 @@ def evolve(self, Parameters ---------- + show_progress: bool, default=True + Option to display the progress. close: bool, default=True Option to close the environment. save: bool, default=False @@ -1397,7 +1412,7 @@ def evolve(self, # udpate reward super().update() - + # update data self.update_data() @@ -1471,7 +1486,7 @@ def close(self, ) # clean - super().close( + super().close_base( n_episodes=self.batch_idx, save_replay=save - ) \ No newline at end of file + ) diff --git a/quantrl/envs/deterministic.py b/quantrl/envs/deterministic.py index 2d67bd9..90a4c1a 100644 --- a/quantrl/envs/deterministic.py +++ b/quantrl/envs/deterministic.py @@ -6,7 +6,7 @@ __name__ = 'quantrl.envs.deterministic' __authors__ = ["Sampreet Kalita"] __created__ = "2023-04-25" -__updated__ = "2024-10-09" +__updated__ = "2024-10-14" # quantrl modules from ..backends.context_manager import get_backend_instance @@ -76,14 +76,14 @@ class LinearizedHOEnv(BaseGymEnv): ============ ================================================ """ - default_params = dict() + default_params = {} """dict: Default parameters of the environment.""" - default_ode_solver_params = dict( - ode_method='vode', - ode_atol=1e-9, - ode_rtol=1e-6 - ) + default_ode_solver_params = { + 'ode_method': 'vode', + 'ode_atol': 1e-9, + 'ode_rtol': 1e-6 + } """dict: Default parameters of the ODE solver.""" backend_libraries = ['torch', 'jax', 'numpy'] @@ -106,14 +106,14 @@ def __init__(self, data_idxs:list, backend_library:str='numpy', backend_precision:str='double', - backend_device:str='cuda', + backend_device:str='gpu', dir_prefix:str='data', **kwargs ): """Class constructor for LinearizedHOEnv.""" # validate arguments - assert backend_library in self.backend_libraries, "parameter ``solver_type`` should be one of ``{}``".format(self.backend_libraries) + assert backend_library in self.backend_libraries, f"parameter ``solver_type`` should be one of ``{self.backend_libraries}``" # select backend backend = get_backend_instance( @@ -133,11 +133,11 @@ def __init__(self, self.num_corrs = num_quads**2 # set parameters - self.params = dict() - for key in self.default_params: + self.params = {} + for key, _ in self.default_params.items(): self.params[key] = params.get(key, self.default_params[key]) # update keyword arguments - for key in self.default_ode_solver_params: + for key, _ in self.default_ode_solver_params.items(): kwargs[key] = kwargs.get(key, self.default_ode_solver_params[key]) # set matrices self.A = backend.zeros( @@ -502,14 +502,14 @@ class LinearizedHOVecEnv(BaseSB3Env): ============ ================================================ """ - default_params = dict() + default_params = {} """dict: Default parameters of the environment.""" - default_ode_solver_params = dict( - ode_method='vode', - ode_atol=1e-9, - ode_rtol=1e-6 - ) + default_ode_solver_params = { + 'ode_method': 'vode', + 'ode_atol': 1e-9, + 'ode_rtol': 1e-6 + } """dict: Default parameters of the ODE solver.""" backend_libraries = ['torch', 'jax', 'numpy'] @@ -533,35 +533,24 @@ def __init__(self, data_idxs:list, backend_library:str='numpy', backend_precision:str='double', - backend_device:str='cuda', + backend_device:str='gpu', dir_prefix:str='data', **kwargs ): """Class constructor for LinearizedHOEnv.""" # validate arguments - assert backend_library in self.backend_libraries, "parameter ``solver_type`` should be one of ``{}``".format(self.backend_libraries) + assert backend_library in self.backend_libraries, f"parameter ``solver_type`` should be one of ``{self.backend_libraries}``" # select backend - if 'torch' in backend_library: - from ..backends.torch import TorchBackend - from ..solvers.torch import TorchDiffEqIVPSolver as IVPSolverClass - backend = TorchBackend( - precision=backend_precision, - device=backend_device - ) - elif 'jax' in backend_library: - from ..backends.jax import JaxBackend - from ..solvers.jax import DiffraxIVPSolver as IVPSolverClass - backend = JaxBackend( - precision=backend_precision - ) - else: - from ..backends.numpy import NumPyBackend - from ..solvers.numpy import SciPyIVPSolver as IVPSolverClass - backend = NumPyBackend( - precision=backend_precision - ) + backend = get_backend_instance( + library=backend_library, + precision=backend_precision, + device=backend_device + ) + IVPSolver = get_IVP_solver( + library=backend_library + ) # set constants self.name = name @@ -571,11 +560,11 @@ def __init__(self, self.num_corrs = num_quads**2 # set parameters - self.params = dict() - for key in self.default_params: + self.params = {} + for key, _ in self.default_params.items(): self.params[key] = params.get(key, self.default_params[key]) # update keyword arguments - for key in self.default_ode_solver_params: + for key, _ in self.default_ode_solver_params.items(): kwargs[key] = kwargs.get(key, self.default_ode_solver_params[key]) # set matrices self.A = backend.zeros( @@ -608,7 +597,7 @@ def __init__(self, ) # initialize solver - self.solver = IVPSolverClass( + self.solver = IVPSolver( func=self.func, y_0=self.States[-1], T=self.T, @@ -875,4 +864,4 @@ def get_mode_rates_real(self, ), axis=1, out=self.mode_rates_real - ) \ No newline at end of file + ) diff --git a/quantrl/envs/stochastic.py b/quantrl/envs/stochastic.py index 792dfd6..bf0f8f2 100644 --- a/quantrl/envs/stochastic.py +++ b/quantrl/envs/stochastic.py @@ -6,7 +6,7 @@ __name__ = 'quantrl.envs.stochastic' __authors__ = ["Sampreet Kalita"] __created__ = "2023-04-25" -__updated__ = "2024-10-09" +__updated__ = "2024-10-14" # dependencies import numpy as np @@ -68,7 +68,7 @@ class LinearEnv(BaseGymEnv): Keyword arguments. Refer to the ``kwargs`` parameter of :class:`quantrl.envs.base.BaseEnv` for available options. """ - default_params = dict() + default_params = {} """dict: Default parameters of the environment.""" backend_libraries = ['torch', 'jax', 'numpy'] @@ -89,14 +89,14 @@ def __init__(self, data_idxs:list, backend_library:str='numpy', backend_precision:str='double', - backend_device:str='cuda', + backend_device:str='gpu', dir_prefix:str='data', **kwargs ): """Class constructor for LinearEnv.""" # validate arguments - assert backend_library in self.backend_libraries, "parameter ``solver_type`` should be one of ``{}``".format(self.backend_libraries) + assert backend_library in self.backend_libraries, f"parameter ``solver_type`` should be one of ``{self.backend_libraries}``" # select backend backend = get_backend_instance( @@ -110,10 +110,11 @@ def __init__(self, self.desc = desc # set parameters - self.params = dict() - for key in self.default_params: + self.params = {} + for key, _ in self.default_params.items(): self.params[key] = params.get(key, self.default_params[key]) # set buffers + self.Ws = None self.I = backend.eye( N=n_observations, dtype='real' @@ -137,7 +138,7 @@ def __init__(self, action_interval=action_interval, data_idxs=data_idxs, dir_prefix=(dir_prefix if dir_prefix != 'data' else ('data/' + self.name.lower()) + '/env') + '_' + '_'.join([ - str(self.params[key]) for key in self.params + str(val) for _, val in self.params.items() ]), file_prefix='lin_env', **kwargs @@ -280,4 +281,4 @@ def get_noise_prefixes(self, Noise prefixes for each observation with shape ``(n_observations, )``. """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/quantrl/io.py b/quantrl/io.py index 5743856..1447be8 100644 --- a/quantrl/io.py +++ b/quantrl/io.py @@ -6,19 +6,20 @@ __name__ = 'quantrl.io' __authors__ = ["Sampreet Kalita"] __created__ = "2023-12-07" -__updated__ = "2024-07-22" +__updated__ = "2024-10-14" # dependencies +import gc +import os from threading import Thread + from tqdm.rich import tqdm -import gc import numpy as np -import os # TODO: Implement ConsoleIO # TODO: Add ``load_parts`` method -class FileIO(object): +class FileIO(): """Handler for file input-output. Initializes ``cache`` to ``None`` and ``index`` to ``-1``. @@ -40,7 +41,7 @@ def __init__(self, """Class constructor for FileIO.""" # set attributes - assert type(cache_dump_interval) is int and cache_dump_interval > 0, "parameter ``disk_cache_size`` should be a positive integer" + assert isinstance(cache_dump_interval, int) and cache_dump_interval > 0, "parameter ``disk_cache_size`` should be a positive integer" self.disk_cache_dir = disk_cache_dir self.cache_dump_interval = cache_dump_interval try: @@ -148,7 +149,7 @@ def get_disk_cache(self, """ # iterate over parts - self.cache_list = list() + cache_list = [] for i in tqdm( range(int(idx_start / self.cache_dump_interval) * self.cache_dump_interval, idx_end + 1, self.cache_dump_interval), desc="Loading", @@ -163,12 +164,12 @@ def get_disk_cache(self, idx_start=i, idx_end=_idx_e ) - self.cache_list += [_cache[:, :, idxs].copy() if idxs is not None else _cache.copy()] + cache_list += [_cache[:, :, idxs].copy() if idxs is not None else _cache.copy()] # clear loaded cache del _cache gc.collect() - return np.concatenate(self.cache_list)[idx_start % self.cache_dump_interval:] + return np.concatenate(cache_list)[idx_start % self.cache_dump_interval:] def _load_cache(self, idx_start:int, @@ -242,4 +243,4 @@ def close(self, ) # clean - del self \ No newline at end of file + del self diff --git a/quantrl/plotters.py b/quantrl/plotters.py index 76e61b3..e78e0d5 100644 --- a/quantrl/plotters.py +++ b/quantrl/plotters.py @@ -6,15 +6,16 @@ __name__ = 'quantrl.plotters' __authors__ = ["Sampreet Kalita"] __created__ = "2023-12-08" -__updated__ = "2024-07-23" +__updated__ = "2024-10-14" # dependencies from io import BytesIO +import os + from matplotlib.gridspec import GridSpec -from PIL import Image import matplotlib.pyplot as plt import numpy as np -import os +from PIL import Image # OpenMP configuration os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' @@ -35,7 +36,7 @@ plt.rcParams['xtick.labelsize'] = 14 plt.rcParams['ytick.labelsize'] = 14 -class TrajectoryPlotter(object): +class TrajectoryPlotter(): """Plotter for trajectories. Initializes ``axes_rows``, ``fig``, ``axes`` and ``lines``. @@ -88,7 +89,7 @@ def __init__(self, self.axes_rows = int(np.ceil(len(self.axes_args) / self.axes_cols)) self.fig = plt.figure(figsize=(6.0 * self.axes_cols, 3.0 * self.axes_rows)) self.gspec = GridSpec(self.axes_rows, self.axes_cols, figure=self.fig, width_ratios=[0.2] * self.axes_cols) - self.axes = list() + self.axes = [] self.lines = None # format frame @@ -112,7 +113,7 @@ def __init__(self, self.fig.tight_layout() # initialize buffers - self.frames = list() + self.frames = [] def plot_lines(self, xs, @@ -140,13 +141,13 @@ def plot_lines(self, line.set_alpha(0.1) # add new lines - self.lines = list() + self.lines = [] for i, ax in enumerate(self.axes): if self.axes_lines_max and len(ax.get_lines()) >= self.axes_lines_max: line = ax.get_lines()[0] line.remove() self.lines.append(ax.plot(xs, Y[:, i])[0]) - if self.show_title and self.fig._suptitle is not None: + if self.show_title: self.fig.suptitle('#' + str(traj_idx)) self.fig.canvas.draw() self.fig.canvas.flush_events() @@ -187,7 +188,7 @@ def make_gif(self, # reset buffer del self.frames - self.frames = list() + self.frames = [] def save_plot(self, file_name:str @@ -223,7 +224,7 @@ def close(self): # clean del self -class LearningCurvePlotter(object): +class LearningCurvePlotter(): """Plotter for learning curve. Initializes ``fig``, ``ax`` and ``data``. @@ -234,21 +235,21 @@ class LearningCurvePlotter(object): Lists of axis properties. The first element of each entry is the ``x_label``, the second is ``y_label``, the third is ``[y_limit_min, y_limit_max]`` and the fourth is ``y_scale``. average_over: int, default=100 Number of points to average over. - percentiles: list, default=[25, 50, 75] - Percentile values for intraquartile ranges. + percentiles: list, default=None + Percentile values for intraquartile ranges. If ``None``, the percentiles are set to ``[25, 50, 75]``. """ def __init__(self, axis_args:list, average_over:int=100, - percentiles:list=[25, 50, 75] + percentiles:list=None ): """Class constructor for LearningCurvePlotter.""" # set attributes self.axis_args = axis_args self.average_over = average_over - self.percentiles = percentiles + self.percentiles = percentiles if percentiles is not None else [25, 50, 75] # turn on interactive mode plt.ion() @@ -263,7 +264,7 @@ def __init__(self, self.fig.tight_layout() # initialze buffer - self.data = list() + self.data = [] self.line = None self.line_faint = None @@ -291,12 +292,13 @@ def add_data(self, """ # if averaging opted + data_rewards_smooth = data_rewards if self.average_over is not None: data_rewards_smooth = np.convolve(data_rewards, np.ones((self.average_over, )) / float(self.average_over), mode='valid') # update data if renew: - self.data = list() + self.data = [] self.line = None self.line_faint = None self.data.append(data_rewards_smooth) @@ -351,4 +353,4 @@ def close(self): plt.close() # clean - del self \ No newline at end of file + del self diff --git a/quantrl/solvers/base.py b/quantrl/solvers/base.py index 8711d98..aaad82c 100644 --- a/quantrl/solvers/base.py +++ b/quantrl/solvers/base.py @@ -6,10 +6,11 @@ __name__ = 'quantrl.solvers.base' __authors__ = ["Sampreet Kalita"] __created__ = "2024-03-10" -__updated__ = "2024-05-29" +__updated__ = "2024-10-14" # dependencies from abc import ABC, abstractmethod + from tqdm import tqdm # quantrl modules @@ -55,6 +56,7 @@ class BaseIVPSolver(ABC): ..note: In the presence of delay, the parameter ``'step_interval'`` is overriden by the delay interval. """ + # attributes default_solver_params = { 'method': 'vode', 'atol': 1e-12, @@ -64,6 +66,8 @@ class BaseIVPSolver(ABC): 'complex': False } """dict: Default parameters of the solver.""" + solver_methods = [] + """list: Methods used by the solver.""" def __init__(self, func, @@ -100,20 +104,20 @@ def __init__(self, ) # set params - self.solver_params = dict() - for key in self.default_solver_params: + self.solver_params = {} + for key, _ in self.default_solver_params.items(): self.solver_params[key] = solver_params.get(key, self.default_solver_params[key]) # override step dimension with delay interval if DDE if self.has_delay and self.delay_interval != 0: self.solver_params['step_interval'] = self.delay_interval # validate params - assert self.solver_params['method'] in self.solver_methods, "parameter ``method`` should be one of ``{}``".format(self.solver_methods) - assert type(self.solver_params['step_interval']) is int and self.solver_params['step_interval'] < self.shape_T[0], "parameter ``step_interval`` should be an integer with a value less than the total number of steps" + assert self.solver_params['method'] in self.solver_methods, f"parameter ``method`` should be one of ``{self.solver_methods}``" + assert isinstance(self.solver_params['step_interval'], int) and self.solver_params['step_interval'] < self.shape_T[0], "parameter ``step_interval`` should be an integer with a value less than the total number of steps" # step constants self.step_interval = self.solver_params['step_interval'] - + @abstractmethod def integrate(self, T_step, @@ -138,7 +142,7 @@ def integrate(self, """ raise NotImplementedError - + @abstractmethod def interpolate(self, T_step, @@ -193,7 +197,7 @@ def step(self, # update delay function if self.has_delay: self.func_delay = self.interpolate( - T=T_step, + T_step=T_step, Y=_Y ) @@ -249,4 +253,4 @@ def solve_ivp(self, params=params )[1:] - return Y \ No newline at end of file + return Y diff --git a/quantrl/solvers/context_manager.py b/quantrl/solvers/context_manager.py index 05d2aff..1dfcca1 100644 --- a/quantrl/solvers/context_manager.py +++ b/quantrl/solvers/context_manager.py @@ -6,16 +6,28 @@ __name__ = 'quantrl.solvers.context_manager' __authors__ = ["Sampreet Kalita"] __created__ = "2024-10-09" -__updated__ = "2024-10-09" +__updated__ = "2024-10-14" # quantrl modules from .base import BaseIVPSolver -IVP_SOLVERS = dict() +IVP_SOLVERS = {} def get_IVP_solver( library:str ) -> BaseIVPSolver: + """Method to obtain an IVP solver class. + + Parameters + ---------- + library: str + Name of the library. Options are ``'jax'``, ``'numpy'`` and ``'torch'``. + + Returns + ------- + IVPSolver: :class:`quantrl.solvers.base.BaseIVPSolver` + The IVP solver class. + """ if library in IVP_SOLVERS: return IVP_SOLVERS[library] if 'jax' in library.lower(): @@ -31,4 +43,4 @@ def get_IVP_solver( from .numpy import SciPyIVPSolver IVP_SOLVERS['numpy'] = SciPyIVPSolver library = 'numpy' - return IVP_SOLVERS[library] \ No newline at end of file + return IVP_SOLVERS[library] diff --git a/quantrl/solvers/jax.py b/quantrl/solvers/jax.py index d391c18..504e581 100644 --- a/quantrl/solvers/jax.py +++ b/quantrl/solvers/jax.py @@ -6,7 +6,7 @@ __name__ = 'quantrl.solvers.jax' __authors__ = ["Sampreet Kalita"] __created__ = "2024-03-10" -__updated__ = "2024-03-23" +__updated__ = "2024-10-14" # dependencies import jax @@ -72,7 +72,7 @@ def integrate(self, y_0 = self.backend.convert_to_typed( tensor=y_0 ) - + # integrate return dfx.diffeqsolve( terms=self.term, @@ -86,11 +86,11 @@ def integrate(self, stepsize_controller=dfx.PIDController( atol=self.solver_params['atol'], rtol=self.solver_params['rtol'] - ) + ) ).ys - + def interpolate(self, T_step, Y ): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/quantrl/solvers/measure.py b/quantrl/solvers/measure.py index d5b8452..ef90e2d 100644 --- a/quantrl/solvers/measure.py +++ b/quantrl/solvers/measure.py @@ -24,10 +24,11 @@ __name__ = 'qom.solvers.measure' __authors__ = ["Sampreet Kalita"] __created__ = "2021-01-04" -__updated__ = "2023-05-29" +__updated__ = "2024-10-14" # dependencies from typing import Union + import numpy as np class QCMSolver(): @@ -51,7 +52,7 @@ class QCMSolver(): 'measure_codes' (*list* or *str*) codenames of the measures to calculate. Options are ``'discord_G'`` for Gaussian quantum discord [3]_, ``'entan_ln'`` for quantum entanglement (using matrix multiplications, fallback) [1]_, ``'entan_ln_2'`` for quantum entanglement (using analytical expressions) [2]_, ``'sync_c'`` for complete quantum synchronization [5]_, ``'sync_p'`` for quantum phase synchronization [5]_). Default is ``['entan_ln']``. 'indices' (*list* or *tuple*) indices of the modes as a list or tuple of two integers. Default is ``(0, 1)``. ================ ==================================================== - cb_update : callable, optional + cb_update : callback, optional Callback function to update status and progress, formatted as ``cb_update(status, progress, reset)``, where ``status`` is a string, ``progress`` is a float and ``reset`` is a boolean. """ @@ -91,7 +92,10 @@ def __init__(self, Modes, Corrs, params:dict, cb_update=None): # set parameters self.set_params(params) - + + # set callback + self.cb_update = cb_update + def set_params(self, params:dict): """Method to validate and set the solver parameters. @@ -113,26 +117,26 @@ def set_params(self, params:dict): # validate measure type assert isinstance(measure_codes, Union[list, str].__args__), "Value of key ``'measure_codes'`` can only be of types ``list`` or ``str``" # convert to list - measure_codes = [measure_codes] if type(measure_codes) is str else measure_codes + measure_codes = [measure_codes] if isinstance(measure_codes, str) else measure_codes # check elements for measure_code in measure_codes: - assert measure_code in self.method_codes, "Elements of key ``'measure_codes'`` can only be one or more keys of ``{}``".format(self.method_codes.keys) + assert measure_code in self.method_codes, f"Elements of key ``'measure_codes'`` can only be one or more keys of ``{self.method_codes.keys()}``" # update parameter params['measure_codes'] = measure_codes # validate indices assert isinstance(indices, Union[list, tuple].__args__), "Value of key ``'indices'`` can only be of types ``list`` or ``tuple``" # convert to list - indices = list(indices) if type(indices) is tuple else indices + indices = list(indices) if isinstance(indices, tuple) else indices # check length assert len(indices) == 2, "Value of key ``'indices'`` can only have 2 elements" - assert indices[0] < _dim and indices[1] < _dim, "Elements of key ``'indices'`` cannot exceed the total number of modes ({})".format(_dim) + assert indices[0] < _dim and indices[1] < _dim, f"Elements of key ``'indices'`` cannot exceed the total number of modes ({_dim})" # update parameter params['indices'] = indices # set solver parameters - self.params = dict() - for key in self.solver_defaults: + self.params = {} + for key, _ in self.solver_defaults.items(): self.params[key] = params.get(key, self.solver_defaults[key]) def get_measures(self): @@ -156,11 +160,10 @@ def get_measures(self): # find measures for j in range(_dim[1]): # display progress - if show_progress: - self.updater.update_progress( - pos=None, - dim=_dim[1], + if show_progress and self.cb_update is not None: + self.cb_update( status="-" * (35 - len(measure_codes[j])) + "Obtaining Measures (" + measure_codes[j] + ")", + progress=0.0, reset=False ) @@ -170,13 +173,15 @@ def get_measures(self): Measures[:, j] = getattr(self, func_name)(pos_i=2 * indices[0], pos_j=2 * indices[1]) if 'corrs_P_p' not in measure_codes[j] else getattr(self, func_name)(pos_i=2 * indices[0] + 1, pos_j=2 * indices[1] + 1) # display completion - if show_progress: - self.updater.update_info( - status="-" * 39 + "Measures Obtained" + if show_progress and self.cb_update is not None: + self.cb_update( + status="-" * 39 + "Measures Obtained", + progress=1.0, + reset=False ) return Measures - + def get_submatrices(self, pos_i:int, pos_j:int): """Helper function to obtain the block matrices of the required modes and its components. @@ -247,7 +252,7 @@ def get_invariants(self, pos_i:int, pos_j:int): # symplectic invariants return np.linalg.det(As), np.linalg.det(Bs), np.linalg.det(Cs), np.linalg.det(Corrs_modes) - + def get_correlation_Pearson(self, pos_i:int, pos_j:int): r"""Method to obtain the Pearson correlation coefficient. @@ -328,7 +333,7 @@ def get_discord_Gaussian(self, pos_i:int, pos_j:int): # update W values Ws[conditions_W_1] = ((2 * np.abs(I_3s[conditions_W_1]) + np.sqrt(_discriminants[conditions_W_1])) / _divisors[conditions_W_1])**2 # W values without main condition - # check sqrt and NaN condtition + # check sqrt and NaN condtition _bs = np.multiply(I_1s, I_2s) + I_4s - I_3s**2 _4acs = 4 * np.multiply(np.multiply(I_1s, I_2s), I_4s) _discriminants = _bs**2 - _4acs @@ -339,9 +344,10 @@ def get_discord_Gaussian(self, pos_i:int, pos_j:int): # all validity conditions conditions = np.logical_and(conditions_mu, np.logical_or(conditions_W_1, conditions_W_2)) - # f function - func_f = lambda x: np.multiply(x + 0.5, np.log10(x + 0.5)) - np.multiply(x - 0.5, np.log10(x - 1 / 2)) - + # f function + def func_f(x): + return np.multiply(x + 0.5, np.log10(x + 0.5)) - np.multiply(x - 0.5, np.log10(x - 1 / 2)) + # update quantum discord values Discord_G[conditions] = func_f(np.sqrt(I_2s[conditions])) \ - func_f(mu_pluses[conditions]) \ @@ -349,7 +355,7 @@ def get_discord_Gaussian(self, pos_i:int, pos_j:int): + func_f(np.sqrt(Ws[conditions])) return Discord_G - + def get_entanglement_logarithmic_negativity(self, pos_i:int, pos_j:int): """Method to obtain the logarithmic negativity entanglement values using matrices [1]_. @@ -378,17 +384,17 @@ def get_entanglement_logarithmic_negativity(self, pos_i:int, pos_j:int): # smallest symplectic eigenvalue eigs, _ = np.linalg.eig(np.matmul(self.Omega_s, Corrs_modes)) - eig_min = np.min(np.abs(eigs), axis=1) + eigs_min = np.min(np.abs(eigs), axis=1) # initialize entanglement - Entan_ln = np.zeros_like(eig_min, dtype=np.float_) + Entan_ln = np.zeros_like(eigs_min, dtype=np.float_) # update entanglement - for i in range(len(eig_min)): - if eig_min[i] < 0: + for i, eig_min in enumerate(eigs_min): + if eig_min < 0: Entan_ln[i] = 0 else: - Entan_ln[i] = np.maximum(0.0, - np.log(2 * eig_min[i])) + Entan_ln[i] = np.maximum(0.0, - np.log(2 * eig_min)) return Entan_ln @@ -430,9 +436,9 @@ def get_entanglement_logarithmic_negativity_2(self, pos_i:int, pos_j:int): # clip negative values Entan_ln[Entan_ln < 0.0] = 0.0 - + return Entan_ln - + def get_synchronization_complete(self, pos_i:int, pos_j:int): """Method to obtain the complete quantum synchronization values [5]_. @@ -482,7 +488,7 @@ def get_synchronization_phase(self, pos_i:int, pos_j:int): cos_js = np.cos(arg_js) sin_is = np.sin(arg_is) sin_js = np.sin(arg_js) - + # transformation for ith mode momentum quadrature p_i_prime_2s = np.multiply(sin_is**2, self.Corrs[:, pos_i, pos_i]) \ - np.multiply(np.multiply(sin_is, cos_is), self.Corrs[:, pos_i, pos_i + 1]) \ @@ -522,11 +528,11 @@ def get_average_amplitude_difference(Modes): """ # validate modes - assert Modes is not None and (type(Modes) is list or type(Modes) is np.ndarray) and np.shape(Modes)[1] == 2, "Parameter ``Modes`` should be a list or NumPy array with dimension ``(dim, 2)``" + assert Modes is not None and isinstance(Modes, (list, np.ndarray)) and np.shape(Modes)[1] == 2, "Parameter ``Modes`` should be a list or NumPy array with dimension ``(dim, 2)``" # get means means = np.mean(Modes, axis=0) - + # average amplitude difference return np.mean([np.linalg.norm(modes[0] - means[0]) - np.linalg.norm(modes[1]- means[1]) for modes in Modes]) @@ -545,11 +551,11 @@ def get_average_phase_difference(Modes): """ # validate modes - assert Modes is not None and (type(Modes) is list or type(Modes) is np.ndarray) and np.shape(Modes)[1] == 2, "Parameter ``Modes`` should be a list or NumPy array with dimension ``(dim, 2)``" + assert Modes is not None and isinstance(Modes, (list, np.ndarray)) and np.shape(Modes)[1] == 2, "Parameter ``Modes`` should be a list or NumPy array with dimension ``(dim, 2)``" # get means means = np.mean(Modes, axis=0) - + # average phase difference return np.mean([np.angle(modes[0] - means[0]) - np.angle(modes[1]- means[1]) for modes in Modes]) @@ -568,18 +574,18 @@ def get_bifurcation_amplitudes(Modes): """ # validate modes - assert Modes is not None and (type(Modes) is list or type(Modes) is np.ndarray) and len(np.shape(Modes)) == 2, "Parameter ``Modes`` should be a list or NumPy array with dimension ``(dim, num_modes)``" + assert Modes is not None and isinstance(Modes, (list, np.ndarray)) and len(np.shape(Modes)) == 2, "Parameter ``Modes`` should be a list or NumPy array with dimension ``(dim, num_modes)``" # convert to real - Modes_real = np.concatenate((np.real(Modes), np.imag(Modes)), axis=1, dtype=np.float_) + Modes_real = np.concatenate((np.real(Modes), np.imag(Modes)), axis=1) # calculate gradients grads = np.gradient(Modes_real, axis=0) - + # get indices where the derivative changes sign idxs = grads[:-1, :] * grads[1:, :] < 0 - Amps = list() + Amps = [] for i in range(idxs.shape[1]): # collect all crests and troughs extremas = Modes_real[:-1, i][idxs[:, i]] @@ -609,7 +615,7 @@ def get_correlation_Pearson(Modes): """ # validate modes - assert Modes is not None and (type(Modes) is list or type(Modes) is np.ndarray) and np.shape(Modes)[1] == 2, "Parameter ``Modes`` should be a list or NumPy array with dimension ``(dim, 2)``" + assert Modes is not None and isinstance(Modes, (list, np.ndarray)) and np.shape(Modes)[1] == 2, "Parameter ``Modes`` should be a list or NumPy array with dimension ``(dim, 2)``" # get means means = np.mean(Modes, axis=0) @@ -661,10 +667,10 @@ def get_Wigner_distributions_single_mode(Corrs, params, cb_update=None): xs = params.get('wigner_xs', None) ys = params.get('wigner_ys', None) for val in [xs, ys]: - assert val is not None and (type(val) is list or type(val) is np.ndarray), "Solver parameters ``'wigner_xs'`` and ``'wigner_ys'`` should be either NumPy arrays or ``list``" + assert val is not None and isinstance(val, (list, np.ndarray)), "Solver parameters ``'wigner_xs'`` and ``'wigner_ys'`` should be either NumPy arrays or ``list``" # handle list - xs = np.array(xs, dtype=np.float_) if type(xs) is list else xs - ys = np.array(ys, dtype=np.float_) if type(xs) is list else ys + xs = np.array(xs, dtype=np.float_) if isinstance(xs, list) else xs + ys = np.array(ys, dtype=np.float_) if isinstance(xs, list) else ys # extract frequently used variables show_progress = params.get('show_progress', False) @@ -696,15 +702,14 @@ def get_Wigner_distributions_single_mode(Corrs, params, cb_update=None): # get Wigner distributions for idx_y in range(len(ys)): for idx_x in range(len(xs)): - # # display progress - # if show_progress: - # _index_status = str(j + 1) + "/" + str(dim_m) - # updater.update_progress( - # pos=idx_y * len(xs) + idx_x, - # dim=dim_w, - # status="-" * (18 - len(_index_status)) + "Obtaining Wigners (" + _index_status + ")", - # reset=False - # ) + # display progress + if show_progress and cb_update is not None: + _index_status = str(j + 1) + "/" + str(dim_m) + cb_update( + status="-" * (18 - len(_index_status)) + "Obtaining Wigners (" + _index_status + ")", + progress=(idx_y * len(xs) + idx_x) / dim_w, + reset=False + ) # wigner function Wigners[:, j, idx_y, idx_x] = np.exp(- 0.5 * np.dot(Vects_t[idx_y, idx_x], _dots[idx_y, idx_x])[0]) / 2.0 / np.pi / np.sqrt(dets) @@ -753,15 +758,14 @@ def get_Wigner_distributions_two_mode(Corrs, params, cb_update=None): xs = params.get('wigner_xs', None) ys = params.get('wigner_ys', None) for val in [xs, ys]: - assert val is not None and (type(val) is list or type(val) is np.ndarray), "Solver parameters ``'wigner_xs'`` and ``'wigner_ys'`` should be either NumPy arrays or ``list``" + assert val is not None and isinstance(val, (list, np.ndarray)), "Solver parameters ``'wigner_xs'`` and ``'wigner_ys'`` should be either NumPy arrays or ``list``" # handle list - xs = np.array(xs, dtype=np.float_) if type(xs) is list else xs - ys = np.array(ys, dtype=np.float_) if type(xs) is list else ys + xs = np.array(xs, dtype=np.float_) if isinstance(xs, list) else xs + ys = np.array(ys, dtype=np.float_) if isinstance(xs, list) else ys # extract frequently used variables show_progress = params.get('show_progress', False) indices = params.get('indices', [0]) - dim_m = len(indices) dim_c = len(Corrs) dim_w = len(ys) * len(xs) pos_i = 2 * indices[0][0] @@ -798,14 +802,13 @@ def get_Wigner_distributions_two_mode(Corrs, params, cb_update=None): # get Wigner distributions for idx_y in range(len(ys)): for idx_x in range(len(xs)): - # # display progress - # if show_progress: - # updater.update_progress( - # pos=idx_y * len(xs) + idx_x, - # dim=dim_w, - # status="-" * 21 + "Obtaining Wigners", - # reset=False - # ) + # display progress + if show_progress and cb_update is not None: + cb_update( + status="-" * 21 + "Obtaining Wigners", + progress=(idx_y * len(xs) + idx_x) / dim_w, + reset=False + ) # wigner function Wigners[:, idx_y, idx_x] = np.exp(- 0.5 * np.dot(Vects_t[idx_y, idx_x], _dots[idx_y, idx_x])[0]) / 4.0 / np.pi**2 / np.sqrt(dets) @@ -843,8 +846,8 @@ def validate_Modes_Corrs(Modes=None, Corrs=None, is_modes_required:bool=False, i assert Corrs is not None if is_corrs_required else True, "Missing required parameter ``Corrs``" # handle list - Modes = np.array(Modes, dtype=np.complex_) if Modes is not None and type(Modes) is list else Modes - Corrs = np.array(Corrs, dtype=np.float_) if Corrs is not None and type(Corrs) is list else Corrs + Modes = np.array(Modes, dtype=np.complex_) if Modes is not None and isinstance(Modes, list) else Modes + Corrs = np.array(Corrs, dtype=np.float_) if Corrs is not None and isinstance(Corrs, list) else Corrs # validate shapes assert len(Modes.shape) == 2 if Modes is not None else True, "``Modes`` should be of shape ``(dim, num_modes)``" @@ -879,7 +882,7 @@ def validate_As_Coeffs(As=None, Coeffs=None): # validate drift matrix assert isinstance(As, Union[list, np.ndarray].__args__), "``As`` should be of type ``list`` or ``numpy.ndarray``" # convert to numpy array - As = np.array(As, dtype=np.float_) if type(As) is list else As + As = np.array(As, dtype=np.float_) if isinstance(As, list) else As # validate shape assert len(As.shape) == 3 and As.shape[1] == As.shape[2], "``As`` should be of shape ``(dim_0, 2 * num_modes, 2 * num_modes)``" # if coefficients are given @@ -887,8 +890,8 @@ def validate_As_Coeffs(As=None, Coeffs=None): # validate coefficients assert isinstance(Coeffs, Union[list, np.ndarray].__args__), "``Coeffs`` should be of type ``list`` or ``numpy.ndarray``" # convert to numpy array - Coeffs = np.array(Coeffs, dtype=np.float_) if type(Coeffs) is list else Coeffs + Coeffs = np.array(Coeffs, dtype=np.float_) if isinstance(Coeffs, list) else Coeffs # validate shape assert len(Coeffs.shape) == 2, "``Coeffs`` should be of shape ``(dim_0, 2 * num_modes + 1)``" - return As, Coeffs \ No newline at end of file + return As, Coeffs diff --git a/quantrl/solvers/numpy.py b/quantrl/solvers/numpy.py index 2331bd0..a2ffeee 100644 --- a/quantrl/solvers/numpy.py +++ b/quantrl/solvers/numpy.py @@ -6,11 +6,11 @@ __name__ = 'quantrl.solvers.numpy' __authors__ = ["Sampreet Kalita"] __created__ = "2024-03-10" -__updated__ = "2024-03-23" +__updated__ = "2024-10-14" # dependencies -from scipy.interpolate import splev, splrep import scipy.integrate as si +from scipy.interpolate import splev, splrep # quantrl modules from ..backends.numpy import NumPyBackend @@ -159,7 +159,7 @@ def integrate_flat(self, self.integrator.set_f_params(args) for i in range(1, len(T_step)): _Y_flat[i] = self.integrator.integrate(T_step[i]) - + # integrate using Python-based solvers else: _Y_flat = self.backend.transpose( @@ -176,7 +176,7 @@ def integrate_flat(self, ) return _Y_flat - + def interpolate(self, T_step, Y @@ -185,4 +185,4 @@ def interpolate(self, tensor=Y )[1] b_spline = [splrep(T_step, Y[:, j]) for j in range(_shape)] - return lambda t: [splev(t, b_spline[j]) for j in range(_shape)] \ No newline at end of file + return lambda t: [splev(t, b_spline[j]) for j in range(_shape)] diff --git a/quantrl/solvers/torch.py b/quantrl/solvers/torch.py index ebc3c9b..06b2f51 100644 --- a/quantrl/solvers/torch.py +++ b/quantrl/solvers/torch.py @@ -6,7 +6,7 @@ __name__ = 'quantrl.solvers.torch' __authors__ = ["Sampreet Kalita"] __created__ = "2024-03-10" -__updated__ = "2024-03-23" +__updated__ = "2024-10-14" # dependencies from torchdiffeq import odeint @@ -73,11 +73,11 @@ def integrate(self, atol=self.solver_params['atol'], rtol=self.solver_params['rtol'], method=self.solver_params['method'], - options=dict() + options={} ) - + def interpolate(self, T_step, Y ): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/quantrl/utils.py b/quantrl/utils.py index d0fc9f4..f96334f 100644 --- a/quantrl/utils.py +++ b/quantrl/utils.py @@ -6,16 +6,17 @@ __name__ = 'quantrl.utils' __authors__ = ["Sampreet Kalita"] __created__ = "2023-06-02" -__updated__ = "2024-08-09" +__updated__ = "2024-10-14" # dependencies +import gc +import os +import time + from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.results_plotter import ts2xy -import gc import numpy as np -import os import pandas -import time class SaveOnBestMeanRewardCallback(BaseCallback): """Utility callback class to save the best mean reward. @@ -61,7 +62,7 @@ def __init__(self, self.t_start = time.time() # initialize file to save mean rewards - with open(self.log_dir + f'reward_mean_{self.episode_start}_{self.n_episodes - 1}.txt', 'w') as file: + with open(self.log_dir + f'reward_mean_{self.episode_start}_{self.n_episodes - 1}.txt', 'w', encoding='utf-8') as file: s = f'{"time":>14} {"n_steps":>12} {"episode_curr":>12} {"reward_curr":>14} {"episode_best":>12} {"reward_best":>14}\n' file.write(s) file.close() @@ -79,7 +80,7 @@ def _on_step(self) -> bool: if self.n_calls % self.update_steps == 0: # retrieve reward data from monitor file - with open(self.log_dir + f'learning_{self.episode_start}_{self.n_episodes - 1}.monitor.csv') as file_handler: + with open(self.log_dir + f'learning_{self.episode_start}_{self.n_episodes - 1}.monitor.csv', 'r', encoding='utf-8') as file_handler: file_handler.readline() xs, ys = ts2xy(pandas.read_csv(file_handler, index_col=None), 'timesteps') @@ -100,15 +101,15 @@ def _on_step(self) -> bool: # update console if self.verbose >= 1: - print(f"Saving new best model and replay buffer...") + print("Saving new best model and replay buffer...") # save model and replay buffer self.model.save(self.log_dir + f'models/best_{self.n_episodes - 1}.zip') self.model.save_replay_buffer(self.log_dir + f'buffers/best_{self.n_episodes - 1}.zip') # save reward data - with open(self.log_dir + f'reward_mean_{self.episode_start}_{self.n_episodes - 1}.txt', 'a') as file: + with open(self.log_dir + f'reward_mean_{self.episode_start}_{self.n_episodes - 1}.txt', 'a', encoding='utf-8') as file: file.write(f'{time.time() - self.t_start:14.03f} {self.n_calls:12d} {episode_curr:12d} {reward_curr:14.06f} {self.episode_best:12d} {self.reward_best:14.06f}\n') file.close() gc.collect() - return True \ No newline at end of file + return True diff --git a/requirements.txt b/requirements.txt index fff2cec..1358fcb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -gymnasium -matplotlib numpy<2.0.0 scipy -seaborn +matplotlib +tqdm +pillow +pandas +gymnasium stable-baselines3 -tqdm \ No newline at end of file diff --git a/requirements_tox.txt b/requirements_tox.txt new file mode 100644 index 0000000..27cddc8 --- /dev/null +++ b/requirements_tox.txt @@ -0,0 +1,14 @@ +pytest +pytest-cov +pylint +numpy<2.0.0 +scipy +matplotlib +tqdm +pillow +pandas +gymnasium +stable-baselines3 +torchdiffeq +jax +diffrax diff --git a/setup.py b/setup.py deleted file mode 100644 index 28094c2..0000000 --- a/setup.py +++ /dev/null @@ -1,38 +0,0 @@ -from setuptools import setup, find_packages - -with open('README.md', 'r') as file_readme: - long_desc = file_readme.read() - -setup( - name='quantrl', - version='0.0.7', - author='Sampreet Kalita', - author_email='sampreet.kalita@hotmail.com', - desctiption='Quantum Control with Reinforcement Learning', - long_description=long_desc, - long_description_content_type='text/markdown', - keywords=['quantum', 'toolbox', 'reinforcement learning', 'python3'], - url='https://github.com/sampreet/quantrl', - packages=find_packages(), - classifiers=[ - 'Programming Language :: Python :: 3', - 'Development Status :: 2 - Pre-Alpha', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Topic :: Scientific/Engineering' - ], - license='BSD', - install_requires=[ - 'gymnasium', - 'matplotlib', - 'numpy<2.0.0', - 'scipy', - 'seaborn' - 'stable-baselines3', - 'tqdm' - ], - python_requires='>=3.8', - zip_safe=False, - include_package_data=True -)