diff --git a/requirements.txt b/requirements.txt index 901eb69..aed8903 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ httpx[http2]==0.27.2 pydantic==1.10.13 +validators==0.34.0 diff --git a/talosintelligence_connector.py b/talosintelligence_connector.py index 92b395b..153700a 100644 --- a/talosintelligence_connector.py +++ b/talosintelligence_connector.py @@ -24,6 +24,10 @@ import ipaddress import time import random +import validators +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from datetime import datetime class RetVal(tuple): @@ -160,11 +164,13 @@ def _make_rest_call(self, retry, endpoint, action_result, method="get", **kwargs url, **kwargs ) + self.debug_print(f"got this return value {r}") except Exception as e: self.debug_print(f"Retrying to establish connection to the server for the {i + 1} time") - jittered_delay = random.randint(delay * 0.9, delay * 1.1) + self.debug_print(e) + jittered_delay = random.uniform(delay * 0.9, delay * 1.1) time.sleep(jittered_delay) - delay = max(delay * 2, 256) + delay = min(delay * 2, 256) with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix="test") as temp_file: cert = f"-----BEGIN CERTIFICATE-----\n{self._cert}\n-----END CERTIFICATE-----\n-----BEGIN RSA PRIVATE KEY-----\n{self._key}\n-----END RSA PRIVATE KEY-----\n" @@ -245,6 +251,8 @@ def _handle_domain_reputation(self, param): action_result = self.add_action_result(ActionResult(dict(param))) domain = param['domain'] + if not validators.domain(domain): + return action_result.set_status(phantom.APP_ERROR, "Please provide a valid url") ips = param.get("ips", "") ips_list = [item.strip() for item in ips.split(',') if item.strip()] url_entry = {"raw_url": domain} @@ -280,6 +288,9 @@ def _handle_url_reputation(self, param): action_result = self.add_action_result(ActionResult(dict(param))) url = param['url'] + if not validators.url(url): + return action_result.set_status(phantom.APP_ERROR, "Please provide a valid url") + ips = param.get("ips", "") ips_list = [item.strip() for item in ips.split(',') if item.strip()] url_entry = {"raw_url": url} @@ -425,6 +436,45 @@ def handle_action(self, param): return ret_val + def check_certificate_expiry(self, cert): + not_before = cert.not_valid_before + not_after = cert.not_valid_after + now = datetime.utcnow() + return not_before <= now <= not_after + + def fetch_crls(self, cert): + try: + crl_distribution_points = cert.extensions.get_extension_for_oid( + x509.ExtensionOID.CRL_DISTRIBUTION_POINTS + ).value + + crl_urls = [] + + for point in crl_distribution_points: + for general_name in point.full_name: + if isinstance(general_name, x509.DNSName): + crl_urls.append(f"http://{general_name.value}") + elif isinstance(general_name, x509.UniformResourceIdentifier): + crl_urls.append(general_name.value) + + return crl_urls + except x509.ExtensionNotFound: + self.debug_print("CRL Distribution Points extension not found in the certificate.") + return [] + + def cert_revoked(self, cert, crl_url): + response = requests.get(crl_url) + response.raise_for_status() + + crl = x509.load_der_x509_crl(response.content, default_backend()) + revoked_certificates = crl.revoked_certificates or [] + self.debug_print(f"crl url is {crl} and revoked certs are {revoked_certificates}") + for revoked_cert in revoked_certificates: + if revoked_cert.serial_number == cert.serial_number: + return True + + return False + def initialize(self): # Load the state in initialize, use it to store data # that needs to be accessed across actions @@ -440,6 +490,21 @@ def insert_newlines(string, every=64): self._cert = insert_newlines(config["certificate"]) self._key = insert_newlines(config["key"]) + cert_string = f"-----BEGIN CERTIFICATE-----\n{self._cert}\n-----END CERTIFICATE-----" + cert_pem_data = cert_string.encode("utf-8") + cert = x509.load_pem_x509_certificate(cert_pem_data, default_backend()) + crl_urls = self.fetch_crls(cert) + self.debug_print(f"crl urls are {crl_urls}") + for crl in crl_urls: + if self.cert_revoked(cert, crl): + self.debug_print("Certificate has been revoked. Please get a new one") + return phantom.APP_ERROR + + is_valid = self.check_certificate_expiry(cert) + if not is_valid: + self.debug_print("Certificate is expired. Please use a valid cert") + return phantom.APP_ERROR + self._appinfo = { "product_family": "splunk", "product_id": "soar", @@ -451,7 +516,7 @@ def insert_newlines(string, every=64): self._appinfo["perf_testing"] = True with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix="test") as temp_file: - cert = f"-----BEGIN CERTIFICATE-----\n{self._cert}\n-----END CERTIFICATE-----\n-----BEGIN RSA PRIVATE KEY-----\n{self._key}\n-----END RSA PRIVATE KEY-----\n" + cert = f"{cert_string}\n-----BEGIN RSA PRIVATE KEY-----\n{self._key}\n-----END RSA PRIVATE KEY-----\n" temp_file.write(cert) temp_file.seek(0) # Move the file pointer to the beginning for reading