Skip to content

Commit

Permalink
Add tox and CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Sampreet committed Oct 14, 2024
1 parent 1dde917 commit a9f791d
Show file tree
Hide file tree
Showing 28 changed files with 643 additions and 393 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/python-tox.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: tox

on: [
push,
pull_request,
]

jobs:
build:

runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [
ubuntu-latest,
macos-12,
windows-latest,
]

steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: "pip"
- name: Install dependencies
run: pip install -r requirements.txt
- name: Install package
run: pip install -e .

tox:

runs-on: ubuntu-latest
strategy:
matrix:
python: [
"3.10",
"3.12",
]

steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: "pip"
- name: Install tox
run: pip install tox
- name: Run tox
run: tox
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
# Changelog

## 2027/08/09 - 00 - v0.0.7 - Update Callback
## 2024/10/14 - 00 - v0.0.8 - Instantiation and GitHub CI
* Instantiated backends and solvers with different numerical libraries:
* Added `context_manager` modules to `quantrl.backends` and `quantrl.solvers` packages.
* Added support for JIT-compilation in `quantrl.backends` package.
* Changed naming convention for GPU-based options as `gpu` from `cuda`.
* Added `tox`-based testing and continuous integration with GitHub workflows.
* Updated all modules with PEP-based code styling and added `pylintrc`.
* Updated `requirements.txt` and added version to `quantrl.__init__`.
* Replaced `setup.py` with `pyproject.toml`.
* Updated `docs` and `README`.

## 2024/08/09 - 00 - v0.0.7 - Update Callback
* Fixed issue with cache indexing in `quantrl.envs.base` module.
* Updated best reward callback in `quantrl.utils` module.
* Added `seaborn` to `requirements` and `setup`.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# QuantRL: Quantum Control using Reinforcement Learning

![Latest Version](https://img.shields.io/badge/version-0.0.7-red?style=for-the-badge)
![Latest Version](https://img.shields.io/badge/version-0.0.8-red?style=for-the-badge)

> A library of modules to interface deterministic and stochastic quantum models for reinforcement learning.
Expand Down
9 changes: 9 additions & 0 deletions docs/source/quantrl.backends.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ quantrl.backends.base module
:undoc-members:
:show-inheritance:

quantrl.backends.context_manager module
---------------------------------------

.. automodule:: quantrl.backends.context_manager
:members:
:private-members:
:undoc-members:
:show-inheritance:

quantrl.backends.jax module
---------------------------

Expand Down
9 changes: 9 additions & 0 deletions docs/source/quantrl.solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ quantrl.solvers.base module
:undoc-members:
:show-inheritance:

quantrl.backends.context_manager module
---------------------------------------

.. automodule:: quantrl.backends.context_manager
:members:
:private-members:
:undoc-members:
:show-inheritance:

quantrl.solvers.jax module
--------------------------

Expand Down
26 changes: 26 additions & 0 deletions pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[MESSAGES CONTROL]
disable=too-many-lines,
too-many-locals,
too-many-arguments,
too-many-statements,
too-few-public-methods,
too-many-public-methods,
too-many-instance-attributes,
too-many-positional-arguments,
line-too-long, # no bounded lines
redefined-builtin, # catches __name__
duplicate-code, # catches __init__
unused-argument, # catches Gym.Env methods
import-outside-toplevel, # for numerical libraries
not-callable, # catches JAX JIT functions
fixme

[BASIC]
attr-naming-style=any
variable-naming-style=any
argument-naming-style=any
const-naming-style=UPPER_CASE
method-naming-style=any
function-naming-style=any
class-naming-style=PascalCase
class-attribute-naming-style=snake_case
87 changes: 87 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
[build-system]
requires = ["cython", "setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"

[project]
dynamic = ["version"]
name = "quantrl"
authors = [
{name = "Sampreet Kalita", email = "[email protected]"},
]
maintainers = [
{name = "Sampreet Kalita", email = "[email protected]"},
]
description = "Quantum Control with Reinforcement Learning"
keywords = [
"quantum",
"toolbox",
"reinforcement learning",
"python3",
]
readme = "README.md"
license = {file = "LICENSE.txt"}
requires-python = ">=3.10"
dependencies = [
"numpy<2.0.0",
"scipy",
"matplotlib",
"tqdm",
"pillow",
"pandas",
"gymnasium",
"stable-baselines3",
]
classifiers = [
"Programming Language :: Python :: 3",
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering",
]

[project.optional-dependencies]
jax-cpu = [
"jax",
"diffrax",
]
jax-gpu = [
"jax[cuda12]",
"diffrax",
]
torch = [
"torch",
"torchvision",
"torchaudio",
"torchdiffeq",
]

[project.urls]
Homepage = "https://github.com/sampreet/quantrl"
Repository = "https://github.com/sampreet/quantrl"
Issues = "https://github.com/sampreet/quantrl/issues"
Changelog = "https://github.com/sampreet/quantrl/blob/master/CHANGELOG.md"

[tool.setuptools.packages.find]
include = ["quantrl"]
namespaces = false

[tool.setuptools.dynamic]
version = {attr = "quantrl.__version__"}

[tool.tox]
legacy_tox_ini = """
[tox]
requires =
tox>=4
virtualenv>=20
env_list =
lint
[testenv:lint]
description = run pylint under {base_python}
deps =
-r requirements_tox.txt
commands =
pylint quantrl
"""
2 changes: 2 additions & 0 deletions quantrl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Module to initialize QuantRL."""
__version__ = "0.0.8"
65 changes: 46 additions & 19 deletions quantrl/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
__name__ = 'quantrl.backends.base'
__authors__ = ["Sampreet Kalita"]
__created__ = "2024-03-10"
__updated__ = "2024-10-09"
__updated__ = "2024-10-13"

# dependencies
from abc import ABC, abstractmethod

import numpy as np

class BaseBackend(ABC):
Expand All @@ -29,9 +30,9 @@ class BaseBackend(ABC):

def __init__(
self,
name,
library,
tensor_type,
name:str='numpy',
library:np=np,
tensor_type:np.ndarray=np.ndarray,
precision:str='double'
):
# validate parameters
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
}
}
}
self.seed_sequence = None

def is_typed(self,
tensor,
Expand All @@ -91,11 +93,32 @@ def is_typed(self,
_dtype = self.dtype_from_str(
dtype=dtype
)
if type(tensor) == self.tensor_type:
if isinstance(tensor, self.tensor_type):
if dtype is None or (dtype is not None and tensor.dtype == _dtype):
return True
return False

def get_seedsequence(self,
seed:int=None
) -> np.random.SeedSequence:
"""Method to obtain a SeedSequence object.
Parameters
----------
seed: int
Initial seed to obtain the entropy.
Returns
-------
seed_sequence: :class:`numpy.random.SeedSequence`
The SeedSequence object.
"""
if seed is None:
entropy = np.random.randint(1234567890)
else:
entropy = np.random.default_rng(seed).integers(0, 1234567890, (1, ))[0]
return np.random.SeedSequence(entropy)

@abstractmethod
def convert_to_typed(self,
tensor,
Expand Down Expand Up @@ -546,37 +569,41 @@ def dtype_from_str(self,
dtype: type
Selected data-type.
"""

# validate params
assert dtype is None or dtype in ['integer', 'real', 'complex'], "parameter ``dtype`` can be either ``'integer'``, ``'real'`` or ``'complex'``."

# default dtype is the real data-type
if dtype is None:
if dtype is None or dtype not in ['integer', 'real', 'complex']:
dtype = 'real'
return self.dtypes['numpy' if numpy else 'typed'][self.precision][dtype]

def jit_transpose(self, tensor, axis_0, axis_1):
"""Method to JIT-compile transposition."""
return self.transpose(tensor, axis_0, axis_1)

def jit_repeat(self, tensor, repeats, axis):
"""Method to JIT-compile repitition."""
return self.repeat(tensor, repeats, axis)

def jit_add(self, tensor_0, tensor_1, out):
"""Method to JIT-compile addition."""
return self.add(tensor_0, tensor_1, out=out)

def jit_matmul(self, tensor_0, tensor_1, out):
"""Method to JIT-compile matrix multiplication."""
return self.matmul(tensor_0, tensor_1, out)

def jit_dot(self, tensor_0, tensor_1, out):
"""Method to JIT-compile dot product."""
return self.dot(tensor_0, tensor_1, out)

def jit_concatenate(self, tensors, axis, out):
"""Method to JIT-compile concatenation."""
return self.concatenate(tensors, axis, out)

def jit_stack(self, tensors, axis, out):
"""Method to JIT-compile stacking."""
return self.stack(tensors, axis, out)

def jit_update(self, tensor, indices, values):
"""Method to JIT-compile updation."""
return self.update(tensor, indices, values)

def empty(self,
Expand Down Expand Up @@ -1003,4 +1030,4 @@ def argmax(self,

return self.library.argmax(self.convert_to_typed(
tensor=tensor
))
))
Loading

0 comments on commit a9f791d

Please sign in to comment.