diff --git a/src/zdns/dnssec.go b/src/zdns/dnssec.go index 80c111de..2059134f 100644 --- a/src/zdns/dnssec.go +++ b/src/zdns/dnssec.go @@ -33,15 +33,10 @@ func (r *Resolver) validateChainOfDNSSECTrust(ctx context.Context, msg *dns.Msg, typeToRRSets := make(map[uint16][]dns.RR) typeToRRSigs := make(map[uint16][]*dns.RRSIG) - // Extract all the RRSets from the message - for _, rr := range msg.Answer { - rrType := rr.Header().Rrtype - if rrType == dns.TypeRRSIG { - rrSig := rr.(*dns.RRSIG) - typeToRRSigs[rrSig.TypeCovered] = append(typeToRRSigs[rrSig.TypeCovered], rrSig) - } else { - typeToRRSets[rrType] = append(typeToRRSets[rrType], rr) - } + if msg.Authoritative { + updateTypeMapWithRRs(typeToRRSets, typeToRRSigs, msg.Answer) + } else { + updateTypeMapWithRRs(typeToRRSets, typeToRRSigs, msg.Ns) } // Shortcut checks on RRSIG cardinality @@ -51,18 +46,20 @@ func (r *Resolver) validateChainOfDNSSECTrust(ctx context.Context, msg *dns.Msg, return false, trace, nil } - if len(typeToRRSets) != len(typeToRRSigs) { - return false, trace, errors.New("mismatched number of RRsets and RRSIGs") - } - // Verify if for each RRset there is a corresponding RRSIG + typeToRRSetsWithRRSIGs := make(map[uint16][]dns.RR) for rrType := range typeToRRSets { if _, ok := typeToRRSigs[rrType]; !ok { - return false, trace, fmt.Errorf("found RRset for type %s but no RRSIG", dns.TypeToString[rrType]) + if msg.Authoritative || isDNSSECRecordType(rrType) { + return false, trace, fmt.Errorf("found RRset for type %s but no RRSIG", dns.TypeToString[rrType]) + } + } else { + typeToRRSetsWithRRSIGs[rrType] = typeToRRSets[rrType] } } + typeToRRSets = typeToRRSetsWithRRSIGs - r.verboseLog(depth+1, fmt.Sprintf("DNSSEC: Found %d RRsets and %d RRSIGs", len(typeToRRSets), len(typeToRRSigs))) + r.verboseLog(depth+1, fmt.Sprintf("DNSSEC: Found %d RRsets with RRSIGs", len(typeToRRSets))) passed, trace, err := r.validateRRSIGs(ctx, typeToRRSets, typeToRRSigs, nameServer, isIterative, trace, depth) if err != nil { @@ -72,6 +69,27 @@ func (r *Resolver) validateChainOfDNSSECTrust(ctx context.Context, msg *dns.Msg, return passed, trace, nil } +func isDNSSECRecordType(rrType uint16) bool { + switch rrType { + case dns.TypeDNSKEY, dns.TypeRRSIG, dns.TypeDS, dns.TypeNSEC, dns.TypeNSEC3, dns.TypeNSEC3PARAM: + return true + default: + return false + } +} + +func updateTypeMapWithRRs(typeToRRSets map[uint16][]dns.RR, typeToRRSigs map[uint16][]*dns.RRSIG, rrs []dns.RR) { + for _, rr := range rrs { + rrType := rr.Header().Rrtype + if rrType == dns.TypeRRSIG { + rrSig := rr.(*dns.RRSIG) + typeToRRSigs[rrSig.TypeCovered] = append(typeToRRSigs[rrSig.TypeCovered], rrSig) + } else { + typeToRRSets[rrType] = append(typeToRRSets[rrType], rr) + } + } +} + // parseKSKsFromAnswer extracts only KSKs (Key Signing Keys) from a DNSKEY RRset answer, // populating a map where the KeyTag is the key and the DNSKEY is the value. // This function skips ZSKs and returns an error if any unexpected flags or types are encountered.