Skip to content

Commit

Permalink
refactor: Remove secret usage from Vault PKI implementation for CA CSR (
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielArndt authored Jul 23, 2024
1 parent e203912 commit aa35974
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 71 deletions.
100 changes: 57 additions & 43 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
KV_SECRET_PREFIX = "kv-creds-"
LOG_FORWARDING_RELATION_NAME = "logging"
PEER_RELATION_NAME = "vault-peers"
PKI_CSR_SECRET_LABEL = "pki-csr"
PKI_MOUNT = "charm-pki"
PKI_RELATION_NAME = "vault-pki"
PKI_ROLE_NAME = "charm"
Expand Down Expand Up @@ -394,8 +393,6 @@ def _configure(self, event: Optional[ConfigChangedEvent] = None) -> None: # noq
return
except VaultClientError:
return
if not all(self._get_approle_auth_secret()):
return
if not (vault := self._get_active_vault_client()):
return
self._configure_pki_secrets_engine()
Expand Down Expand Up @@ -484,26 +481,57 @@ def _configure_pki_secrets_engine(self) -> None:
if not self._common_name_config_is_valid():
logger.debug("Common name config is not valid, skipping")
return
common_name = self._get_config_common_name()
config_common_name = self._get_config_common_name()
vault.enable_secrets_engine(SecretsBackend.PKI, PKI_MOUNT)
if not self._is_intermediate_ca_common_name_valid(vault, common_name):
csr = vault.generate_pki_intermediate_ca_csr(mount=PKI_MOUNT, common_name=common_name)
if not self._common_name_matches_current_csr(config_common_name):
csr = vault.generate_pki_intermediate_ca_csr(
mount=PKI_MOUNT, common_name=config_common_name
)
self._revoke_issued_certificates_and_remove_csr()
self.tls_certificates_pki.request_certificate_creation(
certificate_signing_request=csr.encode(),
is_ca=True,
)
self._set_juju_secret(PKI_CSR_SECRET_LABEL, {"csr": csr})

def _is_intermediate_ca_common_name_valid(self, vault: Vault, common_name: str) -> bool:
def _revoke_issued_certificates_and_remove_csr(self) -> None:
"""Revoke all certificates issued by the PKI secrets engine and remove intermediate CA CSR."""
csrs = self.tls_certificates_pki.get_requirer_csrs()
# There should only be one of them, but for sanity we iterate the list.
for csr in csrs:
self.tls_certificates_pki.request_certificate_revocation(csr.csr.encode())
logger.info("Revoking all certificates issued by the PKI backend.")
self.vault_pki.revoke_all_certificates()

def _is_intermediate_ca_common_name_match(self, vault: Vault, common_name: str) -> bool:
"""Check if the intermediate CA is set with the valid common name."""
intermediate_ca = vault.get_intermediate_ca(mount=PKI_MOUNT)
if not intermediate_ca:
return False
intermediate_ca_common_name = get_common_name_from_certificate(intermediate_ca)
return intermediate_ca_common_name == common_name

def _is_intermediate_ca_set(self, vault: Vault, certificate: str) -> bool:
"""Check if the intermediate CA is set in the PKI secrets engine."""
def _get_pki_intermediate_ca_csr(self) -> Optional[str]:
"""Get the current CSR from the relation data."""
csrs = self.tls_certificates_pki.get_requirer_csrs()
if not csrs:
return None
assert len(csrs) == 1, f"Only one CSR should be available, found {len(csrs)}: {csrs}"
csr = csrs[0].csr
return csr

def _common_name_matches_current_csr(self, common_name: str) -> bool:
"""Return True if the common name provided matches what is in the CSR in the relation data."""
csr = self._get_pki_intermediate_ca_csr()
if not csr:
return False
csr_common_name = get_common_name_from_csr(csr)
is_match = csr_common_name == common_name
if not is_match:
logger.info("Common name changed from `%s` to `%s`", csr_common_name, common_name)
return is_match

def _vault_intermediate_ca_matches(self, vault: Vault, certificate: str) -> bool:
"""Check if the intermediate CA in the Vault PKI enginer is set to the certificate provided."""
intermediate_ca = vault.get_intermediate_ca(mount=PKI_MOUNT)
return certificate == intermediate_ca

Expand All @@ -520,19 +548,23 @@ def _add_ca_certificate_to_pki_secrets_engine(self) -> None:
if not certificate:
logger.debug("No certificate available")
return
common_name = self._get_config_common_name()
if not common_name:
config_common_name = self._get_config_common_name()
if not config_common_name:
logger.error("Common name is not set in the charm config")
return
if not self._is_intermediate_ca_common_name_valid(
vault, common_name
) or not self._is_intermediate_ca_set(vault, certificate):

certificate_common_name = get_common_name_from_certificate(certificate)
if (
config_common_name == certificate_common_name
and not self._vault_intermediate_ca_matches(vault, certificate)
):
vault.set_pki_intermediate_ca_certificate(certificate=certificate, mount=PKI_MOUNT)

if not vault.is_common_name_allowed_in_pki_role(
role=PKI_ROLE_NAME, mount=PKI_MOUNT, common_name=common_name
role=PKI_ROLE_NAME, mount=PKI_MOUNT, common_name=config_common_name
):
vault.create_or_update_pki_charm_role(
allowed_domains=common_name,
allowed_domains=config_common_name,
mount=PKI_MOUNT,
role=PKI_ROLE_NAME,
)
Expand Down Expand Up @@ -613,18 +645,14 @@ def _get_pki_ca_certificate(self) -> Optional[str]:
assigned_certificates = self.tls_certificates_pki.get_assigned_certificates()
if not assigned_certificates:
return None
if not self._pki_csr_secret_set():
logger.info("PKI CSR not set in secrets")
return None
pki_csr = self._get_pki_csr_secret()
if not pki_csr:
logger.warning("PKI CSR not found in secrets")
return None
for assigned_certificate in assigned_certificates:
if assigned_certificate.csr == pki_csr:
return assigned_certificate.certificate
logger.info("No certificate matches the PKI CSR in secrets")
return None
if len(assigned_certificates) > 1:
logger.warning(
f"Only one certificate should be available, found {len(assigned_certificates)}"
)

# Return the last certificate, as it is probably the most recent one.
# We shouldn't have more than 1 anyway.
return assigned_certificates[-1].certificate

def _generate_pki_certificate_for_requirer(self, csr: str, relation_id: int):
"""Generate a PKI certificate for a TLS requirer."""
Expand Down Expand Up @@ -1222,20 +1250,6 @@ def _push_config_file_to_workload(self, content: str):
self._container.push(path=VAULT_CONFIG_FILE_PATH, source=content)
logger.info("Pushed %s config file", VAULT_CONFIG_FILE_PATH)

def _get_pki_csr_secret(self) -> Optional[str]:
"""Return the PKI CSR secret."""
if not self._pki_csr_secret_set():
raise RuntimeError("PKI CSR secret not set.")
return self._get_juju_secret_field(PKI_CSR_SECRET_LABEL, "csr")

def _pki_csr_secret_set(self) -> bool:
"""Return whether PKI CSR secret is stored."""
try:
self.model.get_secret(label=PKI_CSR_SECRET_LABEL)
return True
except SecretNotFoundError:
return False

def _get_approle_auth_secret(self) -> Tuple[Optional[str], Optional[str]]:
"""Get the vault approle login details secret.
Expand Down
42 changes: 14 additions & 28 deletions tests/unit/test_charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
AUTOUNSEAL_MOUNT_PATH,
CHARM_POLICY_NAME,
CHARM_POLICY_PATH,
PKI_CSR_SECRET_LABEL,
PKI_RELATION_NAME,
S3_RELATION_NAME,
TLS_CERTIFICATES_PKI_RELATION_NAME,
Expand Down Expand Up @@ -191,25 +190,6 @@ def _set_ca_certificate_secret(self, private_key: str, certificate: str) -> None
secret.set_info(label=CA_CERTIFICATE_JUJU_SECRET_LABEL)
self.harness.set_leader(original_leader_state)

def _set_csr_secret_in_peer_relation(self, relation_id: int, csr: str) -> None:
"""Set the csr secret in the peer relation."""
content = {
"csr": csr,
}
original_leader_state = self.harness.charm.unit.is_leader()
with self.harness.hooks_disabled():
self.harness.set_leader(is_leader=True)
secret_id = self.harness.add_model_secret(owner=self.app_name, content=content)
secret = self.harness.model.get_secret(id=secret_id)
secret.set_info(label=PKI_CSR_SECRET_LABEL)
self.harness.set_leader(original_leader_state)
key_values = {"vault-pki-csr-secret-id": secret_id}
self.harness.update_relation_data(
app_or_unit=self.app_name,
relation_id=relation_id,
key_values=key_values,
)

def setup_vault_kv_relation(self) -> tuple:
app_name = VAULT_KV_REQUIRER_APPLICATION_NAME
unit_name = app_name + "/0"
Expand Down Expand Up @@ -1613,7 +1593,8 @@ def test_given_vault_pki_configured_when_common_name_is_changed_then_new_certifi
certificate_signing_request=csr.encode(), is_ca=True
)

@patch("charm.get_common_name_from_certificate", new=Mock)
@patch("charm.get_common_name_from_csr", new=Mock)
@patch("charm.get_common_name_from_certificate", new=Mock(return_value="vault"))
@patch(f"{TLS_CERTIFICATES_LIB_PATH}.TLSCertificatesRequiresV3.request_certificate_creation")
@patch(f"{TLS_CERTIFICATES_LIB_PATH}.TLSCertificatesRequiresV3.get_requirer_csrs")
@patch(f"{TLS_CERTIFICATES_LIB_PATH}.TLSCertificatesRequiresV3.get_provider_certificates")
Expand Down Expand Up @@ -1643,13 +1624,11 @@ def test_given_vault_is_available_when_pki_certificate_is_available_then_certifi
self.harness.update_config({"common_name": "vault"})
self.harness.set_leader(is_leader=True)
self.harness.set_can_connect(container=self.container_name, val=True)
peer_relation_id = self._set_peer_relation()
self._set_peer_relation()
self._set_approle_secret(
role_id="root token content",
secret_id="whatever secret id",
)
self._set_csr_secret_in_peer_relation(relation_id=peer_relation_id, csr=csr)

relation_id = self.harness.add_relation(
relation_name=TLS_CERTIFICATES_PKI_RELATION_NAME, remote_app="tls-provider"
)
Expand Down Expand Up @@ -1699,7 +1678,8 @@ def test_given_vault_is_available_when_pki_certificate_is_available_then_certifi
self.mock_vault.make_latest_pki_issuer_default.assert_called_with(mount=PKI_MOUNT)

@patch("ops.model.Container.restart", new=Mock)
@patch("charm.get_common_name_from_certificate", new=Mock)
@patch("charm.get_common_name_from_csr", new=Mock)
@patch("charm.get_common_name_from_certificate")
@patch(f"{TLS_CERTIFICATES_LIB_PATH}.TLSCertificatesRequiresV3.request_certificate_creation")
@patch(f"{TLS_CERTIFICATES_LIB_PATH}.TLSCertificatesRequiresV3.get_requirer_csrs")
@patch(f"{TLS_CERTIFICATES_LIB_PATH}.TLSCertificatesRequiresV3.get_provider_certificates")
Expand All @@ -1710,6 +1690,7 @@ def test_given_vault_pki_configured_when_common_name_is_changed_then_new_certifi
patch_get_provider_certificates,
patch_get_requirer_csrs,
patch_request_certificate_creation,
patch_get_common_name_from_certificate,
):
csr = "some csr content"
self.mock_vault.configure_mock(
Expand All @@ -1729,13 +1710,11 @@ def test_given_vault_pki_configured_when_common_name_is_changed_then_new_certifi
self.harness.update_config({"common_name": "vault"})
self.harness.set_leader(is_leader=True)
self.harness.set_can_connect(container=self.container_name, val=True)
peer_relation_id = self._set_peer_relation()
self._set_peer_relation()
self._set_approle_secret(
role_id="root token content",
secret_id="whatever secret id",
)
self._set_csr_secret_in_peer_relation(relation_id=peer_relation_id, csr=csr)

relation_id = self.harness.add_relation(
relation_name=TLS_CERTIFICATES_PKI_RELATION_NAME, remote_app="tls-provider"
)
Expand All @@ -1753,12 +1732,15 @@ def test_given_vault_pki_configured_when_common_name_is_changed_then_new_certifi
patch_get_provider_certificates.return_value = [provider_certificate]
patch_get_requirer_csrs.return_value = [Mock(csr=csr)]
self.harness.add_storage("config", attach=True)
patch_get_common_name_from_certificate.return_value = "new_common_name"

# Reset mock counts, in case they were called during setup
self.mock_vault.reset_mock()

# When
self.harness.update_config({"common_name": "new_common_name"})

# Then
self.mock_vault.set_pki_intermediate_ca_certificate.assert_called_with(
certificate=certificate,
mount=PKI_MOUNT,
Expand All @@ -1767,6 +1749,10 @@ def test_given_vault_pki_configured_when_common_name_is_changed_then_new_certifi
allowed_domains="new_common_name", mount=PKI_MOUNT, role=PKI_ROLE_NAME
)

@patch(
f"{TLS_CERTIFICATES_LIB_PATH}.TLSCertificatesRequiresV3.request_certificate_creation",
new=Mock(),
)
@patch(f"{TLS_CERTIFICATES_LIB_PATH}.TLSCertificatesProvidesV3.set_relation_certificate")
@patch("charm.get_common_name_from_certificate")
@patch("charm.get_common_name_from_csr")
Expand Down

0 comments on commit aa35974

Please sign in to comment.