Skip to content

Commit

Permalink
Add multi-entity associations CLI cmd and api endpoint (#369)
Browse files Browse the repository at this point in the history
### Related issues

- Closes #365 

### Summary

- adds CLI command for multi-entity associations, for example:  
```bash
$ monarch multi-entity-associations -e MONDO:0012933 -e MONDO:0005439 -e MANGO:0023456 -c biolink:Gene -c biolink:Disease 
```
- adds API endpoint for same thing, for example:  
```
<MONARCH_API_URL>/v3/api/association/multi?entity=MONDO%3A0012933&entity=MONDO%3A0005439&entity=MANDO%3A0001138&counterpart_category=biolink%3AGene&counterpart_category=biolink%3ADisease&limit=20&offset=0
```
- adds solr integration test for multi-entity associations
implementation
  • Loading branch information
glass-ships authored Oct 6, 2023
1 parent f75c376 commit ec8c6c3
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 14 deletions.
17 changes: 16 additions & 1 deletion backend/src/monarch_py/api/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import APIRouter, Depends, Query
from monarch_py.api.additional_models import PaginationParams
from monarch_py.api.config import solr
from monarch_py.datamodels.model import AssociationResults
from monarch_py.datamodels.model import AssociationResults, MultiEntityAssociationResults

router = APIRouter(
tags=["association"],
Expand Down Expand Up @@ -34,3 +34,18 @@ async def _get_associations(
limit=pagination.limit,
)
return response

@router.get("/multi")
async def _get_multi_entity_associations(
entity: Union[List[str], None] = Query(default=None),
counterpart_category: Union[List[str], None] = Query(default=None),
pagination: PaginationParams = Depends(),
) -> List[MultiEntityAssociationResults]:
"""Retrieves all associations between each entity and each counterpart category."""
response = solr().get_multi_entity_associations(
entity=entity,
counterpart_category=counterpart_category,
offset=pagination.offset,
limit_per_group=pagination.limit,
)
return response
28 changes: 28 additions & 0 deletions backend/src/monarch_py/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,5 +279,33 @@ def compare(
format_output(fmt, response, output)


@app.command("multi-entity-associations")
def multi_entity_associations(
entity: List[str] = typer.Option(None, "--entity", "-e", help="Comma-separated list of entities"),
counterpart_category: List[str] = typer.Option(None, "--counterpart-category", "-c"),
limit: int = typer.Option(20, "--limit", "-l"),
offset: int = typer.Option(0, "--offset"),
fmt: str = typer.Option(
"json",
"--format",
"-f",
help="The format of the output (json, yaml, tsv, table)",
),
output: str = typer.Option(None, "--output", "-o", help="The path to the output file"),
):
"""
Paginate through associations for multiple entities
Args:
entity: A comma-separated list of entities
counterpart_category: A comma-separated list of counterpart categories
limit: The number of associations to return
offset: The offset of the first association to be retrieved
fmt: The format of the output (json, yaml, tsv, table)
output: The path to the output file (stdout if not specified)
"""
solr_cli.multi_entity_associations(**locals())


if __name__ == "__main__":
app()
4 changes: 3 additions & 1 deletion backend/src/monarch_py/datamodels/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ class MultiEntityAssociationResults(Results):

id: str = Field(...)
name: Optional[str] = Field(None)
associated_categories: List[CategoryGroupedAssociationResults] = Field(default_factory=list)
associated_categories: List[CategoryGroupedAssociationResults] = Field(
default_factory=list
)
limit: int = Field(..., description="""number of items to return in a response""")
offset: int = Field(..., description="""offset into the total number of items""")
total: int = Field(..., description="""total number of items matching a query""")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,11 @@ def get_multi_entity_associations(
for entity_id in entity:
ent = self.get_entity(entity_id, extra=False)
if ent is None:
results.append(MultiEntityAssociationResults(
id=entity_id, name="Entity not found", total=0, offset=offset, limit=limit_per_group, associated_categories=[]
))
results.append(
MultiEntityAssociationResults(
id=entity_id, name="Entity not found", total=0, offset=0, limit=0, associated_categories=[]
)
)
continue
entity_result = MultiEntityAssociationResults(
id=ent.id, name=ent.name, total=0, offset=offset, limit=limit_per_group, associated_categories=[]
Expand Down
40 changes: 37 additions & 3 deletions backend/src/monarch_py/solr_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def associations(
output: The path to the output file (stdout if not specified)
"""
args = locals()
args.pop("update", None)
args.pop("fmt", None)
args.pop("output", None)

Expand All @@ -169,6 +168,43 @@ def associations(
format_output(fmt, response, output)


@solr_app.command("multi-entity-associations")
def multi_entity_associations(
entity: List[str] = typer.Option(None, "--entity", "-e", help="Entity ID to get associations for"),
counterpart_category: List[str] = typer.Option(None, "--counterpart-category", "-c", help="Counterpart category to get associations for"),
limit: int = typer.Option(20, "--limit", "-l"),
offset: int = typer.Option(0, "--offset"),
fmt: str = typer.Option(
"json",
"--format",
"-f",
help="The format of the output (json, yaml, tsv, table)",
),
output: str = typer.Option(None, "--output", "-o", help="The path to the output file"),
):
"""
Paginate through associations for multiple entities
Args:
entity: A comma-separated list of entities
counterpart_category: A comma-separated list of counterpart categories
limit: The number of associations to return
offset: The offset of the first association to be retrieved
fmt: The format of the output (json, yaml, tsv, table)
output: The path to the output file (stdout if not specified)
"""
# console.print("\n[bold red]Multi-entity associations not implemented in CLI.[/]\n")
# raise typer.Exit(1)
args = locals()
args.pop("fmt", None)
args.pop("output", None)
args['limit_per_group'] = args.pop('limit')

solr = get_solr(update=False)
response = solr.get_multi_entity_associations(**args)
format_output(fmt, response, output)


@solr_app.command("search")
def search(
q: str = typer.Option(None, "--query", "-q"),
Expand Down Expand Up @@ -253,7 +289,6 @@ def histopheno(
subject (str): The subject of the association
Optional Args:
update (bool): Whether to re-download the Monarch KG. Default False
fmt (str): The format of the output (json, yaml, tsv, table). Default JSON
output (str): The path to the output file. Default stdout
"""
Expand Down Expand Up @@ -285,7 +320,6 @@ def association_counts(
entity (str): The entity to get association counts for
Optional Args:
update (bool): Whether to re-download the Monarch KG. Default False
fmt (str): The format of the output (json, yaml, tsv, table). Default JSON
output (str): The path to the output file. Default stdout
"""
Expand Down
16 changes: 10 additions & 6 deletions backend/src/monarch_py/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
import json
import sys
from typing import Union, Dict
from typing import Union, Dict, List

import typer
import yaml
Expand Down Expand Up @@ -88,12 +88,14 @@ def get_headers_from_obj(obj: ConfiguredBaseModel) -> list:
return list(headers)


def to_json(obj: Union[ConfiguredBaseModel, Dict], file: str):
def to_json(obj: Union[ConfiguredBaseModel, Dict, List[ConfiguredBaseModel]], file: str):
"""Converts a pydantic model to a JSON string."""
if isinstance(obj, ConfiguredBaseModel):
json_value = obj.json(indent=4)
elif isinstance(obj, dict):
json_value = json.dumps(obj, indent=4)
elif isinstance(obj, list):
json_value = json.dumps({"items": [o.dict() for o in obj]}, indent=4)
if file:
with open(file, "w") as f:
f.write(json_value)
Expand All @@ -117,7 +119,8 @@ def to_tsv(obj: ConfiguredBaseModel, file: str) -> str:
headers = obj.items[0].dict().keys()
rows = [list(item.dict().values()) for item in obj.items]
else:
raise TypeError(FMT_INPUT_ERROR_MSG)
console.print(f"\n[bold red]{FMT_INPUT_ERROR_MSG}[/]\n")
raise typer.Exit(1)

fh = open(file, "w") if file else sys.stdout
writer = csv.writer(fh, delimiter="\t")
Expand Down Expand Up @@ -147,7 +150,8 @@ def to_table(obj: ConfiguredBaseModel):
headers = obj.items[0].dict().keys()
rows = [list(item.dict().values()) for item in obj.items]
else:
raise TypeError(FMT_INPUT_ERROR_MSG)
console.print(f"\n[bold red]{FMT_INPUT_ERROR_MSG}[/]\n")
raise typer.Exit(1)

for row in rows:
for i, value in enumerate(row):
Expand Down Expand Up @@ -179,8 +183,8 @@ def to_yaml(obj: ConfiguredBaseModel, file: str):
elif isinstance(obj, Results) or isinstance(obj, HistoPheno) or isinstance(obj, AssociationCountList):
yaml.dump([item.dict() for item in obj.items], fh, indent=4)
else:
raise TypeError(FMT_INPUT_ERROR_MSG)

console.print(f"\n[bold red]{FMT_INPUT_ERROR_MSG}[/]\n")
raise typer.Exit(1)
if file:
console.print(f"\nOutput written to {file}\n")
fh.close()
Expand Down
15 changes: 15 additions & 0 deletions backend/tests/integration/test_solr_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,18 @@ def test_entity():
assert response.total > 50
for association in response.items:
assert "MONDO:0007947" in association.subject_closure or "MONDO:0007947" in association.object_closure


def test_multi_entity_associations():
si = SolrImplementation()
response = si.get_multi_entity_associations(
entity=["MONDO:0012933", "MONDO:0005439", "MANDO:0001138"],
counterpart_category=["biolink:Gene", "biolink:Disease"],
)
assert response
assert len(response) == 3
assert response[2].name == "Entity not found"
# assert response[0].associated_categories['biolink:Disease'].total > 0
for c in response[0].associated_categories:
if c.counterpart_category == "biolink:Disease":
assert c.total > 0

0 comments on commit ec8c6c3

Please sign in to comment.