diff --git a/molecularnodes/io/download.py b/molecularnodes/io/download.py new file mode 100644 index 00000000..069fdc1c --- /dev/null +++ b/molecularnodes/io/download.py @@ -0,0 +1,80 @@ +import os +import requests +import io + +def fetch(code, format="cif", cache=None, database='rcsb'): + """ + Downloads a structure from the specified protein data bank in the given format. + + Parameters + ---------- + code : str + The code of the file to fetch. + format : str, optional + The format of the file. Defaults to "cif". Possible values are ['cif', 'pdb', + 'mmcif', 'pdbx', 'mmtf', 'bcif']. + cache : str, optional + The cache directory to store the fetched file. Defaults to None. + database : str, optional + The database to fetch the file from. Defaults to 'rcsb'. + + Returns + ------- + file + The fetched file as a file-like object. + + Raises + ------ + ValueError + If the specified format is not supported. + """ + supported_formats = ['cif', 'pdb', 'mmtf', 'bcif'] + if format not in supported_formats: + raise ValueError(f"File format '{format}' not in: {supported_formats=}") + + _is_binary = (format in ['bcif', 'mmtf']) + filename = f"{code}.{format}" + # create the cache location + if cache: + if not os.path.isdir(cache): + os.makedirs(cache) + + file = os.path.join(cache, filename) + else: + file = None + + # get the contents of the url + r = requests.get(_url(code, format, database)) + if _is_binary: + content = r.content + else: + content = r.text + + if file: + mode = "wb+" if _is_binary else "w+" + with open(file, mode) as f: + f.write(content) + else: + if _is_binary: + file = io.BytesIO(content) + else: + file = io.StringIO(content) + + return file + +def _url(code, format, database="rcsb"): + "Get the URL for downloading the given file form a particular database." + + if database == "rcsb": + if format == "bcif": + return f"https://models.rcsb.org/{code}.bcif" + if format == "mmtf": + return f"https://mmtf.rcsb.org/v1.0/full/{code}" + + else: + return f"https://files.rcsb.org/download/{code}.{format}" + # if database == "pdbe": + # return f"https://www.ebi.ac.uk/pdbe/entry-files/download/{filename}" + else: + ValueError(f"Database {database} not currently supported.") + diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 00000000..c43f16b8 --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,81 @@ +import os +import io +import pytest +from molecularnodes.io import download +import biotite.database.rcsb as rcsb +from biotite.structure.io import load_structure +import tempfile + +from .constants import codes + +databases = ['rcsb'] # currently can't figure out downloading from other services + +def _filestart(format): + if format == "cif": + return 'data_' + else: + return 'HEADER' + + + +@pytest.mark.parametrize('format', ['cif', 'mmtf', 'pdb']) +def test_compare_biotite(format): + struc_download = load_structure(download.fetch('4ozs', format=format, cache=tempfile.TemporaryDirectory().name)) + struc_biotite = load_structure(rcsb.fetch('4ozs', format=format, target_path=tempfile.TemporaryDirectory().name)) + assert struc_download == struc_biotite + +@pytest.mark.parametrize('code', codes) +@pytest.mark.parametrize('database', databases) +@pytest.mark.parametrize('format', ['pdb', 'cif']) +def test_fetch_with_cache(tmpdir, code, format, database): + cache_dir = tmpdir.mkdir("cache") + file = download.fetch(code, format, cache=str(cache_dir), database=database) + + assert isinstance(file, str) + assert os.path.isfile(file) + assert file.endswith(f"{code}.{format}") + + + with open(file, "r") as f: + content = f.read() + assert content.startswith(_filestart(format)) + +databases = ['rcsb'] # currently can't figure out downloading from the pdbe + +@pytest.mark.parametrize('code', codes) +@pytest.mark.parametrize('database', databases) +@pytest.mark.parametrize('format', ['pdb', 'cif']) +def test_fetch_without_cache(tmpdir, code, format, database): + file = download.fetch(code, format, cache=None, database=database) + + assert isinstance(file, io.StringIO) + content = file.getvalue() + assert content.startswith(_filestart(format)) + +@pytest.mark.parametrize('database', databases) +def test_fetch_with_invalid_format(database): + code = '4OZS' + format = "xyz" + + with pytest.raises(ValueError): + download.fetch(code, format, cache=None, database=database) + +@pytest.mark.parametrize('code', codes) +@pytest.mark.parametrize('database', databases) +@pytest.mark.parametrize('format', ['bcif', 'mmtf']) +def test_fetch_with_binary_format(tmpdir, code, database, format): + cache_dir = tmpdir.mkdir("cache") + file = download.fetch(code, format, cache=str(cache_dir), database=database) + + assert isinstance(file, str) + assert os.path.isfile(file) + assert file.endswith(f"{code}.{format}") + + if format == "bcif": + start = b"\x83\xa7" + elif format == "mmtf": + start = b"\xde\x00" + + with open(file, "rb") as f: + content = f.read() + assert content.startswith(start) \ No newline at end of file