diff --git a/mpcontribs-client/mpcontribs/client/__init__.py b/mpcontribs-client/mpcontribs/client/__init__.py index cfbcb9705..c83c97fc9 100644 --- a/mpcontribs-client/mpcontribs/client/__init__.py +++ b/mpcontribs-client/mpcontribs/client/__init__.py @@ -391,13 +391,20 @@ class Client(SwaggerClient): _shared_state = {} - def __init__(self, apikey: str = None, headers: dict = None, host: str = None): + def __init__( + self, + apikey: str = None, + headers: dict = None, + host: str = None, + project: str = None + ): """Initialize the client - only reloads API spec from server as needed Args: apikey (str): API key (or use MPCONTRIBS_API_KEY env var) - ignored if headers set headers (dict): custom headers for localhost connections host (str): host address to connect to (or use MPCONTRIBS_API_HOST env var) + project (str): use this project for all operations (query, update, create, delete) """ # - Kong forwards consumer headers when api-key used for auth # - forward consumer headers when connecting through localhost @@ -430,7 +437,10 @@ def __init__(self, apikey: str = None, headers: dict = None, host: str = None): self.swagger_spec.http_client.headers != self.headers ) or ( self.swagger_spec.spec_dict["host"] != self.host + ) or ( + "project" not in self.__dict__ or self.project != project ): + self.project = project self._load() def __enter__(self): @@ -465,11 +475,15 @@ def _load(self): # expand regex-based query parameters for `data` columns try: - resp = self.projects.get_entries(_fields=["columns"]).result() + query = {"name": self.project} if self.project else {} + resp = self.projects.get_entries(_fields=["columns"], **query).result() except AttributeError: # skip in tests return + if self.project and not resp["data"]: + raise ValueError(f"{self.project} doesn't exist, or access denied!") + columns = {"string": [], "number": []} for project in resp["data"]: @@ -539,7 +553,7 @@ def _split_query( ) -> List[dict]: """Avoid URI too long errors""" pp_default, pp_max = self._get_per_page_default_max(op=op, resource=resource) - per_page = pp_default if "id__in" in query else pp_max + per_page = pp_default if any(k.endswith("__in") for k in query.keys()) else pp_max query["per_page"] = per_page nr_params_to_split = sum( len(v) > per_page for v in query.values() if isinstance(v, list) @@ -603,17 +617,13 @@ def _get_future( setattr(future, "track_id", track_id) return future - def get_project_names(self) -> List[str]: - """Retrieve list of project names.""" - resp = self.projects.get_entries(_fields=["name"]).result() - return [p["name"] for p in resp["data"]] - - def get_project(self, name: str) -> Type[Dict]: + def get_project(self, name: str = None) -> Type[Dict]: """Retrieve full project entry Args: name (str): name of the project """ + name = self.project or name return Dict(self.projects.get_entry(pk=name, _fields=["_all"]).result()) def get_contribution(self, cid: str) -> Type[Dict]: @@ -729,7 +739,7 @@ def get_attachment(self, aid_or_md5: str) -> Type[Attachment]: return Attachment(self.attachments.get_entry(pk=aid, _fields=["_all"]).result()) - def init_columns(self, name: str, columns: dict) -> dict: + def init_columns(self, columns: dict) -> dict: """initialize columns for a project to set their order and desired units The `columns` field tracks the minima and maxima of each `data` field as @@ -750,7 +760,7 @@ def init_columns(self, name: str, columns: dict) -> dict: Example: - >>> client.init_columns("sandbox", {"a": None, "b.c": "eV", "b.d": "mm", "e": ""}) + >>> client.init_columns({"a": None, "b.c": "eV", "b.d": "mm", "e": ""}) This example will result in column headers on the project landing page of the form @@ -761,15 +771,14 @@ def init_columns(self, name: str, columns: dict) -> dict: Args: - name (str): name of the project for which to initialize data columns columns (dict): dictionary mapping data column to its unit """ - if not isinstance(name, str): - return {"error": "`name` argument must be a string!"} - if not isinstance(columns, dict): return {"error": "`columns` argument must be a dict!"} + if not self.project: + return {"error": "initialize client with project argument!"} + existing_columns = set() for k, v in columns.items(): if k in COMPONENTS: @@ -797,16 +806,11 @@ def init_columns(self, name: str, columns: dict) -> dict: existing_columns.add(k) - valid_projects = self.get_project_names() - - if name not in valid_projects: - return {"error": f"{name} doesn't exist or you don't have access!"} - # sort to avoid "overlapping columns" error in handsontable's NestedHeaders sorted_columns = flatten(unflatten(columns, splitter="dot"), reducer="dot") # reconcile with existing columns - resp = self.projects.get_entry(pk=name, _fields=["columns"]).result() + resp = self.projects.get_entry(pk=self.project, _fields=["columns"]).result() existing_columns, new_columns = {}, [] for col in resp["columns"]: @@ -854,32 +858,30 @@ def init_columns(self, name: str, columns: dict) -> dict: if not valid: return {"error": valid} - self.projects.update_entry(pk=name, project={"columns": []}).result() - return self.projects.update_entry(pk=name, project=payload).result() + self.projects.update_entry(pk=self.project, project={"columns": []}).result() + return self.projects.update_entry(pk=self.project, project=payload).result() - def delete_contributions( - self, - name: str, - query: dict = None, - timeout: int = -1 - ): - """Remove all contributions for a project and query + def delete_contributions(self, query: dict = None, timeout: int = -1): + """Remove all contributions for a query Note: This also resets the columns field for a project. It might have to be re-initialized via `client.init_columns()`. Args: - name (str): name of the project for which to delete contributions query (dict): optional query to select contributions timeout (int): cancel remaining requests if timeout exceeded (in seconds) """ + if not self.project or not query or "project" not in query: + print("initialize client with project, or include project in query!") + return + tic = time.perf_counter() query = query or {} - query["project"] = name - cids = list(self.get_all_ids(query).get(name, {}).get("ids", set())) + query["project"] = self.project + cids = list(self.get_all_ids(query).get(self.project, {}).get("ids", set())) if not cids: - print(f"There aren't any contributions to delete for {name}") + print(f"There aren't any contributions to delete for {self.project}") return total = len(cids) @@ -922,7 +924,10 @@ def get_totals( return query = query or {} - skip_keys = {"per_page", "_fields", "format"} + if self.project and "project" not in query: + query["project"] = self.project + + skip_keys = {"per_page", "_fields", "format", "_sort"} query = {k: v for k, v in query.items() if k not in skip_keys} query["_fields"] = [] # only need totals -> explicitly request no fields queries = self._split_query(query, resource=resource, op=op) # don't paginate @@ -932,7 +937,7 @@ def get_totals( for resp in responses.values(): for k in result: - result[k] += resp["result"][k] + result[k] += resp.get("result", {}).get(k, 0) return result["total_count"], result["total_pages"] @@ -945,8 +950,13 @@ def get_unique_identifiers_flags(self, projects: list = None) -> dict: Returns: {"": True|False, ...} """ - unique_identifiers = {} - query = {"name__in": projects} if projects else {} + unique_identifiers, query = {}, {} + + if projects: + query = {"name__in": projects} + elif self.project: + query = {"name": self.project} + resp = self.projects.get_entries( _fields=["name", "unique_identifiers"], **query ).result() @@ -1036,6 +1046,9 @@ def get_all_ids( ret = {} query = query or {} + if self.project and "project" not in query: + query["project"] = self.project + [query.pop(k, None) for k in ["page", "per_page", "_fields"]] id_fields = {"project", "id", "identifier"} @@ -1116,7 +1129,6 @@ def get_all_ids( def update_contributions( self, - name: str, data: dict, query: dict = None, timeout: int = -1 @@ -1126,7 +1138,6 @@ def update_contributions( See `client.contributions.get_entries()` for keyword arguments used in query. Args: - name (str): name of the project data (dict): update to apply on every matching contribution query (dict): optional query to select contributions timeout (int): cancel remaining requests if timeout exceeded (in seconds) @@ -1140,15 +1151,22 @@ def update_contributions( return {"error": valid} query = query or {} - query["project"] = name - cids = list(self.get_all_ids(query).get(name, {}).get("ids", set())) + + if not self.project or not query or "project" not in query: + return {"error": "initialize client with project, or include project in query!"} + + if "project" in query and self.project != query["project"]: + return {"error": f"client initialized with different project {self.project}!"} + + query["project"] = self.project + cids = list(self.get_all_ids(query).get(self.project, {}).get("ids", set())) if not cids: - print(f"There aren't any contributions to update for {name}") + print(f"There aren't any contributions to update for {self.project}") return # get current list of data columns to decide if swagger reload is needed - resp = self.projects.get_entry(pk=name, _fields=["columns"]).result() + resp = self.projects.get_entry(pk=self.project, _fields=["columns"]).result() old_paths = set(c["path"] for c in resp["columns"]) total = len(cids) @@ -1163,7 +1181,7 @@ def update_contributions( updated = sum(resp["count"] for _, resp in responses.items()) if updated: - resp = self.projects.get_entry(pk=name, _fields=["columns"]).result() + resp = self.projects.get_entry(pk=self.project, _fields=["columns"]).result() new_paths = set(c["path"] for c in resp["columns"]) if new_paths != old_paths: @@ -1174,7 +1192,6 @@ def update_contributions( def make_public( self, - name: str, query: dict = None, recursive: bool = False, timeout: int = -1 @@ -1182,17 +1199,15 @@ def make_public( """Publish a project and optionally its contributions Args: - name (str): name of the project query (dict): optional query to select contributions recursive (bool): also publish according contributions? """ return self._set_is_public( - True, name, query=query, recursive=recursive, timeout=timeout + True, query=query, recursive=recursive, timeout=timeout ) def make_private( self, - name: str, query: dict = None, recursive: bool = False, timeout: int = -1 @@ -1200,18 +1215,16 @@ def make_private( """Make a project and optionally its contributions private Args: - name (str): name of the project query (dict): optional query to select contributions recursive (bool): also make according contributions private? """ return self._set_is_public( - False, name, query=query, recursive=recursive, timeout=timeout + False, query=query, recursive=recursive, timeout=timeout ) def _set_is_public( self, is_public: bool, - name: str, query: dict = None, recursive: bool = False, timeout: int = -1 @@ -1220,24 +1233,26 @@ def _set_is_public( Args: is_public (bool): target value for `is_public` flag - name (str): name of the project query (dict): optional query to select contributions recursive (bool): also set `is_public` for according contributions? timeout (int): cancel remaining requests if timeout exceeded (in seconds) """ + if not self.project or not query or "project" not in query: + return {"error": "initialize client with project, or include project in query!"} + try: - resp = self.projects.get_entry(pk=name, _fields=["is_public"]).result() + resp = self.projects.get_entry(pk=self.project, _fields=["is_public"]).result() except HTTPNotFound: - return {"error": f"project `{name}` not found or access denied!"} + return {"error": f"project `{self.project}` not found or access denied!"} if not recursive and resp["is_public"] == is_public: - return {"warning": f"`is_public` already set to {is_public} for `{name}`."} + return {"warning": f"`is_public` already set to {is_public} for `{self.project}`."} ret = {} if resp["is_public"] != is_public: resp = self.projects.update_entry( - pk=name, project={"is_public": is_public} + pk=self.project, project={"is_public": is_public} ).result() ret["published"] = resp["is_public"] == is_public @@ -1245,7 +1260,7 @@ def _set_is_public( query = query or {} query["is_public"] = not is_public ret["contributions"] = self.update_contributions( - name, {"is_public": is_public}, query=query, timeout=timeout + {"is_public": is_public}, query=query, timeout=timeout ) return ret @@ -1313,6 +1328,9 @@ def submit_contributions( collect_ids.append(c["id"]) elif "project" in c and "identifier" in c: project_names.add(c["project"]) + elif self.project and "project" not in c and "identifier" in c: + project_names.add(self.project) + contributions[idx]["project"] = self.project else: return { "error": f"Provide `project` & `identifier`, or `id` for contribution #{idx}!"