Skip to content

Commit

Permalink
Add type hints and mypy linting
Browse files Browse the repository at this point in the history
  • Loading branch information
mprpic committed Apr 21, 2022
1 parent f936228 commit dfb99b7
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 50 deletions.
82 changes: 52 additions & 30 deletions cvelib/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
from datetime import date, datetime
from functools import wraps
from typing import Any, Callable, DefaultDict, Optional, Union

import click

Expand All @@ -17,28 +18,30 @@
}


def validate_cve(ctx, param, value):
def validate_cve(ctx: click.Context, param: click.Parameter, value: Optional[str]) -> Optional[str]:
if value is None:
return
return None
if not CVE_RE.match(value):
raise click.BadParameter("invalid CVE ID.")
return value


def validate_year(ctx, param, value):
def validate_year(
ctx: click.Context, param: click.Parameter, value: Optional[str]
) -> Optional[str]:
if value is None:
return
return None
# Hopefully this code won't be around in year 10,000.
if not re.match(r"^[1-9]\d{3}$", value):
raise click.BadParameter("invalid year.")
return value


def human_ts(ts):
def human_ts(ts: str) -> str:
return datetime.strptime(ts, "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%c")


def print_cve(cve):
def print_cve(cve: dict) -> None:
click.secho(cve["cve_id"], bold=True)
click.echo(f"├─ State:\t{cve['state']}")
# CVEs reserved by other CNAs do not include information on who requested them and when.
Expand All @@ -50,7 +53,7 @@ def print_cve(cve):
click.echo(f"└─ Owning CNA:\t{cve['owning_cna']}")


def print_table(lines):
def print_table(lines: list) -> None:
"""Print tabulated data based on the widths of the longest values in each column."""
col_widths = []
for item_index in range(len(lines[0])):
Expand All @@ -65,11 +68,11 @@ def print_table(lines):
click.echo(text)


def print_json_data(data):
def print_json_data(data: Union[dict, list]) -> None:
click.echo(json.dumps(data, indent=4, sort_keys=True))


def print_user(user):
def print_user(user: dict) -> None:
name = get_full_name(user)
if name:
click.echo(f"{name} — ", nl=False)
Expand All @@ -85,31 +88,32 @@ def print_user(user):
click.echo(f"└─ Modified:\t{human_ts(user['time']['modified'])}")


def get_full_name(user_data):
def get_full_name(user_data: dict) -> Optional[str]:
# If no name values are defined on a user, the entire `name` object is not returned in the
# user data response; see https://github.com/CVEProject/cve-services/issues/436.
name = user_data.get("name", {})
if name:
return f"{name.get('first', '')} {name.get('last', '')}".strip() or None
return None


def bool_to_text(value):
def bool_to_text(value: Optional[bool]) -> str:
if value is None:
return "N/A"
return "Yes" if value else "No"


def natural_cve_sort(cve):
def natural_cve_sort(cve: str) -> list[int]:
if not cve:
return []
return [int(x) for x in cve.split("-")[1:]]


def handle_cve_api_error(func):
def handle_cve_api_error(func: Callable) -> Callable:
"""Decorator for catching CVE API exceptions and formatting the error message."""

@wraps(func)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> Callable:
try:
return func(*args, **kwargs)
except CveApiError as exc:
Expand All @@ -125,7 +129,15 @@ def wrapped(*args, **kwargs):


class Config:
def __init__(self, username, org, api_key, env, api_url, interactive):
def __init__(
self,
username: str,
org: str,
api_key: str,
env: str,
api_url: Optional[str],
interactive: bool,
) -> None:
self.username = username
self.org = org
self.api_key = api_key
Expand All @@ -134,7 +146,7 @@ def __init__(self, username, org, api_key, env, api_url, interactive):
self.interactive = interactive
self.cve_api = self.init_cve_api()

def init_cve_api(self):
def init_cve_api(self) -> CveApi:
return CveApi(
username=self.username,
org=self.org,
Expand Down Expand Up @@ -189,7 +201,15 @@ def init_cve_api(self):
__version__, "-V", "--version", prog_name="cvelib", message="%(prog)s %(version)s"
)
@click.pass_context
def cli(ctx, username, org, api_key, env, api_url, interactive):
def cli(
ctx: click.Context,
username: str,
org: str,
api_key: str,
env: str,
api_url: Optional[str],
interactive: bool,
) -> None:
"""A CLI interface for the CVE Services API."""
ctx.obj = Config(username, org, api_key, env, api_url, interactive)

Expand All @@ -215,7 +235,7 @@ def cli(ctx, username, org, api_key, env, api_url, interactive):
@click.argument("count", default=1, type=click.IntRange(min=1))
@click.pass_context
@handle_cve_api_error
def reserve(ctx, random, year, count, print_raw):
def reserve(ctx: click.Context, random: bool, year: str, count: int, print_raw: bool) -> None:
"""Reserve one or more CVE IDs. COUNT is the number of CVEs to reserve; defaults to 1.
CVE IDs can be reserved one by one (the lowest IDs are reserved first) or in batches of
Expand Down Expand Up @@ -265,7 +285,7 @@ def reserve(ctx, random, year, count, print_raw):
@click.argument("cve_id", callback=validate_cve)
@click.pass_context
@handle_cve_api_error
def show_cve(ctx, print_raw, cve_id):
def show_cve(ctx: click.Context, print_raw: bool, cve_id: str) -> None:
"""Display a specific CVE ID owned by your CNA."""
cve_api = ctx.obj.cve_api
cve = cve_api.show_cve(cve_id=cve_id)
Expand Down Expand Up @@ -297,7 +317,7 @@ def show_cve(ctx, print_raw, cve_id):
)
@click.pass_context
@handle_cve_api_error
def list_cves(ctx, print_raw, sort_by, **query):
def list_cves(ctx: click.Context, print_raw: bool, sort_by: str, **query: dict) -> None:
"""Filter and list reserved CVE IDs owned by your CNA."""
cve_api = ctx.obj.cve_api
cves = list(cve_api.list_cves(**query))
Expand Down Expand Up @@ -338,7 +358,7 @@ def list_cves(ctx, print_raw, sort_by, **query):
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
@handle_cve_api_error
def quota(ctx, print_raw):
def quota(ctx: click.Context, print_raw: bool) -> None:
"""Display the available CVE ID quota for your CNA.
\b
Expand Down Expand Up @@ -370,7 +390,7 @@ def quota(ctx, print_raw):
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
@handle_cve_api_error
def show_user(ctx, username, print_raw):
def show_user(ctx: click.Context, username: Optional[str], print_raw: bool) -> None:
"""Show information about a user."""
if ctx.invoked_subcommand is not None:
return
Expand All @@ -396,7 +416,7 @@ def show_user(ctx, username, print_raw):
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
@handle_cve_api_error
def reset_key(ctx, username, print_raw):
def reset_key(ctx: click.Context, username: Optional[str], print_raw: bool) -> None:
"""Reset a user's personal access token (API key).
This API key is used to authenticate each request to the CVE API.
Expand Down Expand Up @@ -435,7 +455,7 @@ def reset_key(ctx, username, print_raw):
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
@handle_cve_api_error
def update_user(ctx, username, **opts_data):
def update_user(ctx: click.Context, username: Optional[str], **opts_data: dict) -> None:
"""Update a user.
To reset a user's API key, use `cve user reset-key`.
Expand All @@ -454,7 +474,7 @@ def update_user(ctx, username, **opts_data):
opt = "active_roles." + opt.replace("_role", "")
elif opt == "active":
# Convert boolean to string since this data is passed as query params
value = str(value).lower()
value = str(value).lower() # type: ignore
user_updates[opt] = value

if not user_updates:
Expand Down Expand Up @@ -489,15 +509,17 @@ def update_user(ctx, username, **opts_data):
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
@handle_cve_api_error
def create_user(ctx, username, name_first, name_last, roles, print_raw):
def create_user(
ctx: click.Context, username: str, name_first: str, name_last: str, roles: list, print_raw: bool
) -> None:
"""Create a user in your organization.
This action is restricted to users with the ADMIN role.
Note: Once a user is created, they cannot be removed, only marked as inactive. Only create
users when you really need them.
"""
user_data = defaultdict(dict)
user_data: DefaultDict = defaultdict(dict)
user_data["username"] = username

if name_first:
Expand Down Expand Up @@ -539,7 +561,7 @@ def create_user(ctx, username, name_first, name_last, roles, print_raw):
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
@handle_cve_api_error
def show_org(ctx, print_raw):
def show_org(ctx: click.Context, print_raw: bool) -> None:
"""Show information about your organization."""
if ctx.invoked_subcommand is not None:
return
Expand All @@ -560,7 +582,7 @@ def show_org(ctx, print_raw):
@click.option("--raw", "print_raw", default=False, is_flag=True, help="Print response JSON.")
@click.pass_context
@handle_cve_api_error
def users(ctx, print_raw):
def users(ctx: click.Context, print_raw: bool) -> None:
"""List all users in your organization."""
cve_api = ctx.obj.cve_api
org_users = list(cve_api.list_users())
Expand Down Expand Up @@ -588,7 +610,7 @@ def users(ctx, print_raw):

@cli.command()
@click.pass_context
def ping(ctx):
def ping(ctx: click.Context) -> None:
"""Ping the CVE Services API to see if it is up."""
cve_api = ctx.obj.cve_api
ok, error_msg = cve_api.ping()
Expand Down
52 changes: 32 additions & 20 deletions cvelib/cve_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime
from typing import Iterator, Optional, Tuple
from urllib.parse import urljoin

import requests
Expand All @@ -17,15 +19,19 @@ class CveApi:
}
USER_ROLES = ("ADMIN",)

def __init__(self, username, org, api_key, env="prod", url=None):
def __init__(
self, username: str, org: str, api_key: str, env: str = "prod", url: Optional[str] = None
) -> None:
self.username = username
self.org = org
self.api_key = api_key
self.url = url or self.ENVS.get(env)
if not self.url:
raise ValueError("Missing URL for CVE API")
if not url:
url = self.ENVS.get(env)
if not url:
raise ValueError("Missing URL for CVE API")
self.url = url

def http_request(self, method, path, **kwargs):
def http_request(self, method: str, path: str, **kwargs) -> requests.Response:
url = urljoin(self.url, path)
headers = {
"CVE-API-KEY": self.api_key,
Expand Down Expand Up @@ -53,10 +59,10 @@ def http_request(self, method, path, **kwargs):

return response

def get(self, path, **kwargs):
def get(self, path: str, **kwargs) -> requests.Response:
return self.http_request("get", path, **kwargs)

def get_paged(self, path, page_data_attr, params, **kwargs):
def get_paged(self, path: str, page_data_attr: str, params: dict, **kwargs) -> Iterator[dict]:
"""Get data from a paged endpoint.
CVE Services 1.1.0 added pagination on responses longer than the default page size. For
Expand All @@ -81,13 +87,13 @@ def get_paged(self, path, page_data_attr, params, **kwargs):
else:
break

def post(self, path, **kwargs):
def post(self, path: str, **kwargs) -> requests.Response:
return self.http_request("post", path, **kwargs)

def put(self, path, **kwargs):
def put(self, path: str, **kwargs) -> requests.Response:
return self.http_request("put", path, **kwargs)

def reserve(self, count, random, year):
def reserve(self, count: int, random: bool, year: str) -> Tuple[dict, str]:
"""Reserve a set of CVE IDs.
This method returns a tuple containing the reserved CVE IDs, and the remaining ID quota
Expand All @@ -107,10 +113,16 @@ def reserve(self, count, random, year):
response = self.post("cve-id", params=params)
return response.json(), response.headers["CVE-API-REMAINING-QUOTA"]

def show_cve(self, cve_id):
def show_cve(self, cve_id: str) -> dict:
return self.get(f"cve-id/{cve_id}").json()

def list_cves(self, year=None, state=None, reserved_lt=None, reserved_gt=None):
def list_cves(
self,
year: str = None,
state: str = None,
reserved_lt: datetime = None,
reserved_gt: datetime = None,
) -> Iterator[dict]:
params = {}
if year:
params["cve_id_year"] = year
Expand All @@ -122,28 +134,28 @@ def list_cves(self, year=None, state=None, reserved_lt=None, reserved_gt=None):
params["time_reserved.gt"] = reserved_gt.isoformat()
return self.get_paged("cve-id", page_data_attr="cve_ids", params=params)

def quota(self):
def quota(self) -> dict:
return self.get(f"org/{self.org}/id_quota").json()

def show_user(self, username):
def show_user(self, username: str) -> dict:
return self.get(f"org/{self.org}/user/{username}").json()

def reset_api_key(self, username):
def reset_api_key(self, username: str) -> dict:
return self.put(f"org/{self.org}/user/{username}/reset_secret").json()

def create_user(self, **user_data):
def create_user(self, **user_data: dict) -> dict:
return self.post(f"org/{self.org}/user", json=user_data).json()

def update_user(self, username, **user_data):
def update_user(self, username, **user_data: dict) -> dict:
return self.put(f"org/{self.org}/user/{username}", params=user_data).json()

def list_users(self):
def list_users(self) -> Iterator[dict]:
return self.get_paged(f"org/{self.org}/users", page_data_attr="users", params={})

def show_org(self):
def show_org(self) -> dict:
return self.get(f"org/{self.org}").json()

def ping(self):
def ping(self) -> Tuple[bool, Optional[str]]:
"""Check the CVE API status.
Returns a tuple containing a boolean value of whether the request succeeded and any
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
[tool.black]
line-length = 100

[tool.mypy]
warn_unused_configs = true
warn_unreachable = true
warn_no_return = true
warn_unused_ignores = true

[build-system]
requires = [
"setuptools >= 40.9.0",
Expand Down
Loading

0 comments on commit dfb99b7

Please sign in to comment.