Skip to content

Commit

Permalink
optionally init client with project; rm name args
Browse files Browse the repository at this point in the history
  • Loading branch information
tschaume committed Nov 12, 2021
1 parent 545d7ca commit 17c0aea
Showing 1 changed file with 76 additions and 58 deletions.
134 changes: 76 additions & 58 deletions mpcontribs-client/mpcontribs/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,20 @@ class Client(SwaggerClient):

_shared_state = {}

def __init__(self, apikey: str = None, headers: dict = None, host: str = None):
def __init__(
self,
apikey: str = None,
headers: dict = None,
host: str = None,
project: str = None
):
"""Initialize the client - only reloads API spec from server as needed
Args:
apikey (str): API key (or use MPCONTRIBS_API_KEY env var) - ignored if headers set
headers (dict): custom headers for localhost connections
host (str): host address to connect to (or use MPCONTRIBS_API_HOST env var)
project (str): use this project for all operations (query, update, create, delete)
"""
# - Kong forwards consumer headers when api-key used for auth
# - forward consumer headers when connecting through localhost
Expand Down Expand Up @@ -430,7 +437,10 @@ def __init__(self, apikey: str = None, headers: dict = None, host: str = None):
self.swagger_spec.http_client.headers != self.headers
) or (
self.swagger_spec.spec_dict["host"] != self.host
) or (
"project" not in self.__dict__ or self.project != project
):
self.project = project
self._load()

def __enter__(self):
Expand Down Expand Up @@ -465,11 +475,15 @@ def _load(self):

# expand regex-based query parameters for `data` columns
try:
resp = self.projects.get_entries(_fields=["columns"]).result()
query = {"name": self.project} if self.project else {}
resp = self.projects.get_entries(_fields=["columns"], **query).result()
except AttributeError:
# skip in tests
return

if self.project and not resp["data"]:
raise ValueError(f"{self.project} doesn't exist, or access denied!")

columns = {"string": [], "number": []}

for project in resp["data"]:
Expand Down Expand Up @@ -539,7 +553,7 @@ def _split_query(
) -> List[dict]:
"""Avoid URI too long errors"""
pp_default, pp_max = self._get_per_page_default_max(op=op, resource=resource)
per_page = pp_default if "id__in" in query else pp_max
per_page = pp_default if any(k.endswith("__in") for k in query.keys()) else pp_max
query["per_page"] = per_page
nr_params_to_split = sum(
len(v) > per_page for v in query.values() if isinstance(v, list)
Expand Down Expand Up @@ -603,17 +617,13 @@ def _get_future(
setattr(future, "track_id", track_id)
return future

def get_project_names(self) -> List[str]:
"""Retrieve list of project names."""
resp = self.projects.get_entries(_fields=["name"]).result()
return [p["name"] for p in resp["data"]]

def get_project(self, name: str) -> Type[Dict]:
def get_project(self, name: str = None) -> Type[Dict]:
"""Retrieve full project entry
Args:
name (str): name of the project
"""
name = self.project or name
return Dict(self.projects.get_entry(pk=name, _fields=["_all"]).result())

def get_contribution(self, cid: str) -> Type[Dict]:
Expand Down Expand Up @@ -729,7 +739,7 @@ def get_attachment(self, aid_or_md5: str) -> Type[Attachment]:

return Attachment(self.attachments.get_entry(pk=aid, _fields=["_all"]).result())

def init_columns(self, name: str, columns: dict) -> dict:
def init_columns(self, columns: dict) -> dict:
"""initialize columns for a project to set their order and desired units
The `columns` field tracks the minima and maxima of each `data` field as
Expand All @@ -750,7 +760,7 @@ def init_columns(self, name: str, columns: dict) -> dict:
Example:
>>> client.init_columns("sandbox", {"a": None, "b.c": "eV", "b.d": "mm", "e": ""})
>>> client.init_columns({"a": None, "b.c": "eV", "b.d": "mm", "e": ""})
This example will result in column headers on the project landing page of the form
Expand All @@ -761,15 +771,14 @@ def init_columns(self, name: str, columns: dict) -> dict:
Args:
name (str): name of the project for which to initialize data columns
columns (dict): dictionary mapping data column to its unit
"""
if not isinstance(name, str):
return {"error": "`name` argument must be a string!"}

if not isinstance(columns, dict):
return {"error": "`columns` argument must be a dict!"}

if not self.project:
return {"error": "initialize client with project argument!"}

existing_columns = set()
for k, v in columns.items():
if k in COMPONENTS:
Expand Down Expand Up @@ -797,16 +806,11 @@ def init_columns(self, name: str, columns: dict) -> dict:

existing_columns.add(k)

valid_projects = self.get_project_names()

if name not in valid_projects:
return {"error": f"{name} doesn't exist or you don't have access!"}

# sort to avoid "overlapping columns" error in handsontable's NestedHeaders
sorted_columns = flatten(unflatten(columns, splitter="dot"), reducer="dot")

# reconcile with existing columns
resp = self.projects.get_entry(pk=name, _fields=["columns"]).result()
resp = self.projects.get_entry(pk=self.project, _fields=["columns"]).result()
existing_columns, new_columns = {}, []

for col in resp["columns"]:
Expand Down Expand Up @@ -854,32 +858,30 @@ def init_columns(self, name: str, columns: dict) -> dict:
if not valid:
return {"error": valid}

self.projects.update_entry(pk=name, project={"columns": []}).result()
return self.projects.update_entry(pk=name, project=payload).result()
self.projects.update_entry(pk=self.project, project={"columns": []}).result()
return self.projects.update_entry(pk=self.project, project=payload).result()

def delete_contributions(
self,
name: str,
query: dict = None,
timeout: int = -1
):
"""Remove all contributions for a project and query
def delete_contributions(self, query: dict = None, timeout: int = -1):
"""Remove all contributions for a query
Note: This also resets the columns field for a project. It might have to be
re-initialized via `client.init_columns()`.
Args:
name (str): name of the project for which to delete contributions
query (dict): optional query to select contributions
timeout (int): cancel remaining requests if timeout exceeded (in seconds)
"""
if not self.project or not query or "project" not in query:
print("initialize client with project, or include project in query!")
return

tic = time.perf_counter()
query = query or {}
query["project"] = name
cids = list(self.get_all_ids(query).get(name, {}).get("ids", set()))
query["project"] = self.project
cids = list(self.get_all_ids(query).get(self.project, {}).get("ids", set()))

if not cids:
print(f"There aren't any contributions to delete for {name}")
print(f"There aren't any contributions to delete for {self.project}")
return

total = len(cids)
Expand Down Expand Up @@ -922,7 +924,10 @@ def get_totals(
return

query = query or {}
skip_keys = {"per_page", "_fields", "format"}
if self.project and "project" not in query:
query["project"] = self.project

skip_keys = {"per_page", "_fields", "format", "_sort"}
query = {k: v for k, v in query.items() if k not in skip_keys}
query["_fields"] = [] # only need totals -> explicitly request no fields
queries = self._split_query(query, resource=resource, op=op) # don't paginate
Expand All @@ -932,7 +937,7 @@ def get_totals(

for resp in responses.values():
for k in result:
result[k] += resp["result"][k]
result[k] += resp.get("result", {}).get(k, 0)

return result["total_count"], result["total_pages"]

Expand All @@ -945,8 +950,13 @@ def get_unique_identifiers_flags(self, projects: list = None) -> dict:
Returns:
{"<project-name>": True|False, ...}
"""
unique_identifiers = {}
query = {"name__in": projects} if projects else {}
unique_identifiers, query = {}, {}

if projects:
query = {"name__in": projects}
elif self.project:
query = {"name": self.project}

resp = self.projects.get_entries(
_fields=["name", "unique_identifiers"], **query
).result()
Expand Down Expand Up @@ -1036,6 +1046,9 @@ def get_all_ids(

ret = {}
query = query or {}
if self.project and "project" not in query:
query["project"] = self.project

[query.pop(k, None) for k in ["page", "per_page", "_fields"]]
id_fields = {"project", "id", "identifier"}

Expand Down Expand Up @@ -1116,7 +1129,6 @@ def get_all_ids(

def update_contributions(
self,
name: str,
data: dict,
query: dict = None,
timeout: int = -1
Expand All @@ -1126,7 +1138,6 @@ def update_contributions(
See `client.contributions.get_entries()` for keyword arguments used in query.
Args:
name (str): name of the project
data (dict): update to apply on every matching contribution
query (dict): optional query to select contributions
timeout (int): cancel remaining requests if timeout exceeded (in seconds)
Expand All @@ -1140,15 +1151,22 @@ def update_contributions(
return {"error": valid}

query = query or {}
query["project"] = name
cids = list(self.get_all_ids(query).get(name, {}).get("ids", set()))

if not self.project or not query or "project" not in query:
return {"error": "initialize client with project, or include project in query!"}

if "project" in query and self.project != query["project"]:
return {"error": f"client initialized with different project {self.project}!"}

query["project"] = self.project
cids = list(self.get_all_ids(query).get(self.project, {}).get("ids", set()))

if not cids:
print(f"There aren't any contributions to update for {name}")
print(f"There aren't any contributions to update for {self.project}")
return

# get current list of data columns to decide if swagger reload is needed
resp = self.projects.get_entry(pk=name, _fields=["columns"]).result()
resp = self.projects.get_entry(pk=self.project, _fields=["columns"]).result()
old_paths = set(c["path"] for c in resp["columns"])

total = len(cids)
Expand All @@ -1163,7 +1181,7 @@ def update_contributions(
updated = sum(resp["count"] for _, resp in responses.items())

if updated:
resp = self.projects.get_entry(pk=name, _fields=["columns"]).result()
resp = self.projects.get_entry(pk=self.project, _fields=["columns"]).result()
new_paths = set(c["path"] for c in resp["columns"])

if new_paths != old_paths:
Expand All @@ -1174,44 +1192,39 @@ def update_contributions(

def make_public(
self,
name: str,
query: dict = None,
recursive: bool = False,
timeout: int = -1
) -> dict:
"""Publish a project and optionally its contributions
Args:
name (str): name of the project
query (dict): optional query to select contributions
recursive (bool): also publish according contributions?
"""
return self._set_is_public(
True, name, query=query, recursive=recursive, timeout=timeout
True, query=query, recursive=recursive, timeout=timeout
)

def make_private(
self,
name: str,
query: dict = None,
recursive: bool = False,
timeout: int = -1
) -> dict:
"""Make a project and optionally its contributions private
Args:
name (str): name of the project
query (dict): optional query to select contributions
recursive (bool): also make according contributions private?
"""
return self._set_is_public(
False, name, query=query, recursive=recursive, timeout=timeout
False, query=query, recursive=recursive, timeout=timeout
)

def _set_is_public(
self,
is_public: bool,
name: str,
query: dict = None,
recursive: bool = False,
timeout: int = -1
Expand All @@ -1220,32 +1233,34 @@ def _set_is_public(
Args:
is_public (bool): target value for `is_public` flag
name (str): name of the project
query (dict): optional query to select contributions
recursive (bool): also set `is_public` for according contributions?
timeout (int): cancel remaining requests if timeout exceeded (in seconds)
"""
if not self.project or not query or "project" not in query:
return {"error": "initialize client with project, or include project in query!"}

try:
resp = self.projects.get_entry(pk=name, _fields=["is_public"]).result()
resp = self.projects.get_entry(pk=self.project, _fields=["is_public"]).result()
except HTTPNotFound:
return {"error": f"project `{name}` not found or access denied!"}
return {"error": f"project `{self.project}` not found or access denied!"}

if not recursive and resp["is_public"] == is_public:
return {"warning": f"`is_public` already set to {is_public} for `{name}`."}
return {"warning": f"`is_public` already set to {is_public} for `{self.project}`."}

ret = {}

if resp["is_public"] != is_public:
resp = self.projects.update_entry(
pk=name, project={"is_public": is_public}
pk=self.project, project={"is_public": is_public}
).result()
ret["published"] = resp["is_public"] == is_public

if recursive:
query = query or {}
query["is_public"] = not is_public
ret["contributions"] = self.update_contributions(
name, {"is_public": is_public}, query=query, timeout=timeout
{"is_public": is_public}, query=query, timeout=timeout
)

return ret
Expand Down Expand Up @@ -1313,6 +1328,9 @@ def submit_contributions(
collect_ids.append(c["id"])
elif "project" in c and "identifier" in c:
project_names.add(c["project"])
elif self.project and "project" not in c and "identifier" in c:
project_names.add(self.project)
contributions[idx]["project"] = self.project
else:
return {
"error": f"Provide `project` & `identifier`, or `id` for contribution #{idx}!"
Expand Down

0 comments on commit 17c0aea

Please sign in to comment.