diff --git a/checks/tasks/rpki.py b/checks/tasks/rpki.py index 9614facff..1886b74ab 100644 --- a/checks/tasks/rpki.py +++ b/checks/tasks/rpki.py @@ -122,8 +122,8 @@ def callback(results: Mapping[TestName, TestResult], domain, parent, parent_name return parent, results -web_registered = check_registry("web_rpki", web_callback, shared.resolve_a_aaaa) -batch_web_registered = check_registry("batch_web_rpki", batch_web_callback, shared.batch_resolve_a_aaaa) +web_registered = check_registry("web_rpki", web_callback, shared.resolve_all_a_aaaa) +batch_web_registered = check_registry("batch_web_rpki", batch_web_callback, shared.batch_resolve_all_a_aaaa) mail_registered = check_registry("mail_rpki", mail_callback, shared.resolve_mx) batch_mail_registered = check_registry("batch_mail_rpki", batch_mail_callback, shared.batch_resolve_mx) diff --git a/checks/tasks/shared.py b/checks/tasks/shared.py index 9f07d5fb0..5dd553e10 100644 --- a/checks/tasks/shared.py +++ b/checks/tasks/shared.py @@ -64,6 +64,26 @@ def batch_resolve_a_aaaa(self, qname, *args, **kwargs): return do_resolve_a_aaaa(self, qname, *args, **kwargs) +@shared_task( + bind=True, + soft_time_limit=settings.SHARED_TASK_SOFT_TIME_LIMIT_HIGH, + time_limit=settings.SHARED_TASK_TIME_LIMIT_HIGH, + base=SetupUnboundContext, +) +def resolve_all_a_aaaa(self, qname, *args, **kwargs): + return do_resolve_all_a_aaaa(self, qname, *args, **kwargs) + + +@batch_shared_task( + bind=True, + soft_time_limit=settings.BATCH_SHARED_TASK_SOFT_TIME_LIMIT_HIGH, + time_limit=settings.BATCH_SHARED_TASK_TIME_LIMIT_HIGH, + base=SetupUnboundContext, +) +def batch_resolve_all_a_aaaa(self, qname, *args, **kwargs): + return do_resolve_all_a_aaaa(self, qname, *args, **kwargs) + + @shared_task( bind=True, soft_time_limit=settings.SHARED_TASK_SOFT_TIME_LIMIT_HIGH, @@ -162,6 +182,18 @@ def do_resolve_a_aaaa(self, qname, *args, **kwargs): return af_ip_pairs +def do_resolve_all_a_aaaa(self, qname, *args, **kwargs): + """Resolve all A and AAAA records and return all results for each type.""" + af_ip_pairs = [] + ip4 = self.resolve(qname, unbound.RR_TYPE_A) + for ip in ip4: + af_ip_pairs.append((socket.AF_INET, ip)) + ip6 = self.resolve(qname, unbound.RR_TYPE_AAAA) + for ip in ip6: + af_ip_pairs.append((socket.AF_INET6, ip)) + return af_ip_pairs + + def do_resolve_mx_ips(self, url, *args, **kwargs): """Resolve the domain's mailservers returns [(mailserver, af_ip_pairs)] @@ -172,13 +204,7 @@ def do_resolve_mx_ips(self, url, *args, **kwargs): if status is not MxStatus.has_mx: continue - af_ip_pairs = [] - ip4 = self.resolve(qname, unbound.RR_TYPE_A) - for ip in ip4: - af_ip_pairs.append((socket.AF_INET, ip)) - ip6 = self.resolve(qname, unbound.RR_TYPE_AAAA) - for ip in ip6: - af_ip_pairs.append((socket.AF_INET6, ip)) + af_ip_pairs = do_resolve_all_a_aaaa(self, url, *args, **kwargs) mx_ips_pairs.append((qname, af_ip_pairs)) return mx_ips_pairs @@ -195,7 +221,7 @@ def do_resolve_ns_ips(self, url, *args, **kwargs): next_label = next_label[next_label.find(".") + 1 :] for rr in rrset: - yield (rr, do_resolve_a_aaaa(self, rr)) + yield (rr, do_resolve_all_a_aaaa(self, rr)) def resolve_dane(task, port, dname, check_nxdomain=False):