Skip to content

Commit

Permalink
style: adapt ruff formatting #138
Browse files Browse the repository at this point in the history
  • Loading branch information
knrdl committed Nov 13, 2024
1 parent 33e418d commit 0f6f88e
Show file tree
Hide file tree
Showing 25 changed files with 427 additions and 263 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
python -m flake8 --max-line-length 179 --ignore=F722,B008,I001,I004,I005 .
pylint --max-line-length=179 --recursive=yes --disable=too-many-branches,no-else-return,broad-exception-caught,missing-module-docstring,missing-class-docstring,missing-function-docstring .
ruff check .
ruff format --check .
5 changes: 2 additions & 3 deletions app/acme/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
class ACMEResponse(JSONResponse):
def render(self, content: dict[str, Any] | None) -> bytes:
return super().render( # remove null fields from responses
{k: v for k, v in content.items() if v is not None}
if content is not None else None
{k: v for k, v in content.items() if v is not None} if content is not None else None
)


Expand All @@ -36,5 +35,5 @@ def render(self, content: dict[str, Any] | None) -> bytes:
async def start_cronjobs():
await asyncio.gather(
certificate_cronjob.start(),
nonce_cronjob.start()
nonce_cronjob.start(),
)
48 changes: 26 additions & 22 deletions app/acme/account/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
tosAgreedType = Literal[True] if settings.acme.terms_of_service_url else (bool | None)
contactType = conlist(
constr(strip_whitespace=True, to_lower=True, pattern=f'^mailto:{settings.acme.mail_target_regex.pattern}$'),
min_length=1, max_length=1
min_length=1,
max_length=1,
)


Expand Down Expand Up @@ -51,17 +52,15 @@ def mail_addr(self) -> str | None:
@api.post('/new-account')
async def create_or_view_account(
response: Response,
data: Annotated[RequestData[NewOrViewAccountPayload], Depends(SignedRequest(NewOrViewAccountPayload, allow_new_account=True))]
data: Annotated[RequestData[NewOrViewAccountPayload], Depends(SignedRequest(NewOrViewAccountPayload, allow_new_account=True))],
):
"""
https://www.rfc-editor.org/rfc/rfc8555.html#section-7.3
"""
jwk_json: dict = data.key.export(as_dict=True)

async with db.transaction() as sql:
result = await sql.record(
'select id, mail, status from accounts where jwk=$1 and (id=$2 or $2::text is null)',
jwk_json, data.account_id)
result = await sql.record("""select id, mail, status from accounts where jwk=$1 and (id=$2 or $2::text is null)""", jwk_json, data.account_id)
account_exists = bool(result)

if account_exists:
Expand All @@ -75,10 +74,12 @@ async def create_or_view_account(
mail_addr = payload.mail_addr
account_id = secrets.token_urlsafe(16)
async with db.transaction() as sql:
account_status = await sql.value("""
insert into accounts (id, mail, jwk) values ($1, $2, $3)
returning status
""", account_id, mail_addr, jwk_json)
account_status = await sql.value(
"""insert into accounts (id, mail, jwk) values ($1, $2, $3) returning status""",
account_id,
mail_addr,
jwk_json,
)
try:
await mail.send_new_account_info_mail(mail_addr)
except Exception:
Expand All @@ -89,7 +90,7 @@ async def create_or_view_account(
return {
'status': account_status,
'contact': ['mailto:' + mail_addr],
'orders': f'{settings.external_url}acme/accounts/{account_id}/orders'
'orders': f'{settings.external_url}acme/accounts/{account_id}/orders',
}


Expand All @@ -100,35 +101,38 @@ async def change_key(data: Annotated[RequestData, Depends(SignedRequest())]):

@api.post('/accounts/{acc_id}')
async def view_or_update_account(
acc_id: str,
data: Annotated[RequestData[UpdateAccountPayload],
Depends(SignedRequest(UpdateAccountPayload, allow_blocked_account=True))]
acc_id: str,
data: Annotated[RequestData[UpdateAccountPayload], Depends(SignedRequest(UpdateAccountPayload, allow_blocked_account=True))],
):
if acc_id != data.account_id:
raise ACMEException(status_code=status.HTTP_403_FORBIDDEN, exctype='unauthorized', detail='wrong kid', new_nonce=data.new_nonce)

if data.payload.contact:
async with db.transaction() as sql:
await sql.exec("update accounts set mail=$1 where id = $2 and status = 'valid'", data.payload.mail_addr, acc_id)
await sql.exec("""update accounts set mail=$1 where id = $2 and status = 'valid'""", data.payload.mail_addr, acc_id)
try:
await mail.send_new_account_info_mail(data.payload.mail_addr)
except Exception:
logger.error('could not send new account mail to "%s"', data.payload.mail_addr, exc_info=True)

if data.payload.status == 'deactivated': # https://www.rfc-editor.org/rfc/rfc8555#section-7.3.6
async with db.transaction() as sql:
await sql.exec("update accounts set status='deactivated' where id = $1", acc_id)
await sql.exec("""
update orders set status='invalid', error=row('unauthorized','account deactived') where account_id = $1 and status <> 'invalid'
""", acc_id)
await sql.exec("""update accounts set status='deactivated' where id = $1""", acc_id)
await sql.exec(
"""
update orders set status='invalid', error=row('unauthorized','account deactived')
where account_id = $1 and status <> 'invalid'
""",
acc_id,
)

async with db.transaction(readonly=True) as sql:
account_status, mail_addr = await sql.record('select status, mail from accounts where id = $1', acc_id)
account_status, mail_addr = await sql.record("""select status, mail from accounts where id = $1""", acc_id)

return {
'status': account_status,
'contact': ['mailto:' + mail_addr],
'orders': f'{settings.external_url}acme/accounts/{acc_id}/orders'
'orders': f'{settings.external_url}acme/accounts/{acc_id}/orders',
}


Expand All @@ -137,7 +141,7 @@ async def view_orders(acc_id: str, data: Annotated[RequestData, Depends(SignedRe
if acc_id != data.account_id:
raise ACMEException(status_code=status.HTTP_403_FORBIDDEN, exctype='unauthorized', detail='wrong account id provided', new_nonce=data.new_nonce)
async with db.transaction(readonly=True) as sql:
orders = [order_id async for order_id, *_ in sql("select id from orders where account_id = $1 and status <> 'invalid'", acc_id)]
orders = [order_id async for order_id, *_ in sql("""select id from orders where account_id = $1 and status <> 'invalid'""", acc_id)]
return {
'orders': [f'{settings.external_url}acme/orders/{order_id}' for order_id in orders]
'orders': [f'{settings.external_url}acme/orders/{order_id}' for order_id in orders],
}
22 changes: 14 additions & 8 deletions app/acme/authorization/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,39 @@ class UpdateAuthzPayload(BaseModel):
@api.post('/authorizations/{authz_id}')
async def view_or_update_authorization(
authz_id: str,
data: Annotated[RequestData[Optional[UpdateAuthzPayload]],
Depends(SignedRequest(Optional[UpdateAuthzPayload]))]
data: Annotated[RequestData[Optional[UpdateAuthzPayload]], Depends(SignedRequest(Optional[UpdateAuthzPayload]))],
):
async with db.transaction(readonly=True) as sql:
record = await sql.record("""
record = await sql.record(
"""
select authz.status, ord.status, ord.expires_at, authz.domain, chal.id, chal.token, chal.status, chal.validated_at
from authorizations authz
join challenges chal on chal.authz_id = authz.id
join orders ord on authz.order_id = ord.id
where authz.id = $1 and ord.account_id = $2
""", authz_id, data.account_id)
""",
authz_id,
data.account_id,
)
if record:
authz_status, order_status, expires_at, domain, chal_id, chal_token, chal_status, chal_validated_at = record
if data.payload and data.payload.status == 'deactivated': # deactivate authz
if authz_status in ['pending', 'valid'] and order_status in ['pending', 'ready']:
async with db.transaction() as sql:
await sql.exec("""
await sql.exec(
"""
update orders set status='invalid', error=row('unauthorized','authorization deactivated')
where id = (select order_id from authorizations where id = $1)
""", authz_id)
authz_status = await sql.value("update authorizations set status = 'deactivated' where id = $1 returning status", authz_id)
""",
authz_id,
)
authz_status = await sql.value("""update authorizations set status = 'deactivated' where id = $1 returning status""", authz_id)
chal = {
'type': 'http-01',
'url': f'{settings.external_url}acme/challenges/{chal_id}',
'token': chal_token,
'status': chal_status,
'validated': chal_validated_at
'validated': chal_validated_at,
}

return {
Expand Down
63 changes: 35 additions & 28 deletions app/acme/certificate/cronjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,39 @@ async def run():
while True:
try:
async with db.transaction(readonly=True) as sql:
results = [record async for record in sql("""
with
expiring_domains as (
select authz.domain, acc.mail, cert.serial_number, cert.not_valid_after from certificates cert
join orders ord on cert.order_id = ord.id
join accounts acc on ord.account_id = acc.id
join authorizations authz on authz.order_id = ord.id
where acc.status = 'valid' and ord.status = 'valid' and cert.revoked_at is null and (
($1::interval is not null and cert.not_valid_after > now() and cert.not_valid_after < now() + $1 and not cert.user_informed_cert_will_expire)
or
(cert.not_valid_after < now() and not cert.user_informed_cert_has_expired)
results = [
record
async for record in sql(
"""
with
expiring_domains as (
select authz.domain, acc.mail, cert.serial_number, cert.not_valid_after from certificates cert
join orders ord on cert.order_id = ord.id
join accounts acc on ord.account_id = acc.id
join authorizations authz on authz.order_id = ord.id
where acc.status = 'valid' and ord.status = 'valid' and cert.revoked_at is null and (
($1::interval is not null and cert.not_valid_after > now() and cert.not_valid_after < now()+$1 and not cert.user_informed_cert_will_expire)
or
(cert.not_valid_after < now() and not cert.user_informed_cert_has_expired)
)
order by authz.domain
),
newest_domains as (
select authz.domain, max(cert.not_valid_after) as not_valid_after from orders ord
join authorizations authz on authz.order_id = ord.id
join certificates cert on cert.order_id = ord.id
join expiring_domains exp on exp.domain = authz.domain
group by authz.domain
)
order by authz.domain
),
newest_domains as (
select authz.domain, max(cert.not_valid_after) as not_valid_after from orders ord
join authorizations authz on authz.order_id = ord.id
join certificates cert on cert.order_id = ord.id
join expiring_domains exp on exp.domain = authz.domain
group by authz.domain
)
select expd.mail, expd.serial_number, expd.not_valid_after, expd.not_valid_after < now() as is_expired, array_agg(expd.domain) as domains
from expiring_domains expd
join newest_domains newd on expd.domain = newd.domain and expd.not_valid_after = newd.not_valid_after
group by expd.mail, expd.serial_number, expd.not_valid_after
having array_length(array_agg(expd.domain), 1) > 0
""", settings.mail.warn_before_cert_expires)]
select expd.mail, expd.serial_number, expd.not_valid_after, expd.not_valid_after < now() as is_expired, array_agg(expd.domain) as domains
from expiring_domains expd
join newest_domains newd on expd.domain = newd.domain and expd.not_valid_after = newd.not_valid_after
group by expd.mail, expd.serial_number, expd.not_valid_after
having array_length(array_agg(expd.domain), 1) > 0
""",
settings.mail.warn_before_cert_expires,
)
]
for mail_addr, serial_number, expires_at, is_expired, domains in results:
if not is_expired and settings.mail.warn_before_cert_expires:
try:
Expand All @@ -48,7 +54,7 @@ async def run():
ok = False
if ok:
async with db.transaction() as sql:
await sql.exec('update certificates set user_informed_cert_will_expire=true where serial_number=$1', serial_number)
await sql.exec("""update certificates set user_informed_cert_will_expire=true where serial_number=$1""", serial_number)
if is_expired and settings.mail.notify_when_cert_expired:
try:
await mail.send_certs_expired_info_mail(receiver=mail_addr, domains=domains, expires_at=expires_at, serial_number=serial_number)
Expand All @@ -58,10 +64,11 @@ async def run():
ok = False
if ok:
async with db.transaction() as sql:
await sql.exec('update certificates set user_informed_cert_has_expired=true where serial_number=$1', serial_number)
await sql.exec("""update certificates set user_informed_cert_has_expired=true where serial_number=$1""", serial_number)
except Exception:
logger.error('could not inform about expiring certificates', exc_info=True)
finally:
await asyncio.sleep(1 * 60 * 60)

if settings.mail.notify_when_cert_expired or settings.mail.warn_before_cert_expires:
asyncio.create_task(run())
41 changes: 26 additions & 15 deletions app/acme/certificate/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,25 @@ class RevokeCertPayload(BaseModel):
api = APIRouter(tags=['acme:certificate'])


@api.post('/certificates/{serial_number}', response_class=Response, responses={
200: {'content': {'application/pem-certificate-chain': {}}}
})
@api.post('/certificates/{serial_number}', response_class=Response, responses={200: {'content': {'application/pem-certificate-chain': {}}}})
async def download_cert(
response: Response, serial_number: constr(pattern='^[0-9A-F]+$'),
response: Response,
serial_number: constr(pattern='^[0-9A-F]+$'),
data: Annotated[RequestData, Depends(SignedRequest())],
accept: str = Header(default='*/*', pattern=r'(application/pem\-certificate\-chain|\*/\*)',
description='Certificates are only supported as "application/pem-certificate-chain"')
accept: str = Header(
default='*/*', pattern=r'(application/pem\-certificate\-chain|\*/\*)', description='Certificates are only supported as "application/pem-certificate-chain"'
),
):
async with db.transaction(readonly=True) as sql:
pem_chain = await sql.value("""
pem_chain = await sql.value(
"""
select cert.chain_pem from certificates cert
join orders ord on cert.order_id = ord.id
where cert.serial_number = $1 and ord.account_id = $2
""", serial_number, data.account_id)
""",
serial_number,
data.account_id,
)
if not pem_chain:
raise ACMEException(status_code=status.HTTP_404_NOT_FOUND, exctype='malformed', detail='specified certificate not found for current account', new_nonce=data.new_nonce)
return Response(content=pem_chain, headers=response.headers, media_type='application/pem-certificate-chain')
Expand All @@ -51,23 +55,30 @@ async def revoke_cert(data: Annotated[RequestData[RevokeCertPayload], Depends(Si
cert = await parse_cert(cert_bytes)
serial_number = SerialNumberConverter.int2hex(cert.serial_number)
async with db.transaction(readonly=True) as sql:
ok = await sql.value("""
ok = await sql.value(
"""
select true from certificates c
join orders o on o.id = c.order_id
join accounts a on a.id = o.account_id
where
c.serial_number = $1 and c.revoked_at is null and
($2::text is null or (a.id = $2::text and a.status='valid')) and a.jwk=$3
""", serial_number, data.account_id, jwk_json)
""",
serial_number,
data.account_id,
jwk_json,
)
if not ok:
raise ACMEException(status_code=status.HTTP_400_BAD_REQUEST, exctype='alreadyRevoked', detail='cert already revoked or not accessible', new_nonce=data.new_nonce)
async with db.transaction(readonly=True) as sql:
revocations = [(sn, rev_at) async for sn, rev_at in sql('select serial_number, revoked_at from certificates where revoked_at is not null')]
revoked_at = await sql.value('select now()')
revocations = [(sn, rev_at) async for sn, rev_at in sql("""select serial_number, revoked_at from certificates where revoked_at is not null""")]
revoked_at = await sql.value("""select now()""")
revocations = set(revocations)
revocations.add((serial_number, revoked_at))
await ca_service.revoke_cert(serial_number=serial_number, revocations=revocations)
async with db.transaction() as sql:
await sql.exec("""
update certificates set revoked_at = $2 where serial_number = $1 and revoked_at is null
""", serial_number, revoked_at)
await sql.exec(
"""update certificates set revoked_at = $2 where serial_number = $1 and revoked_at is null""",
serial_number,
revoked_at,
)
4 changes: 1 addition & 3 deletions app/acme/certificate/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ async def check_csr(csr_der: bytes, ordered_domains: list[str], new_nonce: str |
if not csr.is_signature_valid:
raise ACMEException(status_code=status.HTTP_400_BAD_REQUEST, exctype='badCSR', detail='invalid signature', new_nonce=new_nonce)

sans = csr.extensions.get_extension_for_oid(
x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME
).value.get_values_for_type(x509.DNSName)
sans = csr.extensions.get_extension_for_oid(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value.get_values_for_type(x509.DNSName)
csr_domains = set(sans)
subject_candidates = csr.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)
if subject_candidates:
Expand Down
Loading

0 comments on commit 0f6f88e

Please sign in to comment.