From 6cf9c6a55107dcf985a7bac23b9d4ed60d93a913 Mon Sep 17 00:00:00 2001 From: Seth Davis Date: Thu, 23 Feb 2017 21:25:09 -0500 Subject: [PATCH] Prep for v1.0.0-beta.1 --- .gitignore | 37 +++ .travis.yml | 19 ++ CHANGES.md | 5 + LICENSE.txt | 21 ++ Makefile | 41 +++ README.md | 121 +++++++ appveyor.yml | 27 ++ dal.go | 272 ++++++++++++++++ dal_test.go | 189 +++++++++++ dns.go | 77 +++++ dns_test.go | 122 +++++++ http.go | 237 ++++++++++++++ http_test.go | 282 +++++++++++++++++ init/linux-systemd/nogo.service.example | 25 ++ main.go | 166 ++++++++++ main_test.go | 88 ++++++ web.go | 404 ++++++++++++++++++++++++ 17 files changed, 2133 insertions(+) create mode 100644 .gitignore create mode 100644 .travis.yml create mode 100644 CHANGES.md create mode 100644 LICENSE.txt create mode 100644 Makefile create mode 100644 README.md create mode 100644 appveyor.yml create mode 100644 dal.go create mode 100644 dal_test.go create mode 100644 dns.go create mode 100644 dns_test.go create mode 100644 http.go create mode 100644 http_test.go create mode 100644 init/linux-systemd/nogo.service.example create mode 100644 main.go create mode 100644 main_test.go create mode 100644 web.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f779b78 --- /dev/null +++ b/.gitignore @@ -0,0 +1,37 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# external packages folder +vendor/ + +*.db +*.log +dist/ +dogo.json +nogo +TODO.txt \ No newline at end of file diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..52b5f38 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,19 @@ +language: go + +go: + - 1.8 + - tip + +matrix: + allow_failures: + - go: tip + fast_finish: true + +install: + - go get github.com/miekg/dns + - go get github.com/boltdb/bolt + - go get github.com/pressly/chi + +script: + - go vet + - go test -race \ No newline at end of file diff --git a/CHANGES.md b/CHANGES.md new file mode 100644 index 0000000..3177759 --- /dev/null +++ b/CHANGES.md @@ -0,0 +1,5 @@ +# Change log + +## [1.0.0-beta.1] - 2017-02-26 + +- Initial public release. \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..4bd42e6 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +Copyright (c) 2017 Seth Davis https://curia.solutions/ + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8eb8de5 --- /dev/null +++ b/Makefile @@ -0,0 +1,41 @@ +VERSION := 1.0.0-beta.1 +BINARY := nogo +SHELL := /bin/bash +LDFLAGS := "-X main.version=$(VERSION) -X main.build=`git rev-parse --verify --short HEAD`" +GOX_OSARCH := "darwin/amd64 freebsd/386 freebsd/amd64 freebsd/arm linux/386 linux/amd64 linux/arm64 netbsd/386 netbsd/amd64 netbsd/arm openbsd/386 openbsd/amd64 windows/386 windows/amd64" +GOX_OUTPUT := "build/$(BINARY)_v$(VERSION)_{{.OS}}_{{.Arch}}/$(BINARY)" + +.DEFAULT_GOAL := $(BINARY) +$(BINARY): + go build -ldflags $(LDFLAGS) -o $(BINARY) + +build: # https://github.com/mitchellh/gox + CGO_ENABLED=0 gox -ldflags $(LDFLAGS) -osarch $(GOX_OSARCH) -output $(GOX_OUTPUT) + # gox doesn't support fine-tuning arm (GOARM defaults to 6), so build linux/arm individually + CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=5 go build -ldflags $(LDFLAGS) -o build/$(BINARY)_v$(VERSION)_linux_armv5/$(BINARY) + CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -ldflags $(LDFLAGS) -o build/$(BINARY)_v$(VERSION)_linux_armv6/$(BINARY) + CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -ldflags $(LDFLAGS) -o build/$(BINARY)_v$(VERSION)_linux_armv7/$(BINARY) + +dist: build + @mkdir dist + $(eval BUILDS := $(shell find build/ -type f -printf '%P\n')) + @for f in $(BUILDS); do \ + echo "Archiving $${f%/*}..."; \ + if [[ $$f =~ darwin|windows ]]; then \ + zip -j dist/$${f%/*}.zip build/$$f README.md; \ + else \ + tar -cvzf dist/$${f%/*}.tar.gz README.md -C build/$${f%/*} $${f#*/}; \ + fi; \ + done + +.PHONY: deps +deps: + go get github.com/miekg/dns + go get github.com/boltdb/bolt + go get github.com/pressly/chi + +.PHONY: clean +clean: + rm -rf build/ + rm -rf dist/ + go clean \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..6fb35be --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +nogo +==== + +[![Linux Build Status](https://img.shields.io/travis/seedifferently/nogo.svg?style=flat-square&label=linux+build)](https://travis-ci.org/seedifferently/nogo) [![Windows Build Status](https://img.shields.io/appveyor/ci/seedifferently/nogo.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/seedifferently/nogo) + + +What? +----- + +`nogo` blocks access to various sites (ads, tracking, porn, gambling, etc) by +acting as a DNS proxy server with host blacklist support. + + +Why? +---- + +I wanted an open source ad blocker solution that was more universal than a +browser plugin, and: + +* Was easy to utilize with unrooted mobile devices (so that [battery life could + be conserved][1]). +* Had a basic web control panel and API for adding, removing, and "pausing" + hosts. +* Provided straightforward cross-platform support and acceptable performance (so + that I could run it from my Raspberry Pi). +* Could be used as a master host "blacklist" service for network-wide ad + blocking (e.g. by configuring the DNS on my router to point to `nogo`). + +[1]: https://lifehacker.com/ad-blockers-on-mobile-can-reduce-battery-drain-by-up-to-1764344384 + + +How? +---- + +You may simply [download a binary release](https://github.com/seedifferently/nogo/releases) +for your platform, or you can follow the steps below to build from source: + +1. Install [Go](https://golang.org/doc/install) (v1.8 or later is required). + +2. Clone the repo, then `cd` into it. + +3. Install the dependencies by running `make deps`. Or if you don't have `make`: + * `go get github.com/miekg/dns` + * `go get github.com/boltdb/bolt` + * `go get github.com/pressly/chi` + +4. Build the app by running `make`. Or if you don't have `make`: `go build` + +5. Run the app: `sudo ./nogo` + +**Note:** + +* Since `nogo` binds to port `:53` by default, it must be given access to + "privileged" ports (e.g. via `setuid` or `sudo`). +* Run with the `-help` switch for information on additional runtime options + (such as disabling or password protecting the web control panel). + + +### Important post-install steps: + +#### 1. You must add hosts to the blacklist. + +`nogo` doesn't ship with a built-in blacklist, so it won't block any hosts until +you add them. + +There are currently two methods for adding hosts to the blacklist: + +1. Navigate to the web control panel (default: [http://localhost:8080/][1]) and + add a host using the form. + +2. Download a popular hosts list file (e.g. pick one from the list at + [https://github.com/StevenBlack/hosts][2]), and execute `nogo` with the + `-import` switch on its first run. + + +#### 2. You must reconfigure your DNS. + +Your computer/mobile device/etc is probably set up by default to utilize a DNS +server which allows connections to any host. Unless you update your DNS +configuration to point to `nogo` (and *only* to `nogo`), nothing will change. + +For those of you who may be unfamiliar with how to update your DNS +configuration, check out Google's guide for their DNS service here: +[https://developers.google.com/speed/public-dns/docs/using][3] + +You can follow their instructions, but don't forget to substitute their DNS +service IP addresses with the sole IP address of the machine running `nogo`. + +[1]: http://localhost:8080/ +[2]: https://github.com/StevenBlack/hosts +[3]: https://developers.google.com/speed/public-dns/docs/using + + +### Known Issues and Limitations + +* The DNS proxy server utilizes a fairly basic configuration, so advanced + features such as EDNS and DNSSEC are not currently supported. +* Due to the fact that the web control panel utilizes a few modern techniques + (such as [flexbox][1] and the [Fetch API][2]), you may experience some issues + with its interface on non-current browsers. + +[1]: https://developer.mozilla.org/en-US/docs/Web/CSS/CSS_Flexible_Box_Layout/Using_CSS_flexible_boxes +[2]: https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API + + +Who? +---- + +My name is Seth and I've been talking to machines in various languages since the +early 90s. If you find this useful and want to say thanks, feel free to +[tweet me][1], [buy me a beer][2], [share some Satoshi][3], or pass +[my resume][4] on to someone you know who is tackling interesting problems with +software. + +[1]: https://twitter.com/seedifferently +[2]: https://paypal.me/seedifferently +[3]: https://coinbase.com/seedifferently +[4]: https://resume.sethdavis.name + + +Copyright (c) 2017 Seth Davis \ No newline at end of file diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 0000000..66c5dea --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,27 @@ +version: "{build}" + +os: Windows Server 2012 R2 + +clone_folder: c:\gopath\src\github.com\seedifferently\nogo + +environment: + GOPATH: c:\gopath + +install: + - rmdir c:\go /s /q + - appveyor DownloadFile https://storage.googleapis.com/golang/go1.8.windows-amd64.zip + - 7z x go1.8.windows-amd64.zip -y -oC:\ > NUL + - go version + - go env + - go get github.com/miekg/dns + - go get github.com/boltdb/bolt + - go get github.com/pressly/chi + - set PATH=%GOPATH%\bin;%PATH% + +build: off + +test_script: + - go vet + - go test -race + +deploy: off \ No newline at end of file diff --git a/dal.go b/dal.go new file mode 100644 index 0000000..5ce2973 --- /dev/null +++ b/dal.go @@ -0,0 +1,272 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "os" + "strings" + + "github.com/boltdb/bolt" +) + +var errRecordNotFound = errors.New("record not found") + +type Record struct { + Paused bool +} + +func (r *Record) isAllowed() bool { + // A paused record is an allowed record + return r.Paused +} + +func (r *Record) jsonEncode() ([]byte, error) { + data, err := json.Marshal(r) + if err != nil { + return nil, err + } + + return data, nil +} + +func (r *Record) jsonDecode(data []byte) (*Record, error) { + if err := json.Unmarshal(data, &r); err != nil { + return nil, err + } + + return r, nil +} + +func (db *DB) keyCount() (int, error) { + var stats bolt.BucketStats + + if err := db.View(func(tx *bolt.Tx) error { + // Get bucket stats + stats = tx.Bucket(blacklistKey).Stats() + + return nil + }); err != nil { + return 0, err + } + + return stats.KeyN, nil +} + +func (db *DB) get(key string) (*Record, error) { + var r *Record + + err := db.View(func(tx *bolt.Tx) error { + var err error + + v := tx.Bucket(blacklistKey).Get([]byte(strings.ToLower(key))) + if v == nil { + return errRecordNotFound + } else if len(v) == 0 { + // Empty value (likely due to hosts import) + r = &Record{} + return nil + } + + r, err = r.jsonDecode(v) + return err + }) + if err != nil { + return nil, err + } + + return r, nil +} + +func (db *DB) put(key string, r *Record) error { + err := db.Update(func(tx *bolt.Tx) error { + var err error + var v []byte + + // Check for record data + if r != nil { + v, err = r.jsonEncode() + if err != nil { + return err + } + } + + return tx.Bucket(blacklistKey).Put([]byte(strings.ToLower(key)), v) + }) + + return err +} + +func (db *DB) delete(key string) error { + err := db.Update(func(tx *bolt.Tx) error { + return tx.Bucket(blacklistKey).Delete([]byte(strings.ToLower(key))) + }) + + return err +} + +func (db *DB) find(search string) map[string]*Record { + var err error + var recs = make(map[string]*Record) + + if search != "" { + search = strings.ToLower(search) + + db.View(func(tx *bolt.Tx) error { + b := tx.Bucket(blacklistKey) + + b.ForEach(func(k, v []byte) error { + var r *Record + + if v == nil { + // Skip "sub-buckets" + return nil + } + + // Convert key to string for processing + sk := string(k) + + if strings.Contains(sk, search) { + if len(v) == 0 { + // Empty value (likely due to hosts import) + r = &Record{} + } else { + if r, err = r.jsonDecode(v); err != nil { + // Log the decode error and continue + log.Printf("Record.jsonDecode(%s) Error: %s\n", sk, err) + r = &Record{} + } + } + + recs[sk] = r + } + + return nil + }) + + return nil + }) + } + + return recs +} + +func (db *DB) getPaused() map[string]*Record { + var err error + var recs = make(map[string]*Record) + + db.View(func(tx *bolt.Tx) error { + b := tx.Bucket(blacklistKey) + + b.ForEach(func(k, v []byte) error { + var r *Record + + if v == nil || len(v) == 0 { + // Skip "sub-buckets" and empty values + return nil + } + + if r, err = r.jsonDecode(v); err != nil { + // Log the decode error and continue + log.Printf("Record.jsonDecode(%s) Error: %s\n", k, err) + return nil + } + + if r.Paused { + recs[string(k)] = r + } + + return nil + }) + + return nil + }) + + return recs +} + +func (db *DB) importBlacklist(fname string) error { + f, err := os.Open(fname) + if err != nil { + return err + } + defer f.Close() + + lines, err := lineCount(f) + if err != nil { + return err + } + + // Rewind + if _, err = f.Seek(0, 0); err != nil { + return err + } + + n := 0 + scanner := bufio.NewScanner(f) + + for scanner.Scan() { + n++ + fmt.Printf("\rProcessing %d of %d", n, lines) + + // Parse line and trim any dots + r := strings.Trim(parseRecord(scanner.Text()), ".") + + // Ignore records that seem invalid, or don't contain dots (e.g. "localhost" and "test") + i := strings.IndexByte(r, '.') + if i < 1 || len(r) < 4 || i >= len(r)-2 { + continue + } + + db.put(r, nil) + } + + fmt.Print("\n") + + return scanner.Err() +} + +func parseRecord(s string) string { + // Ignore comments + i := strings.IndexByte(s, '#') + if i == 0 { + return "" + } else if i > 0 { + s = s[:i] + } + + sf := strings.Fields(s) + + if len(sf) < 1 { + // empty + return "" + } else if len(sf) > 1 { + // Return 2nd item if more than 1 + return sf[1] + } else { + // Return one and only item + return sf[0] + } +} + +func lineCount(r io.Reader) (int, error) { + buf := make([]byte, 32*1024) + sep := []byte{'\n'} + count := 0 + + for { + c, err := r.Read(buf) + count += bytes.Count(buf[:c], sep) + + switch { + case err == io.EOF: + return count, nil + case err != nil: + return count, err + } + } +} diff --git a/dal_test.go b/dal_test.go new file mode 100644 index 0000000..a82c13a --- /dev/null +++ b/dal_test.go @@ -0,0 +1,189 @@ +package main + +import ( + "io/ioutil" + "os" + "strings" + "testing" + + "github.com/boltdb/bolt" +) + +func TestRecord_isAllowed(t *testing.T) { + r := &Record{} + testEqual(t, "isAllowed() = %+v, want %+v", r.isAllowed(), false) + + r = &Record{Paused: true} + testEqual(t, "isAllowed() = %+v, want %+v", r.isAllowed(), true) +} + +func TestRecord_jsonEncode(t *testing.T) { + r := &Record{} + d, _ := r.jsonEncode() + testEqual(t, "jsonEncode() = %+v, want %+v", string(d), "{\"Paused\":false}") + + r = &Record{Paused: true} + d, _ = r.jsonEncode() + testEqual(t, "jsonEncode() = %+v, want %+v", string(d), "{\"Paused\":true}") +} + +func TestRecord_jsonDecode(t *testing.T) { + var r *Record + + r, _ = r.jsonDecode([]byte("{\"Paused\":false}")) + testEqual(t, "jsonDecode() = %+v, want %+v", *r, Record{}) + + r, _ = r.jsonDecode([]byte("{\"Paused\":true}")) + testEqual(t, "jsonDecode() = %+v, want %+v", *r, Record{Paused: true}) +} + +func TestDB_keyCount(t *testing.T) { + var r *Record + db.Reset() + + c, _ := db.keyCount() + testEqual(t, "keyCount() = %+v, want %+v", c, 0) + + if err := db.put("test.test", r); err != nil { + t.Errorf("failed to put: %+v", err) + } + c, _ = db.keyCount() + testEqual(t, "keyCount() = %+v, want %+v", c, 1) +} + +func TestDB_put_get(t *testing.T) { + var r *Record + db.Reset() + + if err := db.put("Nil.test", r); err != nil { + t.Errorf("failed to put: %+v", err) + } + db.View(func(tx *bolt.Tx) error { + v := tx.Bucket(blacklistKey).Get([]byte("nil.test")) + testEqual(t, "put() = %+v, want %+v", string(v), "") + return nil + }) + r, _ = db.get("nil.Test") + testEqual(t, "get() = %+v, want %+v", *r, Record{}) + + if err := db.put("empty.test", &Record{}); err != nil { + t.Errorf("failed to put: %+v", err) + } + db.View(func(tx *bolt.Tx) error { + v := tx.Bucket(blacklistKey).Get([]byte("empty.test")) + testEqual(t, "put() = %+v, want %+v", string(v), "{\"Paused\":false}") + return nil + }) + r, _ = db.get("empty.test") + testEqual(t, "get() = %+v, want %+v", *r, Record{}) + + if err := db.put("paused.test", &Record{Paused: true}); err != nil { + t.Errorf("failed to put: %+v", err) + } + db.View(func(tx *bolt.Tx) error { + v := tx.Bucket(blacklistKey).Get([]byte("paused.test")) + testEqual(t, "put() = %+v, want %+v", string(v), "{\"Paused\":true}") + return nil + }) + r, _ = db.get("paused.test") + testEqual(t, "get() = %+v, want %+v", *r, Record{Paused: true}) +} + +func TestDB_delete(t *testing.T) { + db.Reset() + + if err := db.put("delete.test", &Record{}); err != nil { + t.Errorf("failed to put: %+v", err) + } + r, _ := db.get("delete.test") + testEqual(t, "get() = %+v, want %+v", *r, Record{}) + if err := db.delete("delete.test"); err != nil { + t.Errorf("failed to delete: %+v", err) + } + _, err := db.get("delete.test") + testEqual(t, "get() err = %+v, want %+v", err, errRecordNotFound) +} + +func TestDB_find(t *testing.T) { + var r *Record + db.Reset() + + if err := db.put("one.abcd", r); err != nil { + t.Errorf("failed to put: %+v", err) + } + if err := db.put("two.abcd", r); err != nil { + t.Errorf("failed to put: %+v", err) + } + if err := db.put("one.asdf", r); err != nil { + t.Errorf("failed to put: %+v", err) + } + if err := db.put("one.arst", r); err != nil { + t.Errorf("failed to put: %+v", err) + } + + rs := db.find("one") + testEqual(t, "len(find('one')) = %+v, want %+v", len(rs), 3) + rs = db.find("abcd") + testEqual(t, "len(find('abcd')) = %+v, want %+v", len(rs), 2) + rs = db.find("arst") + testEqual(t, "len(find('arst')) = %+v, want %+v", len(rs), 1) +} + +func TestDB_getPaused(t *testing.T) { + var r *Record + db.Reset() + + if err := db.put("Nil.test", r); err != nil { + t.Errorf("failed to put: %+v", err) + } + if err := db.put("empty.test", &Record{}); err != nil { + t.Errorf("failed to put: %+v", err) + } + if err := db.put("paused.test", &Record{Paused: true}); err != nil { + t.Errorf("failed to put: %+v", err) + } + + rs := db.getPaused() + testEqual(t, "len(getPaused()) = %+v, want %+v", len(rs), 1) + testEqual(t, "getPaused()[0] = %+v, want %+v", *rs["paused.test"], Record{Paused: true}) +} + +func TestDB_importBlacklist(t *testing.T) { + db.Reset() + + f, err := ioutil.TempFile("", "nogo-import-") + if err != nil { + t.Errorf("failed to create TempFile: %+v", err) + } + f.WriteString(" # comment\n.\ntst\ntest\n.test\ntes.t\ntest.test\n") + f.Sync() + defer f.Close() + defer os.Remove(f.Name()) + + db.importBlacklist(f.Name()) + + c, err := db.keyCount() + if err != nil { + t.Errorf("failed to get keyCount: %+v", err) + } + testEqual(t, "keyCount() = %+v, want %+v", c, 1) + + r, _ := db.get("test.test") + testEqual(t, "get('test.test') = %+v, want %+v", *r, Record{}) +} + +func Test_parseRecord(t *testing.T) { + testEqual(t, "parseRecord('# comment') = %+v, want %+v", parseRecord("# comment"), "") + testEqual(t, "parseRecord(' ') = %+v, want %+v", parseRecord(" "), "") + testEqual(t, "parseRecord('partial # comment') = %+v, want %+v", parseRecord("partial # comment"), "partial") + testEqual(t, "parseRecord('127.0.0.1\tlocalhost # comment') = %+v, want %+v", parseRecord("127.0.0.1\tlocalhost # comment"), "localhost") + testEqual(t, "parseRecord('127.0.0.1 localhost alias # comment') = %+v, want %+v", parseRecord("127.0.0.1 localhost alias # comment"), "localhost") +} + +func TestDB_lineCount(t *testing.T) { + c, err := lineCount(strings.NewReader("one\ntwo\nthree\n")) + if err != nil { + t.Errorf("failed to get lineCount: %+v", err) + } + testEqual(t, "lineCount('one\\ntwo\\nthree\\n') = %+v, want %+v", c, 3) +} diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..98fb345 --- /dev/null +++ b/dns.go @@ -0,0 +1,77 @@ +package main + +import ( + "log" + "strings" + + "github.com/miekg/dns" +) + +func dnsHandler(w dns.ResponseWriter, r *dns.Msg) { + // Make a copy of the questions (in case we need them for the error response) + qs := make([]dns.Question, len(r.Question)) + copy(qs, r.Question) + + // If none of the questions are allowed, respond with an error message + if r.Question = filterQuestions(r.Question); len(r.Question) == 0 { + m := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: r.Id, + Response: true, + Opcode: dns.OpcodeQuery, + Authoritative: true, + RecursionDesired: r.RecursionDesired, + RecursionAvailable: false, + Rcode: dns.RcodeNameError, // NXDOMAIN + }, + Question: qs, + } + + w.WriteMsg(m) + return + } + + // Otherwise, proxy allowed questions upstream + for _, addr := range strings.Split(*dnsProxyTo, ",") { + in, _, err := dnsClient.Exchange(r, addr) + if err != nil { + log.Printf("Exchange(%s) Error: %s\n", addr, err) + continue + } + + w.WriteMsg(in) + return + } + + dns.HandleFailed(w, r) +} + +func filterQuestions(qs []dns.Question) []dns.Question { + var keep []dns.Question + + for _, q := range qs { + if isNameAllowed(q.Name) { + keep = append(keep, q) + } + } + + return keep +} + +func isNameAllowed(n string) bool { + n = strings.TrimSuffix(n, ".") + + r, err := db.get(n) + if err != nil { + if err == errRecordNotFound { + // If no record by that name was found, assume it is allowed + return true + } + + // For other errors, assume the name is now allowed + log.Printf("db.get(%s) Error: %s\n", n, err) + return false + } + + return r.isAllowed() +} diff --git a/dns_test.go b/dns_test.go new file mode 100644 index 0000000..01d0d40 --- /dev/null +++ b/dns_test.go @@ -0,0 +1,122 @@ +package main + +import ( + "net" + "sync" + "testing" + "time" + + "github.com/miekg/dns" +) + +func RunLocalDNSServer(laddr string, echo bool) (*dns.Server, string, error) { + pc, err := net.ListenPacket("udp", laddr) + if err != nil { + return nil, "", err + } + + server := &dns.Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour} + if echo { // Act as a simple echo server + server.Handler = dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + w.WriteMsg(m) + }) + } + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = func() { waitLock.Unlock() } + + go func() { + server.ActivateAndServe() + pc.Close() + }() + + waitLock.Lock() + return server, pc.LocalAddr().String(), nil +} + +func Test_dnsHandler(t *testing.T) { + db.Reset() + + s, addrstr, err := RunLocalDNSServer("127.0.0.1:0", false) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + es, eaddrstr, err := RunLocalDNSServer("127.0.0.1:0", true) + if err != nil { + t.Fatalf("unable to run echo test server: %v", err) + } + defer es.Shutdown() + + *dnsProxyTo = eaddrstr + dns.HandleFunc(".", dnsHandler) + defer dns.HandleRemove(".") + + if err := db.put("test.disallowed", &Record{}); err != nil { + t.Errorf("failed to put: %+v", err) + } + if err := db.put("test.allowed", &Record{Paused: true}); err != nil { + t.Errorf("failed to put: %+v", err) + } + + m := new(dns.Msg) + m.SetQuestion("test.disallowed.", dns.TypeA) + + r, err := dns.Exchange(m, addrstr) + if err != nil { + t.Errorf("failed to exchange: %+v", err) + } + testEqual(t, "Disallowed Rcode = %+v, want %+v", r.Rcode, dns.RcodeNameError) + testEqual(t, "Disallowed Authoritative = %+v, want %+v", r.Authoritative, true) + testEqual(t, "Disallowed RecursionAvailable = %+v, want %+v", r.RecursionAvailable, false) + testEqual(t, "Disallowed Questions = %+v, want %+v", r.Question, m.Question) + + m = new(dns.Msg) + m.SetQuestion("test.allowed.", dns.TypeA) + r, err = dns.Exchange(m, addrstr) + if err != nil { + t.Errorf("failed to exchange: %+v", err) + } + testEqual(t, "Allowed Rcode = %+v, want %+v", r.Rcode, dns.RcodeSuccess) + testEqual(t, "Allowed Questions = %+v, want %+v", r.Question, m.Question) + + // Multiple questions + m = new(dns.Msg) + m.SetQuestion("test.allowed.", dns.TypeA) + m.Question = append(m.Question, dns.Question{Name: "test.disallowed.", Qtype: dns.TypeA, Qclass: dns.ClassINET}) + r, err = dns.Exchange(m, addrstr) + if err != nil { + t.Errorf("failed to exchange: %+v", err) + } + testEqual(t, "Multiple Rcode = %+v, want %+v", r.Rcode, dns.RcodeSuccess) + testEqual(t, "Multiple Response len() = %+v, want %+v", len(r.Question), 1) + testEqual(t, "Multiple Response Question = %+v, want %+v", r.Question[0].Name, "test.allowed.") +} + +func Test_filterQuestions(t *testing.T) { + db.Reset() + + if err := db.put("test.allowed", &Record{Paused: true}); err != nil { + t.Errorf("failed to put: %+v", err) + } + if err := db.put("test.disallowed", &Record{}); err != nil { + t.Errorf("failed to put: %+v", err) + } + + qs := filterQuestions([]dns.Question{{Name: "Test.Allowed."}, {Name: "Test.disallowed"}}) + testEqual(t, "len(filterQuestions(...)) = %+v, want %+v", len(qs), 1) + testEqual(t, "filterQuestions(...)[0].Name = %+v, want %+v", qs[0].Name, "Test.Allowed.") +} + +func Test_isNameAllowed(t *testing.T) { + db.Reset() + + db.put("Test.allowed", &Record{Paused: true}) + db.put("Test.disallowed", &Record{}) + testEqual(t, "isNameAllowed('test.allowed') = %+v, want %+v", isNameAllowed("Test.Allowed."), true) + testEqual(t, "isNameAllowed('test.disallowed') = %+v, want %+v", isNameAllowed("Test.Disallowed"), false) + testEqual(t, "isNameAllowed('not.in.db') = %+v, want %+v", isNameAllowed("not.in.db."), true) +} diff --git a/http.go b/http.go new file mode 100644 index 0000000..80e77b1 --- /dev/null +++ b/http.go @@ -0,0 +1,237 @@ +package main + +import ( + "encoding/base64" + "html/template" + "io" + "log" + "net/http" + "strconv" + "strings" + + "github.com/pressly/chi" + "github.com/pressly/chi/render" +) + +type H map[string]interface{} + +// HTTP basic auth middleware +func basicAuth(password string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + + s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(s) != 2 { + http.Error(w, http.StatusText(401), 401) + return + } + + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + log.Printf("base64.StdEncoding.DecodeString() Error: %s\n", err) + http.Error(w, http.StatusText(401), 401) + return + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + log.Printf("strings.SplitN() Error: %s\n", err) + http.Error(w, http.StatusText(401), 401) + return + } + + if pair[0] != "admin" || pair[1] != password { + http.Error(w, http.StatusText(401), 401) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// GET / (root index) +func rootIndexHandler(w http.ResponseWriter, r *http.Request) { + var data map[string]*Record + q := r.FormValue("q") + p := r.FormValue("p") + + if q != "" { // GET /?q=query + if len(q) >= 3 { + data = db.find(q) + } else { + http.Error(w, http.StatusText(422), 422) + return + } + } else if p == "1" { // GET /?p=1 + data = db.getPaused() + } + + totalCount, err := db.keyCount() + if err != nil { + log.Printf("db.keyCount() Error: %s\n", err) + http.Error(w, http.StatusText(500), 500) + return + } + + tmpl, err := template.New("index").Parse(indexTmpl) + if err != nil { + log.Printf("template.ParseFiles() Error: %s\n", err) + http.Error(w, http.StatusText(500), 500) + return + } + + if err = tmpl.Execute(w, H{"data": data, "total_count": totalCount, "q": q, "p": p}); err != nil { + log.Printf("tmpl.Execute() Error: %s\n", err) + http.Error(w, http.StatusText(500), 500) + } +} + +// POST / +func rootCreateHandler(w http.ResponseWriter, r *http.Request) { + var rec *Record + + key := r.FormValue("key") + if len(key) < 4 { + http.Error(w, http.StatusText(422), 422) + return + } + + p := r.FormValue("paused") + if p == "1" { + rec = &Record{Paused: true} + } + + // Save + if err := db.put(key, rec); err != nil { + log.Printf("db.put(%s) Error: %s\n", key, err) + http.Error(w, http.StatusText(500), 500) + return + } + + // Redirect to key view + http.Redirect(w, r, strings.Join([]string{"/", key}, ""), 302) +} + +// GET /:key +func rootReadHandler(w http.ResponseWriter, r *http.Request) { + key := chi.URLParam(r, "key") + + rec, err := db.get(key) + if err == errRecordNotFound { + http.Error(w, http.StatusText(404), 404) + return + } else if err != nil { + log.Printf("db.get(%s) Error: %s\n", key, err) + http.Error(w, http.StatusText(500), 500) + return + } + data := map[string]*Record{key: rec} + + totalCount, err := db.keyCount() + if err != nil { + log.Printf("db.keyCount() Error: %s\n", err) + http.Error(w, http.StatusText(500), 500) + return + } + + tmpl, err := template.New("index").Parse(indexTmpl) + if err != nil { + log.Printf("template.ParseFiles() Error: %s\n", err) + http.Error(w, http.StatusText(500), 500) + return + } + + if err = tmpl.Execute(w, H{"data": data, "total_count": totalCount}); err != nil { + log.Printf("tmpl.Execute() Error: %s\n", err) + http.Error(w, http.StatusText(500), 500) + } +} + +// GET /api/ +func apiIndexHandler(w http.ResponseWriter, r *http.Request) { + var data map[string]*Record + + q := r.FormValue("q") + p := r.FormValue("p") + if len(q) >= 3 { // GET /api/?q=query + data = db.find(q) + } else if p == "1" { // GET /api/?p=1 + data = db.getPaused() + } else { + http.Error(w, http.StatusText(422), 422) + return + } + + render.JSON(w, r, H{"data": data}) +} + +// GET /api/:key +func apiReadHandler(w http.ResponseWriter, r *http.Request) { + key := chi.URLParam(r, "key") + + data, err := db.get(key) + if err == errRecordNotFound { + http.Error(w, http.StatusText(404), 404) + return + } else if err != nil { + log.Printf("db.get(%s) Error: %s\n", key, err) + http.Error(w, http.StatusText(500), 500) + return + } + + render.JSON(w, r, H{"data": H{key: data}}) +} + +// PUT /api/:key +func apiPutHandler(w http.ResponseWriter, r *http.Request) { + var data Record + + key := chi.URLParam(r, "key") + if len(key) < 4 { + http.Error(w, http.StatusText(422), 422) + return + } + + // Bind + if err := render.Bind(r.Body, &data); err != nil && err != io.EOF { + log.Printf("render.Bind() Error: %s\n", err) + http.Error(w, http.StatusText(400), 400) + return + } + + // Save + if err := db.put(key, &data); err != nil { + log.Printf("db.put(%s) Error: %s\n", key, err) + http.Error(w, http.StatusText(500), 500) + return + } + + render.JSON(w, r, H{"data": H{key: data}}) +} + +// DELETE /api/:key +func apiDeleteHandler(w http.ResponseWriter, r *http.Request) { + key := chi.URLParam(r, "key") + + // Delete + if err := db.delete(key); err != nil { + log.Printf("db.delete(%s) Error: %s\n", key, err) + http.Error(w, http.StatusText(500), 500) + return + } + + render.NoContent(w, r) +} + +// GET /css/nogo.css +func cssHandler(w http.ResponseWriter, r *http.Request) { + var data = []byte(nogoCSS) + + w.Header().Set("Cache-Control", "public, max-age=31536000") + w.Header().Set("Content-Length", strconv.Itoa(len(data))) + w.Header().Set("Content-Type", "text/css") + + w.Write(data) +} diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..b5f7582 --- /dev/null +++ b/http_test.go @@ -0,0 +1,282 @@ +package main + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/pressly/chi" +) + +func Test_basicAuth(t *testing.T) { + db.Reset() + + // Unauthorized (no Authorization header) + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + cr := chi.NewRouter() + cr.Use(basicAuth("test")) + cr.Get("/", rootIndexHandler) + cr.ServeHTTP(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 401) + + // Unauthorized (wrong password) + r = httptest.NewRequest("GET", "/", nil) + r.Header.Set("Authorization", "Basic YWRtaW46YWRtaW4=") + w = httptest.NewRecorder() + cr.ServeHTTP(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 401) + + // Authorized + r = httptest.NewRequest("GET", "/", nil) + r.Header.Set("Authorization", "Basic YWRtaW46dGVzdA==") + w = httptest.NewRecorder() + cr.ServeHTTP(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Body contains '0 total records' = %+v, want %+v", strings.Contains(w.Body.String(), "0 total records"), true) +} + +func Test_rootIndexHandler(t *testing.T) { + db.Reset() + + // No records + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + rootIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Body contains '0 total records' = %+v, want %+v", strings.Contains(w.Body.String(), "0 total records"), true) + + // A record + if err := db.put("test.test", &Record{Paused: true}); err != nil { + t.Errorf("failed to put: %+v", err) + } + r = httptest.NewRequest("GET", "/", nil) + w = httptest.NewRecorder() + rootIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "text/html; charset=utf-8") + testEqual(t, "Body contains '1 total records' = %+v, want %+v", strings.Contains(w.Body.String(), "1 total records"), true) + + // Search record + r = httptest.NewRequest("GET", "/?q=te", nil) + w = httptest.NewRecorder() + rootIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 422) + r = httptest.NewRequest("GET", "/?q=tes", nil) + w = httptest.NewRecorder() + rootIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "text/html; charset=utf-8") + testEqual(t, "Body contains 'Found 1 of 1 total records' = %+v, want %+v", strings.Contains(w.Body.String(), "Found 1 of 1 total records"), true) + testEqual(t, "Body contains 'test.test' = %+v, want %+v", strings.Contains(w.Body.String(), "
test.test
"), true) + + // List paused records + r = httptest.NewRequest("GET", "/?p=1", nil) + w = httptest.NewRecorder() + rootIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "text/html; charset=utf-8") + testEqual(t, "Body contains '1 of 1 total records' = %+v, want %+v", strings.Contains(w.Body.String(), "1 of 1 total records"), true) + testEqual(t, "Body contains 'test.test' = %+v, want %+v", strings.Contains(w.Body.String(), "
test.test
"), true) +} + +func Test_rootCreateHandler(t *testing.T) { + db.Reset() + + // Invalid + r := &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/"}, + Form: url.Values{"key": {"tst"}, "paused": {"1"}}, + } + w := httptest.NewRecorder() + rootCreateHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 422) + + // Valid + r = &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/"}, + Form: url.Values{"key": {"test.test"}, "paused": {"1"}}, + } + w = httptest.NewRecorder() + rootCreateHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 302) + testEqual(t, "Location header = %+v, want %+v", w.Header().Get("Location"), "/test.test") + // verify record created in db + rec, _ := db.get("test.test") + testEqual(t, "get() = %+v, want %+v", *rec, Record{Paused: true}) +} + +func Test_rootReadHandler(t *testing.T) { + db.Reset() + + // No records + r := httptest.NewRequest("GET", "/test.test", nil) + rctx := chi.NewRouteContext() + rctx.URLParams.Set("key", "test.test") + w := httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + rootReadHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 404) + + // A record + if err := db.put("test.test", &Record{Paused: true}); err != nil { + t.Errorf("failed to put: %+v", err) + } + r = httptest.NewRequest("GET", "/test.test", nil) + rctx = chi.NewRouteContext() + rctx.URLParams.Set("key", "test.test") + w = httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + rootReadHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "text/html; charset=utf-8") + testEqual(t, "Body contains '1 of 1 total records' = %+v, want %+v", strings.Contains(w.Body.String(), "1 of 1 total records"), true) + testEqual(t, "Body contains 'test.test' = %+v, want %+v", strings.Contains(w.Body.String(), "
test.test
"), true) +} + +func Test_apiIndexHandler(t *testing.T) { + db.Reset() + if err := db.put("test.test", &Record{Paused: true}); err != nil { + t.Errorf("failed to put: %+v", err) + } + + // Bare request + r := httptest.NewRequest("GET", "/api/", nil) + w := httptest.NewRecorder() + apiIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 422) + + // Search record + r = httptest.NewRequest("GET", "/api/?q=te", nil) + w = httptest.NewRecorder() + apiIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 422) + r = httptest.NewRequest("GET", "/api/?q=tes", nil) + w = httptest.NewRecorder() + apiIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "application/json; charset=utf-8") + testEqual(t, "Body = %+v, want %+v", w.Body.String(), "{\"data\":{\"test.test\":{\"Paused\":true}}}\n") + + // List paused records + r = httptest.NewRequest("GET", "/api/?p=1", nil) + w = httptest.NewRecorder() + apiIndexHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "application/json; charset=utf-8") + testEqual(t, "Body = %+v, want %+v", w.Body.String(), "{\"data\":{\"test.test\":{\"Paused\":true}}}\n") +} + +func Test_apiReadHandler(t *testing.T) { + db.Reset() + + // No records + r := httptest.NewRequest("GET", "/api/test.test", nil) + rctx := chi.NewRouteContext() + rctx.URLParams.Set("key", "test.test") + w := httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + apiReadHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 404) + + // A record + if err := db.put("test.test", &Record{Paused: true}); err != nil { + t.Errorf("failed to put: %+v", err) + } + r = httptest.NewRequest("GET", "/api/test.test", nil) + rctx = chi.NewRouteContext() + rctx.URLParams.Set("key", "test.test") + w = httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + apiReadHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "application/json; charset=utf-8") + testEqual(t, "Body = %+v, want %+v", w.Body.String(), "{\"data\":{\"test.test\":{\"Paused\":true}}}\n") +} + +func Test_apiPutHandler(t *testing.T) { + db.Reset() + + // Invalid + r := httptest.NewRequest("GET", "/api/tst", nil) + rctx := chi.NewRouteContext() + rctx.URLParams.Set("key", "tst") + w := httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + apiPutHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 422) + + // Create + r = httptest.NewRequest("GET", "/api/unpaused.test", nil) + rctx = chi.NewRouteContext() + rctx.URLParams.Set("key", "unpaused.test") + w = httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + apiPutHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "application/json; charset=utf-8") + testEqual(t, "Body = %+v, want %+v", w.Body.String(), "{\"data\":{\"unpaused.test\":{\"Paused\":false}}}\n") + // verify record created in db + rec, _ := db.get("unpaused.test") + testEqual(t, "get() = %+v, want %+v", *rec, Record{Paused: false}) + + // Create paused + r = httptest.NewRequest("GET", "/api/paused.test", strings.NewReader("{\"Paused\":true}")) + rctx = chi.NewRouteContext() + rctx.URLParams.Set("key", "paused.test") + w = httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + apiPutHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "application/json; charset=utf-8") + testEqual(t, "Body = %+v, want %+v", w.Body.String(), "{\"data\":{\"paused.test\":{\"Paused\":true}}}\n") + // verify record created in db + rec, _ = db.get("paused.test") + testEqual(t, "get() = %+v, want %+v", *rec, Record{Paused: true}) + + // Update + r = httptest.NewRequest("GET", "/api/unpaused.test", strings.NewReader("{\"Paused\":true}")) + rctx = chi.NewRouteContext() + rctx.URLParams.Set("key", "unpaused.test") + w = httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + apiPutHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Body = %+v, want %+v", w.Body.String(), "{\"data\":{\"unpaused.test\":{\"Paused\":true}}}\n") + // verify record updated in db + rec, _ = db.get("unpaused.test") + testEqual(t, "get() = %+v, want %+v", *rec, Record{Paused: true}) +} + +func Test_apiDeleteHandler(t *testing.T) { + db.Reset() + if err := db.put("test.test", &Record{}); err != nil { + t.Errorf("failed to put: %+v", err) + } + + r := httptest.NewRequest("DELETE", "/api/test.test", nil) + rctx := chi.NewRouteContext() + rctx.URLParams.Set("key", "test.test") + w := httptest.NewRecorder() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + apiDeleteHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 204) + // verify record deleted in db + _, err := db.get("test.test") + testEqual(t, "get() err = %+v, want %+v", err, errRecordNotFound) +} + +func Test_cssHandler(t *testing.T) { + r := httptest.NewRequest("GET", "/css/nogo.css", nil) + w := httptest.NewRecorder() + cssHandler(w, r) + testEqual(t, "Response code = %+v, want %+v", w.Code, 200) + testEqual(t, "Cache-Control header = %+v, want %+v", w.Header().Get("Cache-Control"), "public, max-age=31536000") + testEqual(t, "Content-Type header = %+v, want %+v", w.Header().Get("Content-Type"), "text/css") + testEqual(t, "Body contains 'Roboto' = %+v, want %+v", strings.Contains(w.Body.String(), "Roboto"), true) +} diff --git a/init/linux-systemd/nogo.service.example b/init/linux-systemd/nogo.service.example new file mode 100644 index 0000000..56c075a --- /dev/null +++ b/init/linux-systemd/nogo.service.example @@ -0,0 +1,25 @@ +[Unit] +Description=nogo DNS blacklist/proxy +Documentation=http://nogo.curia.solutions/ +After=network-online.target +Wants=network-online.target systemd-networkd-wait-online.service + +[Service] +Restart=on-failure +ExecStart=/usr/local/bin/nogo -db=/var/nogo/nogo.db +LimitNOFILE=32768 + +; Use private /tmp and /var/tmp +PrivateTmp=true +; Use a minimal /dev +PrivateDevices=true +; Hide /home, /root, and /run/user +ProtectHome=true +; Make /usr, /boot, /etc and possibly some more folders read-only +ProtectSystem=full +; … except /var/nogo, because we want our BoltDB file to live there. +; This merely retains r/w access rights, it does not add any new. Must still be writable on the host! +ReadWriteDirectories=/var/nogo + +[Install] +WantedBy=multi-user.target \ No newline at end of file diff --git a/main.go b/main.go new file mode 100644 index 0000000..718a78d --- /dev/null +++ b/main.go @@ -0,0 +1,166 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "runtime" + "strings" + "sync" + "syscall" + "time" + + "github.com/boltdb/bolt" + "github.com/miekg/dns" + "github.com/pressly/chi" + "github.com/pressly/chi/middleware" +) + +type DB struct { + *bolt.DB +} + +var ( + db *DB + dnsServers []*dns.Server + httpServer *http.Server + dnsClient = &dns.Client{} + blacklistKey = []byte("blacklist") + version = "undefined" + build = "undefined" + + dbPath = flag.String("db", "nogo.db", "Specify a file path for the database.") + dnsAddr = flag.String("dns-addr", ":53", "Specify an address for the DNS proxy server to listen on.") + dnsNet = flag.String("dns-net", "udp", "Specify the listener protocol(s) for the DNS proxy server to use (\"udp\", \"tcp\", or \"udp+tcp\").") + dnsProxyTo = flag.String("dns-proxyto", "8.8.8.8:53,8.8.4.4:53", "Specify one or more (comma separated) upstream DNS server addresses to proxy allowed queries to.") + blacklist = flag.String("import", "", "Specify a file path to import records to block (traditional hosts file format, or simply one domain per line).") + webAddr = flag.String("web-addr", ":8080", "Specify an address for the control panel web server to listen on.") + webOff = flag.Bool("web-off", false, "Instruct nogo not to serve the web control panel/API.") + webPasswd = flag.String("web-password", "", "Instruct the web control panel/API to require basic auth, using the specified password and a username of \"admin\".") + showVer = flag.Bool("version", false, "Show version and exit.") +) + +func init() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "nogo version %s+%s %s/%s\n", version, build, runtime.GOOS, runtime.GOARCH) + fmt.Fprintln(os.Stderr, "Copyright (c) 2017 Seth Davis") + fmt.Fprintf(os.Stderr, "http://nogo.curia.solutions/\n\n") + fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + flag.PrintDefaults() + } +} + +func main() { + var wg sync.WaitGroup + + flag.Parse() + + if *showVer { + fmt.Printf("nogo version %s+%s %s/%s\n", version, build, runtime.GOOS, runtime.GOARCH) + os.Exit(0) + } + + // Initialize the database + bdb, err := bolt.Open(*dbPath, 0600, &bolt.Options{Timeout: 2 * time.Second}) + if err != nil { + log.Fatalf("bolt.Open() Error: %s\n", err) + } + db = &DB{bdb} + defer db.Close() + + // Ensure blacklist bucket exists + if err = db.Update(func(tx *bolt.Tx) error { + // Get/create bucket + _, err := tx.CreateBucketIfNotExists(blacklistKey) + return err + }); err != nil { + log.Fatalf("CreateBucketIfNotExists(%s) Error: %s\n", blacklistKey, err) + } + + // Import a blacklist, if specified + if *blacklist != "" { + db.NoSync = true + + fmt.Println("Importing blacklist file. Please wait...") + if err := db.importBlacklist(*blacklist); err != nil { + log.Fatalf("db.importBlacklist(%s) Error: %s\n", *blacklist, err) + } + + if err := db.Sync(); err != nil { + log.Fatalf("db.Sync() Error: %s\n", err) + } + + db.NoSync = false + } + + // Initialize the HTTP router + r := chi.NewRouter() + + // Register HTTP middleware + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) + if *webPasswd != "" { + r.Use(basicAuth(*webPasswd)) + } + + // Register HTTP routes/handlers + r.Get("/", rootIndexHandler) + r.Post("/", rootCreateHandler) + r.Get("/:key", rootReadHandler) + r.Get("/api/", apiIndexHandler) + r.Get("/api/:key", apiReadHandler) + r.Put("/api/:key", apiPutHandler) + r.Delete("/api/:key", apiDeleteHandler) + r.Get("/css/nogo.css", cssHandler) + + // Initialize/start the servers + log.Println("Booting up nogo...") + + for _, n := range strings.Split(*dnsNet, "+") { + dnsServers = append(dnsServers, &dns.Server{Addr: *dnsAddr, Net: n, Handler: dns.HandlerFunc(dnsHandler)}) + + wg.Add(1) + go func(s *dns.Server) { + defer wg.Done() + if err := s.ListenAndServe(); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + log.Fatal(err) + } + }(dnsServers[len(dnsServers)-1]) + } + log.Printf("DNS proxy listening at: %s (%s)\n", *dnsAddr, *dnsNet) + + if *webOff != true { + httpServer = &http.Server{Addr: *webAddr, Handler: r} + + wg.Add(1) + go func() { + defer wg.Done() + if err := httpServer.ListenAndServe(); err != http.ErrServerClosed { + log.Fatal(err) + } + }() + log.Printf("Web control panel/API listening at: %s\n", *webAddr) + } + + // Attempt to gracefully shut down when signaled + sig := make(chan os.Signal) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + s := <-sig + fmt.Printf("Signal (%s) received, shutting down... ", s) + + if *webOff != true { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + httpServer.Shutdown(ctx) + } + for _, s := range dnsServers { + s.Shutdown() + } + + wg.Wait() // Wait on goroutines + fmt.Println("Done!") +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..5da2f53 --- /dev/null +++ b/main_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "math/rand" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" + + "github.com/boltdb/bolt" +) + +func TestMain(m *testing.M) { + db = MustOpenDB() + exitVal := m.Run() + db.MustClose() + + os.Exit(exitVal) +} + +func randInt() int { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return r.Intn(10000) +} + +func tempfilePath(prefix string) string { + var name string + dir := os.TempDir() + conflict := true + + for i := 0; i < 10000; i++ { + name = filepath.Join(dir, prefix+strconv.Itoa(randInt())) + + if _, err := os.Stat(name); os.IsNotExist(err) { + conflict = false + break + } + } + + if conflict { + panic("couldn't find a suitable tempfile path") + } + + return name +} + +func MustOpenDB() *DB { + bdb, err := bolt.Open(tempfilePath("nogo-db-"), 0666, nil) + if err != nil { + panic(err) + } + + return &DB{bdb} +} + +func (db *DB) Reset() { + db.Update(func(tx *bolt.Tx) error { + // Delete bucket + tx.DeleteBucket(blacklistKey) + return nil + }) + + if err := db.Update(func(tx *bolt.Tx) error { + // Create bucket + _, err := tx.CreateBucket(blacklistKey) + return err + }); err != nil { + panic(err) + } +} + +func (db *DB) MustClose() { + defer os.Remove(db.Path()) + + if err := db.Close(); err != nil { + panic(err) + } +} + +func testEqual(t *testing.T, msg string, args ...interface{}) bool { + if !reflect.DeepEqual(args[len(args)-2], args[len(args)-1]) { + t.Errorf(msg, args...) + return false + } + return true +} diff --git a/web.go b/web.go new file mode 100644 index 0000000..f72aa40 --- /dev/null +++ b/web.go @@ -0,0 +1,404 @@ +package main + +// nogo.css +var nogoCSS = `*, +*:after, +*:before { + box-sizing: inherit; +} + +html { + box-sizing: border-box; + font-size: 62.5%; +} + +body { + color: #606c76; + font-family: 'Roboto', 'Helvetica Neue', 'Helvetica', 'Arial', sans-serif; + font-size: 1.6em; + font-weight: 300; + letter-spacing: .01em; + line-height: 1.6; +} + +input[type='search'], +input[type='text'] { + -webkit-appearance: none; + -moz-appearance: none; + appearance: none; + background-color: transparent; + border: 0.1rem solid #d1d1d1; + border-radius: .4rem; + box-shadow: none; + box-sizing: inherit; + height: 3.8rem; + padding: .6rem 1.0rem; + width: 100%; +} + +input[type='search']:focus, +input[type='text']:focus { + border-color: #9b4dca; + outline: 0; +} + +label { + display: block; + font-size: 1.6rem; + font-weight: 700; + margin-bottom: .5rem; +} + +.container { + margin: 0 auto; + max-width: 112.0rem; + padding: 0 2.0rem; + position: relative; + width: 100%; +} + +.row { + display: flex; + flex-direction: column; + padding: 0; + width: 100%; +} + +.row .column { + display: block; + flex: 1 1 auto; + margin-left: 0; + max-width: 100%; + width: 100%; +} + +@media (min-width: 40rem) { + .row { + flex-direction: row; + margin-left: -1.0rem; + width: calc(100% + 2.0rem); + } + .row .column { + margin-bottom: inherit; + padding: 0 1.0rem; + } +} + +a { + color: #9b4dca; + text-decoration: none; +} + +a:focus, a:hover { + color: #606c76; +} + +fieldset, +input { + margin-bottom: 1.5rem; +} + +.text-right { text-align: right; } + +#header { + height: 69px; + background-color: #2f2f2f; + text-align: center; +} + +#header a { + font-weight: bold; + font-size: 40px; + letter-spacing: 2px; +} + +#main { + min-height: calc(100vh - 99px); + padding-bottom: 1rem; +} + +#inputs { padding-top: 15px; } + +#inputs form { margin-bottom: 5px; } + +#records-header > div { margin-bottom: 15px; } + +#records-header, +.row.record { + display: inline-flex; + flex-direction: row; + flex-wrap: nowrap; + white-space: nowrap; +} + +#back { flex: 0 0 0%; } + +.row.record:hover { background-color: #eee; } + +.column.actions { + flex: 0; + padding-right: 0; +} + +.column.key { padding-left: 0; } + +.actions form { + display: inline-block; + margin: 0; + padding: 0; +} + +.actions form.hide { display: none; } + +.actions button.icon { + vertical-align: text-top; + width: 16px; + height: 16px; + padding: 0; + margin: 0 10px 0 0; + border: none; + color: inherit; + background: none; + cursor: pointer; +} + +.actions button.icon:focus { + outline: 0; +} + +.actions button.icon-pause { + background: url() left bottom no-repeat; +} + +.actions button.icon-resume { + background: url() left bottom no-repeat; +} + +.actions button.icon-trash { + background: url() left bottom no-repeat; +} + +#footer { + height: 30px; + background-color: #2f2f2f; + padding-top: 3px; + text-align: center; + font-size: 15px; +} + +#footer a { + color: #fdfdfd; + font-weight: bold; + letter-spacing: 1px; +} + +@media (min-width: 400px) { + .actions button.icon { margin: 0 20px 0 0; } +} + +@media (min-width: 640px) { + .actions button.icon { margin: 0 25px 0 0; } +} + +@media (min-width: 960px) { + .column.actions { + order: 1; + padding: 0 1rem 0 0; + } + + .column.key { + overflow: hidden; + padding: 0 0 0 1rem; + } + + .actions button.icon { margin: 0 0 0 30px; } +} + +@media (min-width: 1122px) { + #header { border-radius: 0 0 5px 5px; } + + #footer { border-radius: 5px 5px 0 0; } +}` + +// index.html template string +var indexTmpl = ` + + + nogo + + + + + + + + + + + + + + + +
+
+
+
+ + +
+
+
+
+ + +
+
+
+ +
+ {{- if or .data .q .p }} +
+ « Back +
+ {{- else }} + + {{ end }} +
+ {{- if or .data .q .p }} + {{ if or .q }}Found {{ end }}{{ len .data }} of {{ .total_count }} total records. + {{- else }} + {{ .total_count }} total records. + {{- end }} +
+
+ + {{- range $k, $v := .data }} +
+
+ + + +
+ + + +
+
+
{{ $k }}
+
+ {{- end }} +
+ + + + + +`