diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a98b4d48..83b08dd0 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -34,7 +34,7 @@ jobs: - name: Move install script run: mv install.sh docs/_build/html - name: Deploy documentation - if: ${{ github.event_name == 'push' && github.ref_name == 'main' }} + if: ${{ github.event_name == 'push' && (github.ref_name == 'main' || github.ref_name == 'release') }} uses: s0/git-publish-subdir-action@develop env: REPO: self diff --git a/README.md b/README.md index ffd2ed09..c96b8ce6 100644 --- a/README.md +++ b/README.md @@ -92,20 +92,6 @@ make html ``` the docs will be under `docs/_build`. -## Known issues - -When using lambeq on a Windows machine, the instantiation of the BobcatParser -might trigger an SSL certificate error. We are currently investigating the -issue. In the meantime, you can download the model through this -[link](https://qnlp.cambridgequantum.com/models/bert/latest/model.tar.gz), -extract the archive, and provide the path to the BobcatParser: - -```python -from lambeq import BobcatParser - -parser = BobcatParser('path/to/model_dir') -``` - ## License Distributed under the Apache 2.0 license. See [`LICENSE`](LICENSE) for diff --git a/docs/release_notes.rst b/docs/release_notes.rst index 8e414328..33cad604 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -3,6 +3,17 @@ Release notes ============= +.. _rel-0.2.4: + +`0.2.4 `_ +------------------------------------------------------------ + +- Fix a bug that caused the :py:class:`~lambeq.BobcatParser` and the :py:class:`~lambeq.WebParser` to trigger an SSL certificate error using Windows. + +- Fix false positives in assigning conjunction rule using the :py:class:`~lambeq.CCGBankParser`. The rule ``, + X[conj] -> X[conj]`` is a case of removing left punctuation, but was being assigned conjunction erroneously. + +- Add support for using ``jax`` as backend of ``tensornetwork`` when setting ``use_jit=True`` in the :py:class:`~lambeq.NumpyModel`. The interface is not affected by this change, but performance of the model is significantly improved. + .. _rel-0.2.3: `0.2.3 `_ diff --git a/docs/troubleshooting.rst b/docs/troubleshooting.rst index bf5a4f73..c9475ef2 100644 --- a/docs/troubleshooting.rst +++ b/docs/troubleshooting.rst @@ -11,10 +11,9 @@ encourage you to SSL error [Windows] ------------------- -When using ``lambeq`` on a Windows machine, the instantiation of the -BobcatParser might trigger an SSL certificate error. We are currently -investigating the issue. In the meantime, you can download the model through -this +When using ``lambeq <= 0.2.3`` on a Windows machine, the instantiation of the +BobcatParser might trigger an SSL certificate error. If you require +``lambeq <= 0.2.3``, you can download the model through this `link `_, extract the archive, and provide the path to the BobcatParser: @@ -22,3 +21,14 @@ extract the archive, and provide the path to the BobcatParser: from lambeq import BobcatParser parser = BobcatParser('path/to/model_dir') + +Note that using the :py:class:`~lambeq.WebParser` will most likely result in +the same error. + +However, this was resolved in release +`0.2.4 `_. Please consider +upgrading lambeq: + +.. code-block:: bash + + pip install --upgrade lambeq diff --git a/lambeq/text2diagram/bobcat_parser.py b/lambeq/text2diagram/bobcat_parser.py index 01dc4e9e..d608cbb9 100644 --- a/lambeq/text2diagram/bobcat_parser.py +++ b/lambeq/text2diagram/bobcat_parser.py @@ -19,11 +19,11 @@ import json import os from pathlib import Path -import shutil +import requests import sys import tarfile +import tempfile from typing import Any, Iterable, Optional, Union -from urllib.request import urlopen, urlretrieve import warnings from discopy.biclosed import Ty @@ -83,8 +83,7 @@ def model_is_stale(model: str, model_dir: str) -> bool: return False try: - with urlopen(url) as f: - remote_version = f.read().strip().decode("utf-8") + remote_version = requests.get(url).text.strip() except Exception: return False @@ -107,41 +106,42 @@ def download_model( if model_dir is None: model_dir = get_model_dir(model_name) - class ProgressBar: - bar = None - - def update(self, chunk: int, chunk_size: int, size: int) -> None: - if self.bar is None: - self.bar = tqdm( - bar_format='Downloading model: {percentage:3.1f}%|' - '{bar}|{n:.3f}/{total:.3f}GB ' - '[{elapsed}<{remaining}]', - total=size/1e9) - warnings.filterwarnings('ignore', category=TqdmWarning) - self.bar.update(chunk_size/1e9) - - def close(self): - self.bar.close() - if verbose == VerbosityLevel.TEXT.value: print('Downloading model...', file=sys.stderr) if verbose == VerbosityLevel.PROGRESS.value: - progress_bar = ProgressBar() - model_file, headers = urlretrieve(url, reporthook=progress_bar.update) - progress_bar.close() + response = requests.get(url, stream=True) + size = int(response.headers.get('content-length', 0)) + block_size = 1024 + + warnings.filterwarnings('ignore', category=TqdmWarning) + progress_bar = tqdm( + bar_format='Downloading model: {percentage:3.1f}%|' + '{bar}|{n:.3f}/{total:.3f}GB ' + '[{elapsed}<{remaining}]', + total=size/1e9) + + model_file = tempfile.NamedTemporaryFile() + for data in response.iter_content(block_size): + progress_bar.update(len(data)/1e9) + model_file.write(data) + else: - model_file, headers = urlretrieve(url) + content = requests.get(url).content + model_file = tempfile.NamedTemporaryFile() + model_file.write(content) # Extract model + model_file.seek(0) if verbose != VerbosityLevel.SUPPRESS.value: print('Extracting model...') - with tarfile.open(model_file) as tar: - tar.extractall(model_dir) + tar = tarfile.open(fileobj=model_file) + tar.extractall(model_dir) + model_file.close() # Download version ver_url = get_model_url(model_name) + '/' + VERSION_FNAME - ver_file, headers = urlretrieve(ver_url) - shutil.copy(ver_file, model_dir / VERSION_FNAME) # type: ignore + with open(os.path.join(model_dir, VERSION_FNAME), 'wb') as w: + w.write(requests.get(ver_url).content) class BobcatParseError(Exception): diff --git a/lambeq/text2diagram/ccgbank_parser.py b/lambeq/text2diagram/ccgbank_parser.py index 0f9e3f6e..3889df4f 100644 --- a/lambeq/text2diagram/ccgbank_parser.py +++ b/lambeq/text2diagram/ccgbank_parser.py @@ -45,7 +45,7 @@ from lambeq.text2diagram.ccg_parser import CCGParser from lambeq.text2diagram.ccg_rule import CCGRule from lambeq.text2diagram.ccg_tree import CCGTree -from lambeq.text2diagram.ccg_types import CONJ_TAG, CCGAtomicType, str2biclosed +from lambeq.text2diagram.ccg_types import CCGAtomicType, str2biclosed class CCGBankParseError(Exception): @@ -408,12 +408,9 @@ def _build_ccgtree(sentence: str, start: int) -> tuple[CCGTree, int]: child, pos = CCGBankParser._build_ccgtree(sentence, pos) children.append(child) - if tree_match['ccg_str'].endswith(CONJ_TAG): - rule = CCGRule.CONJUNCTION - else: - rule = CCGRule.infer_rule( - Ty().tensor(*(child.biclosed_type for child in children)), - biclosed_type) + rule = CCGRule.infer_rule(Ty().tensor(*(child.biclosed_type + for child in children)), + biclosed_type) ccg_tree = CCGTree(rule=rule, biclosed_type=biclosed_type, children=children) diff --git a/lambeq/text2diagram/web_parser.py b/lambeq/text2diagram/web_parser.py index 56139018..078ffa63 100644 --- a/lambeq/text2diagram/web_parser.py +++ b/lambeq/text2diagram/web_parser.py @@ -15,12 +15,9 @@ __all__ = ['WebParser', 'WebParseError'] -import json +import requests import sys from typing import Optional -from urllib.error import HTTPError -from urllib.parse import urlencode -from urllib.request import urlopen from tqdm.auto import tqdm @@ -34,13 +31,14 @@ class WebParseError(OSError): - def __init__(self, sentence: str, error_code: int) -> None: + def __init__(self, sentence: str) -> None: self.sentence = sentence - self.error_code = error_code def __str__(self) -> str: - return (f'Online parsing of sentence {repr(self.sentence)} failed, ' - f'Web status code: {self.error_code}.') + return (f'Web parser could not parse {repr(self.sentence)}.' + 'Check that you are using the correct URL. ' + 'If the URL is correct, this means the parser could not parse ' + 'your sentence.') class WebParser(CCGParser): @@ -136,25 +134,20 @@ def sentences2trees( trees: list[Optional[CCGTree]] = [] if verbose == VerbosityLevel.TEXT.value: print('Parsing sentences.', file=sys.stderr) - for sent in tqdm( + for sentence in tqdm( sentences, desc='Parsing sentences', leave=False, disable=verbose != VerbosityLevel.PROGRESS.value): - params = urlencode({'sentence': sent}) - url = f'{self.service_url}?{params}' + params = {'sentence': sentence} try: - with urlopen(url) as f: - data = json.load(f) - except HTTPError as e: - if suppress_exceptions: - tree = None - else: - raise WebParseError(str(sentence), e.code) - except Exception as e: + data = requests.get(self.service_url, params=params).json() + except requests.RequestException as e: if suppress_exceptions: tree = None + elif type(e) == requests.JSONDecodeError: + raise WebParseError(str(sentence)) else: raise e else: diff --git a/lambeq/training/numpy_model.py b/lambeq/training/numpy_model.py index 7a14283c..f442f018 100644 --- a/lambeq/training/numpy_model.py +++ b/lambeq/training/numpy_model.py @@ -35,6 +35,7 @@ from discopy.tensor import Diagram from sympy import default_sort_key, lambdify +from lambeq.training.model import SizedIterable from lambeq.training.quantum_model import QuantumModel @@ -100,13 +101,40 @@ def _get_lambda(self, diagram: Diagram) -> Callable[[Any], Any]: return self.lambdas[diagram] def diagram_output(*x): - with Tensor.backend('jax'): - result = diagram.lambdify(*self.symbols)(*x).eval().array + with Tensor.backend('jax'), tn.DefaultBackend('jax'): + sub_circuit = self._fast_subs([diagram], x)[0] + result = tn.contractors.auto(*sub_circuit.to_tn()).tensor return self._normalise_vector(result) self.lambdas[diagram] = jit(diagram_output) return self.lambdas[diagram] + def _fast_subs(self, + diagrams: list[Diagram], + weights: SizedIterable) -> list[Diagram]: + """Substitute weights into a list of parameterised circuit.""" + parameters = {k: v for k, v in zip(self.symbols, weights)} + diagrams = pickle.loads(pickle.dumps(diagrams)) # does fast deepcopy + for diagram in diagrams: + for b in diagram._boxes: + if b.free_symbols: + while hasattr(b, 'controlled'): + b._free_symbols = set() + b = b.controlled + syms, values = [], [] + for sym in b._free_symbols: + syms.append(sym) + try: + values.append(parameters[sym]) + except KeyError: + raise KeyError(f'Unknown symbol {sym!r}.') + b._data = lambdify(syms, b._data)(*values) + b.drawing_name = b.name + b._free_symbols = set() + if hasattr(b, '_phase'): + b._phase = b._data + return diagrams + def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray: """Return the exact prediction for each diagram. @@ -139,27 +167,7 @@ def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray: return numpy.array([diag_f(*self.weights) for diag_f in lambdified_diagrams]) - parameters = {k: v for k, v in zip(self.symbols, self.weights)} - diagrams = pickle.loads(pickle.dumps(diagrams)) # does fast deepcopy - for diagram in diagrams: - for b in diagram._boxes: - if b.free_symbols: - while hasattr(b, 'controlled'): - b._free_symbols = set() - b = b.controlled - syms, values = [], [] - for sym in b._free_symbols: - syms.append(sym) - try: - values.append(parameters[sym]) - except KeyError: - raise KeyError(f'Unknown symbol {sym!r}.') - b._data = lambdify(syms, b._data)(*values) - b.drawing_name = b.name - b._free_symbols = set() - if hasattr(b, '_phase'): - b._phase = b._data - + diagrams = self._fast_subs(diagrams, self.weights) with Tensor.backend('numpy'): return numpy.array([ self._normalise_vector(tn.contractors.auto(*d.to_tn()).tensor) diff --git a/pyproject.toml b/pyproject.toml index 430e65e9..1e44cebe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ module = [ "depccg.*", "discopy.*", "jax.*", + "requests.*", "sympy.*", "tensornetwork.*", "tqdm.*", diff --git a/tests/text2diagram/test_reader.py b/tests/text2diagram/test_reader.py index 585b051a..736bff4d 100644 --- a/tests/text2diagram/test_reader.py +++ b/tests/text2diagram/test_reader.py @@ -1,4 +1,5 @@ import pytest +from requests.exceptions import MissingSchema from discopy import Word from discopy.rigid import Box, Diagram, Id, Spider @@ -8,6 +9,7 @@ from lambeq import (AtomicType, BobcatParser, IQPAnsatz, TreeReader, TreeReaderMode, VerbosityLevel, WebParser, cups_reader, spiders_reader, stairs_reader) +from lambeq.text2diagram.web_parser import WebParseError @pytest.fixture @@ -108,7 +110,7 @@ def test_suppress_exceptions(sentence): assert bad_reader.sentence2diagram(sentence) is None bad_reader = TreeReader(bad_parser, suppress_exceptions=False) - with pytest.raises(ValueError): + with pytest.raises(MissingSchema): bad_reader.sentence2diagram(sentence) diff --git a/tests/text2diagram/test_web_parser.py b/tests/text2diagram/test_web_parser.py index 197d7d7a..c71b7dab 100644 --- a/tests/text2diagram/test_web_parser.py +++ b/tests/text2diagram/test_web_parser.py @@ -1,4 +1,5 @@ from io import StringIO +from shutil import ExecError import pytest from unittest.mock import patch @@ -56,7 +57,7 @@ def test_bad_url(): assert bad_parser.sentence2diagram( "Need a proper url", suppress_exceptions=True) is None - with pytest.raises(WebParseError): + with pytest.raises(Exception): bad_parser.sentence2diagram("Need a proper url")