-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #411 from BradyAJohnston/dev-fetch
Add standalone structure downloader
- Loading branch information
Showing
2 changed files
with
161 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
import requests | ||
import io | ||
|
||
def fetch(code, format="cif", cache=None, database='rcsb'): | ||
""" | ||
Downloads a structure from the specified protein data bank in the given format. | ||
Parameters | ||
---------- | ||
code : str | ||
The code of the file to fetch. | ||
format : str, optional | ||
The format of the file. Defaults to "cif". Possible values are ['cif', 'pdb', | ||
'mmcif', 'pdbx', 'mmtf', 'bcif']. | ||
cache : str, optional | ||
The cache directory to store the fetched file. Defaults to None. | ||
database : str, optional | ||
The database to fetch the file from. Defaults to 'rcsb'. | ||
Returns | ||
------- | ||
file | ||
The fetched file as a file-like object. | ||
Raises | ||
------ | ||
ValueError | ||
If the specified format is not supported. | ||
""" | ||
supported_formats = ['cif', 'pdb', 'mmtf', 'bcif'] | ||
if format not in supported_formats: | ||
raise ValueError(f"File format '{format}' not in: {supported_formats=}") | ||
|
||
_is_binary = (format in ['bcif', 'mmtf']) | ||
filename = f"{code}.{format}" | ||
# create the cache location | ||
if cache: | ||
if not os.path.isdir(cache): | ||
os.makedirs(cache) | ||
|
||
file = os.path.join(cache, filename) | ||
else: | ||
file = None | ||
|
||
# get the contents of the url | ||
r = requests.get(_url(code, format, database)) | ||
if _is_binary: | ||
content = r.content | ||
else: | ||
content = r.text | ||
|
||
if file: | ||
mode = "wb+" if _is_binary else "w+" | ||
with open(file, mode) as f: | ||
f.write(content) | ||
else: | ||
if _is_binary: | ||
file = io.BytesIO(content) | ||
else: | ||
file = io.StringIO(content) | ||
|
||
return file | ||
|
||
def _url(code, format, database="rcsb"): | ||
"Get the URL for downloading the given file form a particular database." | ||
|
||
if database == "rcsb": | ||
if format == "bcif": | ||
return f"https://models.rcsb.org/{code}.bcif" | ||
if format == "mmtf": | ||
return f"https://mmtf.rcsb.org/v1.0/full/{code}" | ||
|
||
else: | ||
return f"https://files.rcsb.org/download/{code}.{format}" | ||
# if database == "pdbe": | ||
# return f"https://www.ebi.ac.uk/pdbe/entry-files/download/{filename}" | ||
else: | ||
ValueError(f"Database {database} not currently supported.") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import os | ||
import io | ||
import pytest | ||
from molecularnodes.io import download | ||
import biotite.database.rcsb as rcsb | ||
from biotite.structure.io import load_structure | ||
import tempfile | ||
|
||
from .constants import codes | ||
|
||
databases = ['rcsb'] # currently can't figure out downloading from other services | ||
|
||
def _filestart(format): | ||
if format == "cif": | ||
return 'data_' | ||
else: | ||
return 'HEADER' | ||
|
||
|
||
|
||
@pytest.mark.parametrize('format', ['cif', 'mmtf', 'pdb']) | ||
def test_compare_biotite(format): | ||
struc_download = load_structure(download.fetch('4ozs', format=format, cache=tempfile.TemporaryDirectory().name)) | ||
struc_biotite = load_structure(rcsb.fetch('4ozs', format=format, target_path=tempfile.TemporaryDirectory().name)) | ||
assert struc_download == struc_biotite | ||
|
||
@pytest.mark.parametrize('code', codes) | ||
@pytest.mark.parametrize('database', databases) | ||
@pytest.mark.parametrize('format', ['pdb', 'cif']) | ||
def test_fetch_with_cache(tmpdir, code, format, database): | ||
cache_dir = tmpdir.mkdir("cache") | ||
file = download.fetch(code, format, cache=str(cache_dir), database=database) | ||
|
||
assert isinstance(file, str) | ||
assert os.path.isfile(file) | ||
assert file.endswith(f"{code}.{format}") | ||
|
||
|
||
with open(file, "r") as f: | ||
content = f.read() | ||
assert content.startswith(_filestart(format)) | ||
|
||
databases = ['rcsb'] # currently can't figure out downloading from the pdbe | ||
|
||
@pytest.mark.parametrize('code', codes) | ||
@pytest.mark.parametrize('database', databases) | ||
@pytest.mark.parametrize('format', ['pdb', 'cif']) | ||
def test_fetch_without_cache(tmpdir, code, format, database): | ||
file = download.fetch(code, format, cache=None, database=database) | ||
|
||
assert isinstance(file, io.StringIO) | ||
content = file.getvalue() | ||
assert content.startswith(_filestart(format)) | ||
|
||
@pytest.mark.parametrize('database', databases) | ||
def test_fetch_with_invalid_format(database): | ||
code = '4OZS' | ||
format = "xyz" | ||
|
||
with pytest.raises(ValueError): | ||
download.fetch(code, format, cache=None, database=database) | ||
|
||
@pytest.mark.parametrize('code', codes) | ||
@pytest.mark.parametrize('database', databases) | ||
@pytest.mark.parametrize('format', ['bcif', 'mmtf']) | ||
def test_fetch_with_binary_format(tmpdir, code, database, format): | ||
cache_dir = tmpdir.mkdir("cache") | ||
file = download.fetch(code, format, cache=str(cache_dir), database=database) | ||
|
||
assert isinstance(file, str) | ||
assert os.path.isfile(file) | ||
assert file.endswith(f"{code}.{format}") | ||
|
||
if format == "bcif": | ||
start = b"\x83\xa7" | ||
elif format == "mmtf": | ||
start = b"\xde\x00" | ||
|
||
with open(file, "rb") as f: | ||
content = f.read() | ||
assert content.startswith(start) |