diff --git a/dns/server.go b/dns/server.go index 7e9ebff..12e297a 100644 --- a/dns/server.go +++ b/dns/server.go @@ -1,8 +1,9 @@ package dns import ( + "net" "strconv" - "sync" + "strings" "sync/atomic" "time" @@ -13,8 +14,9 @@ import ( ) type record struct { - rr dns.RR - name *regexp2.Regexp + rtype uint16 + rvalue string + name *regexp2.Regexp } type filter struct { name *regexp2.Regexp @@ -22,28 +24,42 @@ type filter struct { } type server struct { - records []*record - records_lock sync.RWMutex - filters []*filter - upstreamDNS string + records []*record + filters []*filter + domain string count uint64 } +func joinNames(questions []dns.Question) string { + var names []string + for _, q := range questions { + names = append(names, q.Name) + } + return strings.Join(names, " ") +} + +func joinTypes(questions []dns.Question) string { + var types []string + for _, q := range questions { + types = append(types, dns.TypeToString[q.Qtype]) + } + return strings.Join(types, " ") +} + func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) + m := new(dns.Msg).SetReply(req) + id := atomic.AddUint64(&s.count, 1) startTime := time.Now() defer func() { - log.Println("d"+strconv.FormatUint(id, 10), w.RemoteAddr().String(), time.Since(startTime).Round(1*time.Microsecond), m.Rcode, m.Question[0].Name, m.Answer) + log.Println("d"+strconv.FormatUint(id, 10), w.RemoteAddr().String(), time.Since(startTime).Round(1*time.Microsecond), RcodeTypeMap[m.Rcode], joinTypes(req.Question), joinNames(req.Question)) }() for _, q := range req.Question { for _, r := range s.filters { if ok, _ := r.name.MatchString(q.Name); ok { if r.allowance { - m.Rcode = dns.RcodeSuccess break } else { m.Rcode = dns.RcodeRefused @@ -52,11 +68,32 @@ func (s *server) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { } } } - { - c := new(dns.Client) - in, _, _ := c.Exchange(req, s.upstreamDNS) - w.WriteMsg(in) - return + + for _, q := range req.Question { + for _, r := range s.records { + if q.Qtype == r.rtype { + if ok, _ := r.name.MatchString(q.Name); ok { + var ret dns.RR + switch r.rtype { + case dns.TypeA: + ret = &dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(r.rvalue)} + case dns.TypePTR: + ret = &dns.PTR{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 60}, + Ptr: r.rvalue} + default: + m.Rcode = dns.RcodeNotImplemented + goto _end + } + m.Answer = append(m.Answer, ret) + } + } + } + } + if len(m.Answer) == 0 { + m.Rcode = dns.RcodeNameError } _end: w.WriteMsg(m) @@ -72,22 +109,35 @@ func (s *server) AddFilter(name *regexp2.Regexp, allowance bool) error { s.filters = append(s.filters, &filter{name: name, allowance: allowance}) return nil } +func (s *server) AddRecord(name *regexp2.Regexp, rtype uint16, rvalue string) error { + s.records = append(s.records, &record{name: name, rtype: rtype, rvalue: rvalue}) + return nil +} + +func (s *server) AddRecordWithIP(name string, ip string) error { + real_subdomain := name + "." + s.domain + "." + real_ptr := reverseIP(ip) + ".in-addr.arpa." + s.domain + "." + + s.AddRecord(regexp2.MustCompile(Dnsname2Regexp(real_subdomain), 0), dns.TypeA, ip) + s.AddRecord(regexp2.MustCompile(Dnsname2Regexp(real_ptr), 0), dns.TypePTR, real_subdomain) + + return nil +} -// func (s *server) AddRecord(domain string, rr dns.RR) error { -// r, err := regexp2.Compile(domain, 0) -// if err != nil { -// return err -// } -// s.records_lock.Lock() -// s.records = append(s.records, &record{rr: rr, domain: r}) -// s.records_lock.Unlock() -// return nil -// } - -func NewServer(upstreamDNS string) *server { +func NewServer(domain string) *server { return &server{ - records: []*record{}, - filters: []*filter{}, - upstreamDNS: upstreamDNS, + records: []*record{}, + filters: []*filter{}, + domain: domain, + } +} + +func reverseIP(ipAddr string) string { + segments := strings.Split(ipAddr, ".") + + for i, j := 0, len(segments)-1; i < j; i, j = i+1, j-1 { + segments[i], segments[j] = segments[j], segments[i] } + + return strings.Join(segments, ".") } diff --git a/dns/utils.go b/dns/utils.go index 26344a6..85d94ee 100644 --- a/dns/utils.go +++ b/dns/utils.go @@ -1,6 +1,10 @@ package dns -import "strings" +import ( + "strings" + + "github.com/miekg/dns" +) func Dnsnames2Regexps(dnsnames []string) []string { var out []string @@ -16,3 +20,36 @@ func Dnsname2Regexp(dnsname string) (v string) { v = strings.ReplaceAll(v, "*", ".*") return "^" + v + "$" } + +func DnsStringTypeToInt(s string) uint16 { + switch s { + case "A": + return dns.TypeA + case "AAAA": + return dns.TypeAAAA + case "CNAME": + return dns.TypeCNAME + case "MX": + return dns.TypeMX + case "NS": + return dns.TypeNS + case "PTR": + return dns.TypePTR + case "SOA": + return dns.TypeSOA + case "SRV": + return dns.TypeSRV + case "TXT": + return dns.TypeTXT + default: + return 0 + } +} + +var DnsTypeMap = []string{ + "None", "A", "NS", "MD", "MF", "CNAME", "SOA", "MB", "MG", "MR", "NULL", "PTR", "HINFO", "MINFO", "MX", "TXT", "RP", "AFSDB", "X25", "ISDN", "RT", "NSAPPTR", "SIG", "KEY", "PX", "GPOS", "AAAA", "LOC", "NXT", "EID", "NIMLOC", "SRV", "ATMA", "NAPTR", "KX", "CERT", "DNAME", "OPT", "APL", "DS", "SSHFP", "IPSECKEY", "RRSIG", "NSEC", "DNSKEY", "DHCID", "NSEC3", "NSEC3PARAM", "TLSA", "SMIMEA", "HIP", "NINFO", "RKEY", "TALINK", "CDS", "CDNSKEY", "OPENPGPKEY", "CSYNC", "ZONEMD", "SVCB", "HTTPS", "SPF", "UINFO", "UID", "GID", "UNSPEC", "NID", "L32", "L64", "LP", "EUI48", "EUI64", "TKEY", "TSIG", "IXFR", "AXFR", "MAILB", "MAILA", "ANY", "URI", "CAA", "AVC", "DOA", "AMTRELAY", "TA", "DLV", "RESERVED", +} + +var RcodeTypeMap = []string{ + "NoError", "FormErr", "ServFail", "NXDomain", "NotImp", "Refused", "YXDomain", "YXRRSet", "NXRRSet", "NotAuth", "NotZone", "RESERVED11", "RESERVED12", "RESERVED13", "RESERVED14", "RESERVED15", "BADVERS", "BADSIG", "BADKEY", "BADTIME", "BADMODE", "BADNAME", "BADALG", "BADTRUNC", "BADCOOKIE", "RESERVED25", "RESERVED26", "RESERVED27", "RESERVED28", "RESERVED29", "RESERVED30", "RESERVED31", +} diff --git a/ui/config.go b/ui/config.go index 9feb0d2..8a5b8a4 100644 --- a/ui/config.go +++ b/ui/config.go @@ -90,10 +90,11 @@ type UdpLoggerConfig struct { } type DnsConfig struct { - Bind string `yaml:"Bind"` - Upstream string `yaml:"Upstream"` - Records []DnsRecord `yaml:"Records,flow"` - Filters []DnsFilterRule `yaml:"Filters,flow"` + Bind string `yaml:"Bind"` + Domain string `yaml:"Domain"` + Records []DnsRecord `yaml:"Records,flow"` + Filters []DnsFilterRule `yaml:"Filters,flow"` + Binds []DnsBind `yaml:"Binds,flow"` } type DnsRecord struct { @@ -106,3 +107,7 @@ type DnsFilterRule struct { Name string `yaml:"Name"` Allowance bool `yaml:"Allowance"` } +type DnsBind struct { + Name string `yaml:"Name"` + Addr string `yaml:"Addr"` +} diff --git a/ui/myservice.go b/ui/myservice.go index c528603..25b3253 100644 --- a/ui/myservice.go +++ b/ui/myservice.go @@ -242,7 +242,8 @@ func LoadCfg(cfgs []byte) error { log.Println("sys", "tcp", err) os.Exit(-1) } - var Dns = dns.NewServer(cfg.DNS.Upstream) + var Dns = dns.NewServer(cfg.DNS.Domain) + log.Println("sys", "dns", "domain is", cfg.DNS.Domain) for _, f := range cfg.DNS.Filters { log.Println("sys", "dns", "Filter", f.Name, f.Allowance) r, err := regexp2.Compile(dns.Dnsname2Regexp(f.Name), 0) @@ -252,6 +253,18 @@ func LoadCfg(cfgs []byte) error { } Dns.AddFilter(r, f.Allowance) } + for _, r := range cfg.DNS.Records { + log.Println("sys", "dns", "Record", r.Domain, r.Type, r.Value) + Dns.AddRecord(regexp2.MustCompile(dns.Dnsname2Regexp(r.Domain), 0), dns.DnsStringTypeToInt(r.Type), r.Value) + } + for _, b := range cfg.DNS.Binds { + log.Println("sys", "dns", b.Name, "->", b.Addr) + err := Dns.AddRecordWithIP(b.Name, b.Addr) + if err != nil { + log.Println("sys", "dns", err) + os.Exit(-1) + } + } if cfg.DNS.Bind != "" { go func() {