Skip to content

Commit

Permalink
log shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
paulgmiller committed Dec 22, 2023
1 parent e356ea7 commit 474a1a6
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 163 deletions.
30 changes: 5 additions & 25 deletions cmd/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ Copyright © 2023 NAME HERE <EMAIL ADDRESS>
package cmd

import (
"encoding/json"
"log"
"net"

"github.com/paulgmiller/wg-sync/udpjoin"
"github.com/spf13/cobra"
)

Expand All @@ -31,22 +30,22 @@ func add(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}*/
jreq := joinRequest{
jreq := udpjoin.Request{
PublicKey: "DEADBEEFDEADBEEF", //d0.PublicKey.String(),
AuthToken: "TOTALLYSECRET",
}

resp, err := send(jreq)
resp, err := udpjoin.Send(joinServer, jreq)
if err != nil {
return err
}
log.Printf("got %s", resp.Assignedip)

jreq2 := joinRequest{
jreq2 := udpjoin.Request{
PublicKey: "amMRWDvsLUmNHn52xer2yl/UaAkXnDrd/HxUTRkEGXc=", //d0.PublicKey.String(),
AuthToken: "TOTALLYSECRET",
}
resp, err = send(jreq2)
resp, err = udpjoin.Send(joinServer, jreq2)
if err != nil {
return err
}
Expand All @@ -55,25 +54,6 @@ func add(cmd *cobra.Command, args []string) error {

}

func send(jReq joinRequest) (joinResponse, error) {

conn, err := net.Dial("udp", joinServer)
if err != nil {
return joinResponse{}, err
}
log.Printf("dialing %s, %s", joinServer, conn.LocalAddr().String())
defer conn.Close()
err = json.NewEncoder(conn).Encode(jReq)
if err != nil {
return joinResponse{}, err
}

var jResp joinResponse
err = json.NewDecoder(conn).Decode(&jResp)

return jResp, err
}

/* old and busted
resp, err := http.Get(cfgFile)
Expand Down
150 changes: 12 additions & 138 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,19 @@ package cmd

import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"

"github.com/paulgmiller/wg-sync/nethelpers"
"github.com/paulgmiller/wg-sync/pretty"
"github.com/paulgmiller/wg-sync/token"
"github.com/paulgmiller/wg-sync/udpjoin"
"github.com/paulgmiller/wg-sync/wghelpers"
"github.com/samber/lo"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

const defaultJoinPort = ":5000"
Expand All @@ -38,13 +32,20 @@ func init() {
//probably have to pass in public ip and maye cidr?
}

type cidrAllocatorImpl struct{}

func (c cidrAllocatorImpl) Allocate() (net.IP, error) {
return net.ParseIP("10.0.0.100"), nil
}

func serve(cmd *cobra.Command, args []string) error {
mux := http.NewServeMux()
mux.HandleFunc("/peers", Peers)
mux.Handle("/token", token.New())

srv := http.Server{Addr: ":8888", Handler: mux}

//todo gracefully shut both servers down.
ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM)
defer stop()

Expand All @@ -55,10 +56,13 @@ func serve(cmd *cobra.Command, args []string) error {
}
}()

err := HaddleJoins(ctx, cidrAllocatorImpl{})
err := udpjoin.New().HaddleJoins(ctx, cidrAllocatorImpl{})
if err != nil {
log.Printf("udp handler exited with %s", err)
}
log.Printf("up and seving")
<-ctx.Done()
log.Printf("got term signal")

shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand All @@ -70,136 +74,6 @@ func serve(cmd *cobra.Command, args []string) error {
return err
}

type joinRequest struct {
PublicKey string
AuthToken string
}

type joinResponse struct {
Assignedip string
Peers []pretty.Peer
}

type cidrAllocator interface {
Allocate() (net.IP, error)
}

type cidrAllocatorImpl struct{}

func (c cidrAllocatorImpl) Allocate() (net.IP, error) {
return net.ParseIP("10.0.0.100"), nil
}

var lock sync.Mutex

func HaddleJoins(ctx context.Context, alloc cidrAllocator) error {
udpaddr, err := net.ResolveUDPAddr("udp", "127.0.0.1"+defaultJoinPort)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpaddr)
if err != nil {
return err
}
log.Printf("Waiting for joins on %s", udpaddr.String())
go func() {
for {
buf := make([]byte, 4096) //how big should we be? will we go over multiple packets?
n, remoteAddr, err := conn.ReadFromUDP(buf) //has to be this ratehr than desrialize because we need the remote addr or we get write: destination address required
if err != nil {
if !errors.Is(err, net.ErrClosed) {
log.Printf("Failed to read from udp: %s", err)
}
return
}
// Deserialize the JSON data into a Message struct
var jreq joinRequest
err = json.Unmarshal(buf[:n], &jreq)
if err != nil {
log.Printf("Failed to unmarshal: %s, %s", buf, err)

continue
}

//obviously bad.
if jreq.AuthToken != "HOKEYPOKEYSMOKEY" {
log.Printf("bad auth token from %v, %s", remoteAddr, jreq.PublicKey)
//ban them for a extended period?
continue
}

log.Printf("got join request from %v, %s", remoteAddr, jreq.PublicKey)
jResp, err := GenerateResponse(jreq, alloc)
if err != nil {
log.Printf("Failed to generate response %s", err)
//ban them for a extended period?
continue
}

respbuf, err := json.Marshal(jResp)
if err != nil {
log.Printf("Failed to enode: %s", err)
continue
}
_, err = conn.WriteToUDP(respbuf, remoteAddr)
if err != nil {
log.Printf("Failed to send: %s, %s", buf, err)
continue
}

}
}()
<-ctx.Done()
conn.Close()
log.Println("Listener closed")
return nil

}

func GenerateResponse(jreq joinRequest, alloc cidrAllocator) (joinResponse, error) {
lock.Lock()
defer lock.Unlock()

d0, err := wghelpers.GetDevice()
if err != nil {
return joinResponse{}, err
}

var asssignedip string
existing, found := lo.Find(d0.Peers, func(p wgtypes.Peer) bool { return p.PublicKey.String() == jreq.PublicKey })
if found { //should we also check that the ip is the same?
log.Printf("peer %s already exists", jreq.PublicKey)
asssignedip = existing.AllowedIPs[0].String()
} else {
ip, err := alloc.Allocate()
if err != nil {
//not nice to not tell them sorry? But then we need an error protocol
return joinResponse{}, err
}
asssignedip = ip.String()
}

//ad the peer to us before we return anything

cidr, err := nethelpers.GetWireGaurdCIDR(d0.Name)
if err != nil {
return joinResponse{}, err
}

//ip, cinet.ParseCIDR(cidr.String())

return joinResponse{
Assignedip: asssignedip,
Peers: []pretty.Peer{
{
PublicKey: d0.PublicKey.String(),
AllowedIPs: cidr.String(), //too much throttle down to /32?
Endpoint: fmt.Sprintf("%s:%d", nethelpers.GetOutboundIP(), d0.ListenPort), //just pass this in instead of trying to detect it?
},
},
}, nil
}

func Peers(resp http.ResponseWriter, req *http.Request) {
d0, err := wghelpers.GetDevice()
if err != nil {
Expand Down
Loading

0 comments on commit 474a1a6

Please sign in to comment.