-
Notifications
You must be signed in to change notification settings - Fork 12
/
tinydns.go
162 lines (149 loc) · 4.29 KB
/
tinydns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package tinydns
import (
"bytes"
"encoding/gob"
"fmt"
"net"
"strings"
"github.com/miekg/dns"
"github.com/projectdiscovery/hmap/store/hybrid"
sliceutil "github.com/projectdiscovery/utils/slice"
)
type TinyDNS struct {
options *Options
server *dns.Server
hm *hybrid.HybridMap
OnServeDns func(data Info)
}
type Info struct {
Domain string
Operation string
Wildcard bool
Msg string
Upstream string
}
func New(options *Options) (*TinyDNS, error) {
hm, err := hybrid.New(hybrid.DefaultDiskOptions)
if err != nil {
return nil, err
}
tinydns := &TinyDNS{
options: options,
hm: hm,
}
srv := &dns.Server{
Addr: options.ListenAddress,
Net: options.Net,
Handler: tinydns,
}
tinydns.server = srv
return tinydns, nil
}
func (t *TinyDNS) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
var info Info
domain := r.Question[0].Name
domainlookup := strings.TrimSuffix(domain, ".")
info.Domain = domainlookup
info.Operation = "request"
info.Msg = fmt.Sprintf("Received request for: %s\n", domainlookup)
if t.OnServeDns != nil {
t.OnServeDns(info)
}
switch r.Question[0].Qtype {
case dns.TypeA:
// attempts in order to retrieve the record in the following fallback-chain
if dnsRecord, ok := t.options.DnsRecords[domainlookup]; ok { // - hardcoded records
info.Domain = domainlookup
info.Operation = "in-memory"
info.Wildcard = false
info.Msg = fmt.Sprintf("Using in-memory record for %s.\n", domainlookup)
if t.OnServeDns != nil {
t.OnServeDns(info)
}
_ = w.WriteMsg(reply(r, domain, dnsRecord))
} else if dnsRecord, ok = t.options.DnsRecords["*"]; ok { // - wildcard
info.Domain = domainlookup
info.Operation = "in-memory"
info.Wildcard = true
info.Msg = fmt.Sprintf("Using in-memory wildcard record for %s.\n", domainlookup)
if t.OnServeDns != nil {
t.OnServeDns(info)
}
_ = w.WriteMsg(reply(r, domain, dnsRecord))
} else if dnsRecordBytes, ok := t.hm.Get(domain); ok { // - cache
dnsRecord := &DnsRecord{}
err := gob.NewDecoder(bytes.NewReader(dnsRecordBytes)).Decode(dnsRecord)
if err == nil {
info.Domain = domainlookup
info.Operation = "cached"
info.Wildcard = false
info.Msg = fmt.Sprintf("Using cached record for %s.\n", domainlookup)
if t.OnServeDns != nil {
t.OnServeDns(info)
}
_ = w.WriteMsg(reply(r, domain, dnsRecord))
}
} else if len(t.options.UpstreamServers) > 0 {
// upstream and store in cache
upstreamServer := sliceutil.PickRandom(t.options.UpstreamServers)
info.Domain = domainlookup
info.Operation = "cached"
info.Wildcard = false
info.Upstream = upstreamServer
info.Msg = fmt.Sprintf("Retrieving records for %s with upstream %s.\n", domainlookup, upstreamServer)
if t.OnServeDns != nil {
t.OnServeDns(info)
}
msg, err := dns.Exchange(r, upstreamServer)
if err == nil {
_ = w.WriteMsg(msg)
dnsRecord := &DnsRecord{}
for _, record := range msg.Answer {
switch recordType := record.(type) {
case *dns.A:
dnsRecord.A = append(dnsRecord.A, recordType.A.String())
case *dns.AAAA:
dnsRecord.AAAA = append(dnsRecord.AAAA, recordType.AAAA.String())
}
}
var dnsRecordBytes bytes.Buffer
if err := gob.NewEncoder(&dnsRecordBytes).Encode(dnsRecord); err == nil {
info.Domain = domainlookup
info.Operation = "saving"
info.Wildcard = false
info.Upstream = upstreamServer
info.Msg = fmt.Sprintf("Saving records for %s in cache.\n", domainlookup)
if t.OnServeDns != nil {
t.OnServeDns(info)
}
_ = t.hm.Set(domain, dnsRecordBytes.Bytes())
}
}
}
}
_ = w.WriteMsg(reply(r, domain, &DnsRecord{}))
}
func reply(r *dns.Msg, domain string, dnsRecord *DnsRecord) *dns.Msg {
msg := dns.Msg{}
msg.SetReply(r)
msg.Authoritative = true
for _, a := range dnsRecord.A {
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(a),
})
}
for _, aaaa := range dnsRecord.AAAA {
msg.Answer = append(msg.Answer, &dns.AAAA{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
AAAA: net.ParseIP(aaaa),
})
}
return &msg
}
func (t *TinyDNS) Run() error {
return t.server.ListenAndServe()
}
func (t *TinyDNS) Close() {
t.hm.Close()
}