From 47cca64114b9da855fb059db34aeda660d1cf57e Mon Sep 17 00:00:00 2001 From: Carl Kittelberger Date: Sun, 17 Mar 2024 05:51:44 +0100 Subject: [PATCH] Working state. Will probably squash or split this into proper commits later. --- cmd/storage-discover/main.go | 42 +++- cmd/storage/.gitattributes | 2 + .../Icedream - Whiplash (Radio Edit).m4a | 2 +- ...dream - Whiplash (Radio Edit).m4a.beatgrid | 3 + ...dream - Whiplash (Radio Edit).m4a.waveform | 3 + cmd/storage/demo.go | 61 ++++- cmd/storage/engine_library_service_server.go | 73 ++++-- cmd/storage/http.go | 40 +++- cmd/storage/main.go | 127 +++++------ cmd/storage/network_trust_service_server.go | 13 +- eaas/announcer.go | 158 ------------- eaas/announcer_configuration.go | 22 -- eaas/beacon.go | 212 ++++++++++++++++++ eaas/beacon_configuration.go | 35 +++ eaas/listener.go | 2 +- go.mod | 2 +- internal/socket/addr_ip.go | 18 ++ internal/socket/addr_net.go | 14 ++ internal/socket/addr_port_go1.18.go | 22 ++ internal/socket/addr_port_legacy.go | 24 ++ internal/socket/addr_port_test.go | 19 ++ internal/socket/broadcast.go | 26 +++ .../{util_test.go => broadcast_test.go} | 11 - internal/socket/reuse_port_nonwindows.go | 14 -- internal/socket/sockopts.go | 9 + internal/socket/sockopts_nonwindows.go | 23 ++ ...se_port_windows.go => sockopts_windows.go} | 0 internal/socket/util.go | 42 ---- 28 files changed, 650 insertions(+), 369 deletions(-) create mode 100644 cmd/storage/Icedream - Whiplash (Radio Edit).m4a.beatgrid create mode 100644 cmd/storage/Icedream - Whiplash (Radio Edit).m4a.waveform delete mode 100644 eaas/announcer.go delete mode 100644 eaas/announcer_configuration.go create mode 100644 eaas/beacon.go create mode 100644 eaas/beacon_configuration.go create mode 100644 internal/socket/addr_ip.go create mode 100644 internal/socket/addr_net.go create mode 100644 internal/socket/addr_port_go1.18.go create mode 100644 internal/socket/addr_port_legacy.go create mode 100644 internal/socket/addr_port_test.go rename internal/socket/{util_test.go => broadcast_test.go} (78%) delete mode 100644 internal/socket/reuse_port_nonwindows.go create mode 100644 internal/socket/sockopts.go create mode 100644 internal/socket/sockopts_nonwindows.go rename internal/socket/{reuse_port_windows.go => sockopts_windows.go} (100%) delete mode 100644 internal/socket/util.go diff --git a/cmd/storage-discover/main.go b/cmd/storage-discover/main.go index b3d7f4d..2cd6e4e 100644 --- a/cmd/storage-discover/main.go +++ b/cmd/storage-discover/main.go @@ -5,6 +5,8 @@ import ( "crypto/ed25519" "crypto/rand" "crypto/x509" + "encoding/json" + "errors" "flag" "io" "log" @@ -107,6 +109,11 @@ func main() { runEngineLibraryUI(grpcURL) } +func marshalJSON(v any) []byte { + s, _ := json.Marshal(v) + return s +} + func runEngineLibraryUI(grpcURL string) { ctx := context.Background() connection, err := eaas.DialContext(ctx, grpcURL) @@ -139,23 +146,42 @@ func runEngineLibraryUI(grpcURL string) { if err != nil { panic(err) } - var pageSize uint32 = 100 + var pageSize uint32 = 25 + getTracksResp, err := connection.GetTracks(ctx, &enginelibrary.GetTracksRequest{ + PageSize: &pageSize, + }) + if err != nil { + panic(err) + } + for _, track := range getTracksResp.GetTracks() { + log.Printf("Track: %s", string(marshalJSON(track))) + getTrackResp, err := connection.GetTrack(ctx, &enginelibrary.GetTrackRequest{ + TrackId: track.GetMetadata().Id, + }) + if err != nil { + log.Println("\tfailed to GetTrack on this track") + continue + } + log.Printf("\t%+v", getTrackResp) + } for _, playlist := range getLibraryResp.GetPlaylists() { - log.Printf("Playlist %q (%q)", playlist.GetTitle(), playlist.GetListType()) - + log.Printf("Playlist: %s", string(marshalJSON(playlist))) getTracksResp, err := connection.GetTracks(ctx, &enginelibrary.GetTracksRequest{ PlaylistId: playlist.Id, PageSize: &pageSize, }) + if errors.Is(err, io.EOF) { + // BUG - empty playlist causes EOF, reconnect + connection, err = eaas.DialContext(ctx, grpcURL) + if err != nil { + panic(err) + } + } if err != nil { panic(err) } for _, track := range getTracksResp.GetTracks() { - metadata := track.GetMetadata() - if metadata == nil { - continue - } - log.Printf("\tTrack %s", metadata.String()) + log.Printf("\tTrack: ID %s", track.GetMetadata().GetId()) } } } diff --git a/cmd/storage/.gitattributes b/cmd/storage/.gitattributes index 4b0a82f..d427248 100644 --- a/cmd/storage/.gitattributes +++ b/cmd/storage/.gitattributes @@ -1 +1,3 @@ *.m4a filter=lfs diff=lfs merge=lfs -text +*.beatgrid filter=lfs diff=lfs merge=lfs -text +*.waveform filter=lfs diff=lfs merge=lfs -text diff --git a/cmd/storage/Icedream - Whiplash (Radio Edit).m4a b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a index 1b5dafe..f5ea7be 100644 --- a/cmd/storage/Icedream - Whiplash (Radio Edit).m4a +++ b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5407491f594edd312074196d814f00cea0cd5737ad440b5ddd02f2b1511c8f81 +oid sha256:1ecca0471c006231ede7f8f2152104cee47597908e9f54c3d0298882bf95f6c5 size 4158124 diff --git a/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.beatgrid b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.beatgrid new file mode 100644 index 0000000..a77c832 --- /dev/null +++ b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.beatgrid @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fba7a60b354a661ffcdc93679d5dd6e3acf700f63f0a3e051e1d2e89b43914f5 +size 59 diff --git a/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.waveform b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.waveform new file mode 100644 index 0000000..af2e6a5 --- /dev/null +++ b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.waveform @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81342152f52d1707dbbd4d95b8ea9dd5108a1f4529e88e216230b529211b236b +size 2856 diff --git a/cmd/storage/demo.go b/cmd/storage/demo.go index 3b3604d..900682a 100644 --- a/cmd/storage/demo.go +++ b/cmd/storage/demo.go @@ -4,40 +4,81 @@ import ( "bytes" _ "embed" "fmt" + "path/filepath" "strconv" + "strings" "github.com/dhowden/tag" "github.com/google/uuid" + "github.com/icedream/go-stagelinq/eaas" "github.com/icedream/go-stagelinq/eaas/proto/enginelibrary" "google.golang.org/protobuf/types/known/timestamppb" ) +// Imports needed for image resizing (see commented out code for it) +// import ( +// "image" +// _ "image/jpeg" +// _ "image/png" +// ) + var ( - demoTrackFileName = "Icedream - Whiplash (Radio Edit).flac" + demoTrackFileName = "Icedream - Whiplash (Radio Edit).m4a" demoLibrary = "12eceaa2-f81a-4b63-b196-94648a3bdd95" demoLibraryName = "Demo Library" demoPlaylist = "55ab0c7c-6c35-429a-81d0-25b039a34a9f" demoPlaylistName = "Demo Playlist" demoPlaylistTrackCount uint32 = 1 - demoTrackIDs []string - demoTrackURL = "/demo/" + demoTrackFileName - demoTrackLength = uint32(len(demoTrackBytes)) - demoTrackMetadata enginelibrary.TrackMetadata - demoTrackArtwork []byte + demoTrackIDs = []string{ + "1 " + demoLibrary, + } + // HACK - imitating original Engine DJ software behavior by using Windows paths + demoTrackURL = filepath.Join("C:", "demo", demoTrackFileName) + // HACK - imitating original Engine DJ software behavior by adding brackets. + demoTrackURLGRPC = fmt.Sprintf("<%s>", filepath.ToSlash(demoTrackURL)) + demoTrackLength = uint32(len(demoTrackBytes)) + demoTrackMetadata enginelibrary.TrackMetadata + demoTrackArtwork []byte + demoToken eaas.Token = eaas.Token{ + 0x5e, 0xff, 0xae, 0x59, 0x12, 0x88, 0x29, 0x30, + 0xde, 0xad, 0xc0, 0xde, 0xc0, 0xff, 0xee, 0x00, + } ) //go:embed "Icedream - Whiplash (Radio Edit).m4a" var demoTrackBytes []byte +//go:embed "Icedream - Whiplash (Radio Edit).m4a.beatgrid" +var demoBeatGrid []byte + +//go:embed "Icedream - Whiplash (Radio Edit).m4a.waveform" +var demoOverviewWaveform []byte + +var demoTrackPreviewArtwork []byte + func init() { - for i := 0; i < int(demoPlaylistTrackCount); i++ { - demoTrackIDs = append(demoTrackIDs, uuid.New().String()) + if len(demoTrackIDs) == 0 { + for i := 0; i < int(demoPlaylistTrackCount); i++ { + demoTrackIDs = append(demoTrackIDs, uuid.New().String()) + } } demoTrackMetadata.DateAdded = timestamppb.Now() if metadata, err := tag.ReadFrom(bytes.NewReader(demoTrackBytes)); err == nil { if metadata.Picture() != nil { demoTrackArtwork = metadata.Picture().Data + demoTrackPreviewArtwork = demoTrackArtwork + // // If you wanna be nice to the hardware, you can have the server + // // shrink down the artwork. I don't think even the original Engine + // // DJ software does that though. + // img, _, err := image.Decode(bytes.NewReader(demoTrackArtwork)) + // if err == nil { + // img = resize.Resize(240, 240, img, resize.Lanczos2) + // } + // var b bytes.Buffer + // if err := jpeg.Encode(&b, img, &jpeg.Options{Quality: 70}); err == nil { + // demoTrackPreviewArtwork = b.Bytes() + // } } if v := metadata.Artist(); len(v) > 0 { demoTrackMetadata.Artist = &v @@ -64,19 +105,21 @@ func init() { } if v, ok := metadata.Raw()["KEY"]; ok { s := fmt.Sprint(v) + s = strings.Trim(s, "\x00 ") demoTrackMetadata.Key = &s } if v, ok := metadata.Raw()["LABEL"]; ok { s := fmt.Sprint(v) + s = strings.Trim(s, "\x00 ") demoTrackMetadata.Label = &s } if v, ok := metadata.Raw()["REMIXER"]; ok { s := fmt.Sprint(v) + s = strings.Trim(s, "\x00 ") demoTrackMetadata.Remixer = &s } if v := uint32(metadata.Year()); v > 0 { demoTrackMetadata.Year = &v } - } } diff --git a/cmd/storage/engine_library_service_server.go b/cmd/storage/engine_library_service_server.go index 11988ee..7323894 100644 --- a/cmd/storage/engine_library_service_server.go +++ b/cmd/storage/engine_library_service_server.go @@ -4,6 +4,7 @@ import ( "context" _ "embed" "fmt" + "log" "math" "strconv" @@ -15,31 +16,40 @@ import ( var _ enginelibrary.EngineLibraryServiceServer = &EngineLibraryServiceServer{} +// EngineLibraryServiceServer is an example library service server +// implementation. +// +// It will provide a the demo audio file as if contained in a library with +// playlists. Some functions not needed for the task are left unimplemented. type EngineLibraryServiceServer struct { enginelibrary.UnimplementedEngineLibraryServiceServer } // EventStream implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) EventStream(ctx context.Context, req *enginelibrary.EventStreamRequest) (*enginelibrary.EventStreamResponse, error) { + log.Printf("EventStream: %+v", req) return &enginelibrary.EventStreamResponse{ Event: []*enginelibrary.Event{}, }, nil } // GetCredentials implements enginelibrary.EngineLibraryServiceServer. -func (e *EngineLibraryServiceServer) GetCredentials(context.Context, *enginelibrary.GetCredentialsRequest) (*enginelibrary.GetCredentialsResponse, error) { +func (e *EngineLibraryServiceServer) GetCredentials(ctx context.Context, req *enginelibrary.GetCredentialsRequest) (*enginelibrary.GetCredentialsResponse, error) { + log.Printf("GetCredentials: %+v", req) panic("unimplemented") } // GetHistoryPlayedTracks implements enginelibrary.EngineLibraryServiceServer. -func (e *EngineLibraryServiceServer) GetHistoryPlayedTracks(context.Context, *enginelibrary.GetHistoryPlayedTracksRequest) (*enginelibrary.GetHistoryPlayedTracksResponse, error) { +func (e *EngineLibraryServiceServer) GetHistoryPlayedTracks(ctx context.Context, req *enginelibrary.GetHistoryPlayedTracksRequest) (*enginelibrary.GetHistoryPlayedTracksResponse, error) { + log.Printf("GetHistoryPlayedTracks: %+v", req) return &enginelibrary.GetHistoryPlayedTracksResponse{ Tracks: []*enginelibrary.HistoryPlayedTrack{}, }, nil } // GetHistorySessions implements enginelibrary.EngineLibraryServiceServer. -func (e *EngineLibraryServiceServer) GetHistorySessions(context.Context, *enginelibrary.GetHistorySessionsRequest) (*enginelibrary.GetHistorySessionsResponse, error) { +func (e *EngineLibraryServiceServer) GetHistorySessions(ctx context.Context, req *enginelibrary.GetHistorySessionsRequest) (*enginelibrary.GetHistorySessionsResponse, error) { + log.Printf("GetHistorySessions: %+v", req) return &enginelibrary.GetHistorySessionsResponse{ Sessions: []*enginelibrary.HistorySession{}, }, nil @@ -47,6 +57,7 @@ func (e *EngineLibraryServiceServer) GetHistorySessions(context.Context, *engine // GetLibraries implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetLibraries(ctx context.Context, req *enginelibrary.GetLibrariesRequest) (*enginelibrary.GetLibrariesResponse, error) { + log.Printf("GetLibraries: %+v", req) return &enginelibrary.GetLibrariesResponse{ Libraries: []*enginelibrary.Library{ { @@ -59,8 +70,9 @@ func (e *EngineLibraryServiceServer) GetLibraries(ctx context.Context, req *engi // GetLibrary implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetLibrary(ctx context.Context, req *enginelibrary.GetLibraryRequest) (*enginelibrary.GetLibraryResponse, error) { + log.Printf("GetLibrary: %+v", req) switch req.GetLibraryId() { - case demoLibrary: + case "", demoLibrary: return &enginelibrary.GetLibraryResponse{ Playlists: []*enginelibrary.PlaylistMetadata{ { @@ -79,6 +91,7 @@ func (e *EngineLibraryServiceServer) GetLibrary(ctx context.Context, req *engine // GetSearchFilters implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetSearchFilters(ctx context.Context, req *enginelibrary.GetSearchFiltersRequest) (*enginelibrary.GetSearchFiltersResponse, error) { + log.Printf("GetSearchFilters: %+v", req) resp := &enginelibrary.GetSearchFiltersResponse{ SearchFilters: &enginelibrary.SearchFilterOptions{}, } @@ -143,26 +156,40 @@ func generateDemoTrackMetadata(trackID string) *enginelibrary.TrackMetadata { return &metadata } +var unsetFloat64 float64 = -1 + // GetTrack implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetTrack(ctx context.Context, req *enginelibrary.GetTrackRequest) (*enginelibrary.GetTrackResponse, error) { - if req.GetLibraryId() != demoLibrary && req.LibraryId != nil { + log.Printf("GetTrack: %+v", req) + if len(req.GetLibraryId()) != 0 && req.GetLibraryId() != demoLibrary { return nil, status.Error(codes.NotFound, "library not found") } for _, trackID := range demoTrackIDs { + metadata := generateDemoTrackMetadata(trackID) if trackID == req.GetTrackId() { - return &enginelibrary.GetTrackResponse{ + resp := &enginelibrary.GetTrackResponse{ Blob: &enginelibrary.TrackBlob{ Type: &enginelibrary.TrackBlob_Url{ Url: &enginelibrary.TrackBlobUrl{ - Url: &demoTrackURL, + Url: &demoTrackURLGRPC, FileSize: &demoTrackLength, }, }, }, - Metadata: generateDemoTrackMetadata(trackID), - PerformanceData: nil, // TODO - }, nil + Metadata: generateDemoTrackMetadata(trackID), + PerformanceData: &enginelibrary.TrackPerformanceData{ + Bpm: metadata.Bpm, + BeatGrid: demoBeatGrid, + MainCue: &enginelibrary.MainCue{ + Position: &unsetFloat64, + InitialPosition: &unsetFloat64, + }, + OverviewWaveform: demoOverviewWaveform, + }, + } + log.Printf("=> Found demo track ID: %+v", resp) + return resp, nil } } @@ -171,21 +198,31 @@ func (e *EngineLibraryServiceServer) GetTrack(ctx context.Context, req *engineli // GetTracks implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetTracks(ctx context.Context, req *enginelibrary.GetTracksRequest) (*enginelibrary.GetTracksResponse, error) { - switch req.GetLibraryId() { - case "", demoLibrary: + log.Printf("GetTracks: %+v", req) + switch { + case req.GetPlaylistId() == demoPlaylist: // specific playlist resp := &enginelibrary.GetTracksResponse{ Tracks: []*enginelibrary.ListTrack{}, } for _, trackID := range demoTrackIDs { resp.Tracks = append(resp.Tracks, &enginelibrary.ListTrack{ Metadata: generateDemoTrackMetadata(trackID), - PreviewArtwork: demoTrackArtwork, + PreviewArtwork: demoTrackPreviewArtwork, }) } - return &enginelibrary.GetTracksResponse{ + return resp, nil + case req.GetLibraryId() == "" || req.GetLibraryId() == demoLibrary: // specific or default library + resp := &enginelibrary.GetTracksResponse{ Tracks: []*enginelibrary.ListTrack{}, - }, nil - default: + } + for _, trackID := range demoTrackIDs { + resp.Tracks = append(resp.Tracks, &enginelibrary.ListTrack{ + Metadata: generateDemoTrackMetadata(trackID), + PreviewArtwork: demoTrackPreviewArtwork, + }) + } + return resp, nil + default: // neither playlist nor library match return &enginelibrary.GetTracksResponse{ Tracks: []*enginelibrary.ListTrack{}, }, nil @@ -194,11 +231,13 @@ func (e *EngineLibraryServiceServer) GetTracks(ctx context.Context, req *enginel // PutEvents implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) PutEvents(ctx context.Context, req *enginelibrary.PutEventsRequest) (*enginelibrary.PutEventsResponse, error) { + log.Printf("PutEvents: %+v", req) return &enginelibrary.PutEventsResponse{}, nil } // SearchTracks implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) SearchTracks(ctx context.Context, req *enginelibrary.SearchTracksRequest) (*enginelibrary.SearchTracksResponse, error) { + log.Printf("SearchTracks: %+v", req) resp := &enginelibrary.SearchTracksResponse{ Tracks: []*enginelibrary.ListTrack{}, } @@ -261,7 +300,7 @@ trackLoop: } resp.Tracks = append(resp.Tracks, &enginelibrary.ListTrack{ Metadata: metadata, - PreviewArtwork: demoTrackArtwork, + PreviewArtwork: demoTrackPreviewArtwork, }) } return resp, nil diff --git a/cmd/storage/http.go b/cmd/storage/http.go index 4f3bb45..02fcc3e 100644 --- a/cmd/storage/http.go +++ b/cmd/storage/http.go @@ -3,7 +3,9 @@ package main import ( "bytes" "io" + "log" "net/http" + "net/url" "strconv" "github.com/gorilla/mux" @@ -11,20 +13,46 @@ import ( func newHTTPServiceHandler() http.Handler { r := mux.NewRouter() - r.Get("/download/{path}").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Use( + func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("Mux handling: %s %s", r.Method, r.URL.String()) + h.ServeHTTP(w, r) + }) + }, + ) + r.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("Mux not found: %s %s", r.Method, r.URL.String()) + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("Not found

Not found

")) + }) + r.UseEncodedPath() + r.SkipClean(true) + r.HandleFunc("/download/{path}", func(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) - requestedPath := vars["path"] - if requestedPath != demoTrackURL { + requestedPath, err := url.PathUnescape(vars["path"]) + if err != nil { + log.Println("HTTP: Download, bad path:", requestedPath) + w.Write([]byte("Badly encoded download path

Badly encoded download path

")) + w.WriteHeader(http.StatusBadRequest) + return + } + if requestedPath != demoTrackURLGRPC { + w.Write([]byte("Download not found

Download not found

")) + log.Println("HTTP: Download, not found:", requestedPath) w.WriteHeader(http.StatusNotFound) return } + log.Println("HTTP: Download, OK:", requestedPath) + w.Header().Set("Content-type", "application/octet-stream") w.Header().Set("Content-length", strconv.Itoa(len(demoTrackBytes))) w.WriteHeader(http.StatusOK) f := bytes.NewReader(demoTrackBytes) io.Copy(w, f) - }) - r.Get("/ping").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + }).Methods(http.MethodGet) + r.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + log.Println("HTTP: Ping") w.WriteHeader(http.StatusOK) - }) + }).Methods(http.MethodGet) return r } diff --git a/cmd/storage/main.go b/cmd/storage/main.go index a12738e..a3461a2 100644 --- a/cmd/storage/main.go +++ b/cmd/storage/main.go @@ -1,14 +1,9 @@ package main import ( - "bytes" "context" "crypto/rand" - "encoding/binary" - "encoding/hex" - "errors" "fmt" - "io" "log" "net" "net/http" @@ -17,6 +12,7 @@ import ( "syscall" "time" + "github.com/icedream/go-stagelinq/eaas" "github.com/icedream/go-stagelinq/eaas/proto/enginelibrary" "github.com/icedream/go-stagelinq/eaas/proto/networktrust" "google.golang.org/grpc" @@ -28,25 +24,22 @@ const ( timeout = 5 * time.Second ) +var hostname string + +func init() { + var err error + hostname, err = os.Hostname() + if err != nil { + hostname = "eaas-demoserver" + } +} + func main() { var token [16]byte if _, err := rand.Read(token[:]); err != nil { panic(err) } - // listener, err := stagelinq.ListenWithConfiguration(&stagelinq.ListenerConfiguration{ - // DiscoveryTimeout: timeout, - // SoftwareName: appName, - // SoftwareVersion: appVersion, - // Name: "testing", - // }) - // if err != nil { - // panic(err) - // } - // defer listener.Close() - - // listener.AnnounceEvery(time.Second) - ctx := context.TODO() ctx, cancel := context.WithCancel(ctx) @@ -68,22 +61,22 @@ func main() { }() // set up grpc listener + grpcPort := eaas.DefaultEAASGRPCPort grpcListener, err := net.ListenTCP("tcp", &net.TCPAddr{ - IP: net.IPv4(0, 0, 0, 0), - Port: 50010, + Port: int(grpcPort), }) if err != nil { panic(err) } enginelibrary.RegisterEngineLibraryServiceServer(grpcServer, &EngineLibraryServiceServer{}) - networktrust.RegisterNetworkTrustServiceServer(grpcServer, &NetworkTrustServer{}) + networktrust.RegisterNetworkTrustServiceServer(grpcServer, &NetworkTrustServiceServer{}) go func() { log.Println("Listening on GRPC") _ = grpcServer.Serve(grpcListener) }() // set up http listener - s.Addr = ":50020" + s.Addr = fmt.Sprintf(":%d", eaas.DefaultEAASHTTPPort) s.Handler = newHTTPServiceHandler() go func() { log.Println("Listening on HTTP") @@ -91,63 +84,51 @@ func main() { }() // listen for broadcasts - udpListener, err := net.ListenUDP("udp", &net.UDPAddr{ - IP: net.IPv4(255, 255, 255, 255), - Port: 11224, + // interfaces, err := net.Interfaces() + // if err != nil { + // panic(err) + // } + // for _, intf := range interfaces { + // addrs, err := intf.Addrs() + // if err != nil { + // continue + // } + // for _, addr := range addrs { + // var ip net.IP + // switch v := addr.(type) { + // case *net.IPAddr: + // ip = v.IP + // case *net.IPNet: + // ip = v.IP + // default: + // continue + // } + // if ip.To4() == nil { + // continue + // } + // if ip.IsLinkLocalUnicast() { + // // skip zero-conf IPv4 and link-local IPv6 addresses requiring definition which interface to bind to + // continue + // } + // // ip := net.IPv4bcast + log.Println("Beacon starting") + beacon, err := eaas.StartBeaconWithConfiguration(&eaas.BeaconConfiguration{ + Name: hostname, + SoftwareVersion: appVersion, + GRPCPort: grpcPort, + Token: demoToken, }) if err != nil { panic(err) } - udpC := make(chan *net.UDPAddr, 2) - go func() { - b := make([]byte, 6) - for { - n, addr, err := udpListener.ReadFromUDP(b) - if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { - return - } - if err != nil { - log.Println("UDP error, ignoring:", err) - continue - } - if n != 6 { - log.Println("UDP message too short, ignoring") - continue - } - if !bytes.Equal(b, eaasMagic) { - log.Println("UDP broadcast invalid, ignoring") - continue - } - udpC <- addr - } - }() - hostname, err := os.Hostname() - if err != nil { - hostname = "demo" - } - go func() { - log.Println("Listening on UDP") - for { - select { - case addr := <-udpC: - msg := new(bytes.Buffer) - msg.Write(eaasResponseMagic) - msg.Write(token[:]) - messages.WriteNetworkString(msg, hostname) - uri := fmt.Sprintf("grpc://%s:%d", "192.168.188.120", 50010) - binary.Write(msg, binary.BigEndian, uint32(len(uri))) - messages.WriteNetworkString(msg, appVersion) - msg.Write([]byte{0, 0, 0, 2, 0, 0x5f}) // TODO - b := msg.Bytes() - log.Println("Sending UDP beacon\n", hex.Dump(b)) - udpListener.WriteToUDP(b, addr) - case <-ctx.Done(): - _ = udpListener.Close() - return - } - } + defer func() { + log.Println("Beacon shutting down") + beacon.Shutdown() }() + // } + // } // wait for interrupt/term + log.Println("Running") <-ctx.Done() } diff --git a/cmd/storage/network_trust_service_server.go b/cmd/storage/network_trust_service_server.go index 60418c7..7e66274 100644 --- a/cmd/storage/network_trust_service_server.go +++ b/cmd/storage/network_trust_service_server.go @@ -2,21 +2,22 @@ package main import ( "context" - "time" + "log" "github.com/icedream/go-stagelinq/eaas/proto/networktrust" ) -var _ networktrust.NetworkTrustServiceServer = &NetworkTrustServer{} +var _ networktrust.NetworkTrustServiceServer = &NetworkTrustServiceServer{} -type NetworkTrustServer struct { +// NetworkTrustServiceServer is an example network trust service server +// implementation. All it does is approve everything that asks. +type NetworkTrustServiceServer struct { networktrust.UnimplementedNetworkTrustServiceServer } // CreateTrust implements networktrust.NetworkTrustServiceServer. -func (n *NetworkTrustServer) CreateTrust(ctx context.Context, _ *networktrust.CreateTrustRequest) (*networktrust.CreateTrustResponse, error) { - // safety sleep to not confuse Engine - time.After(time.Second) +func (n *NetworkTrustServiceServer) CreateTrust(ctx context.Context, req *networktrust.CreateTrustRequest) (*networktrust.CreateTrustResponse, error) { + log.Printf("CreateTrust: %+v", req) // Just allow all for now return &networktrust.CreateTrustResponse{ diff --git a/eaas/announcer.go b/eaas/announcer.go deleted file mode 100644 index 1890260..0000000 --- a/eaas/announcer.go +++ /dev/null @@ -1,158 +0,0 @@ -package eaas - -import ( - "bytes" - "context" - "errors" - "math/rand" - "net" - "sync" - "time" - - "github.com/icedream/go-stagelinq/internal/socket" -) - -// Announcer listens on UDP port 11224 for EAAS clients and announces itself to them. -type Announcer struct { - softwareVersion string - hostname string - packetConn net.PacketConn - token Token - port uint16 - shutdownCond *sync.Cond - shutdownWaitGroup sync.WaitGroup -} - -// Token returns our token that is being announced to the EAAS network. Use this -// token for further communication with services on other devices. -func (l *Announcer) Token() Token { - return l.token -} - -// Close shuts down the listener. -func (l *Announcer) Close() error { - // notify goroutines we are going to shut down and wait for them to finish - l.shutdownCond.Broadcast() - l.shutdownWaitGroup.Wait() - - return l.packetConn.Close() -} - -// Announce announces this EAAS listener to the network. -// -// This function should be called before actually listening in for devices to -// allow them to pick up our token for communication immediately. -func (l *Announcer) Announce() error { - return l.announce() -} - -// AnnounceEvery will start a goroutine which calls the Announce function at given interval. -// It will automatically terminate once this listener is shut down. -// A recommended value for the interval is 1 second. -func (l *Announcer) AnnounceEvery(interval time.Duration) { - shutdownC := make(chan interface{}, 1) - - // make Close() wait for us - l.shutdownWaitGroup.Add(1) - - // listen for shutdown signal broadcast, forward it to our own channel - go func() { - l.shutdownCond.L.Lock() - defer l.shutdownCond.L.Unlock() - l.shutdownCond.Wait() - shutdownC <- nil - }() - - go func() { - defer l.shutdownWaitGroup.Done() - - // timestamp for when to send next announcement - ticker := time.NewTicker(interval) - - // do first announcement immediately - l.Announce() - - for { - select { - case <-ticker.C: // next interval - announcement - if err := l.Announce(); errors.Is(err, net.ErrClosed) { - return - } - // NOTE - Considering AnnounceEvery is a fire-and-forget command we're ignoring other errors here for now. Not sure how to properly handle them otherwise atm. - case <-shutdownC: - return - } - } - }() -} - -func (l *Announcer) announce() (err error) { - // TODO - optimization: cache the built message because it will be sent repeatedly? - m := &eaasDiscoveryRequestMessage{} - b := new(bytes.Buffer) - err = m.WriteMessageTo(b) - if err != nil { - return - } - finalBytes := b.Bytes() - ips, err := socket.GetAllBroadcastIPs() - if err != nil { - return - } - for _, ip := range ips { - addr := makeEAASDiscoveryBroadcastAddress(ip) - packetConn, err := net.DialUDP("udp", nil, addr) - if err == nil { - _, _ = packetConn.Write(finalBytes) - packetConn.Close() - } - } - - return -} - -// Announce sets up a EAAS announcer. -func Announce() (announcer *Announcer, err error) { - return AnnounceWithConfiguration(nil) -} - -var zeroToken = Token{} - -// AnnounceWithConfiguration sets up a EAAS announcer with the given configuration. -func AnnounceWithConfiguration(announcerConfig *AnnouncerConfiguration) (announcer *Announcer, err error) { - // Use empty configuration if no configuration object was passed - if announcerConfig == nil { - announcerConfig = new(AnnouncerConfiguration) - } - - // Initialize token if none was configured - token := announcerConfig.Token - if bytes.Equal(announcerConfig.Token[:], zeroToken[:]) { - if _, err = rand.Read(token[:]); err != nil { - return - } - } - - // Use background context if none was configured - ctx := announcerConfig.Context - if ctx == nil { - ctx = context.Background() - } - - // We are setting up a shared UDP address socket here to allow other applications to still listen for EAAS discovery messages - config := &net.ListenConfig{ - Control: socket.SetSocketControlForReusePort, - } - packetConn, err := config.ListenPacket(ctx, eaasDiscoveryNetwork, eaasDiscoveryAddressString) - if err != nil { - return - } - - return &Announcer{ - hostname: announcerConfig.Name, - packetConn: packetConn, - softwareVersion: announcerConfig.SoftwareVersion, - token: token, - shutdownCond: sync.NewCond(&sync.Mutex{}), - }, nil -} diff --git a/eaas/announcer_configuration.go b/eaas/announcer_configuration.go deleted file mode 100644 index 29d6d9a..0000000 --- a/eaas/announcer_configuration.go +++ /dev/null @@ -1,22 +0,0 @@ -package eaas - -import ( - "context" -) - -// AnnouncerConfiguration contains configurable values for setting up a EAAS -// announcer. -type AnnouncerConfiguration struct { - // Context can be set to allow cancellation of network operations from somewhere else in the code. - Context context.Context - - // Name is the name under which we announce ourselves to the network. Denon - // software tends to use hostname here. - Name string - - // SoftwareVersion is your application's version. It is used for StagelinQ announcements to the network. - SoftwareVersion string - - // Token is used as part of announcements and main data communication. It is currently recommended to leave this empty. - Token Token -} diff --git a/eaas/beacon.go b/eaas/beacon.go new file mode 100644 index 0000000..2e8aa55 --- /dev/null +++ b/eaas/beacon.go @@ -0,0 +1,212 @@ +package eaas + +import ( + "bytes" + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "sync" + + "github.com/icedream/go-stagelinq/internal/messages" + "github.com/icedream/go-stagelinq/internal/socket" + "golang.org/x/net/ipv4" +) + +// Beacon listens on UDP port 11224 for EAAS clients and announces itself to them. +type Beacon struct { + softwareVersion string + hostname string + packetConn4 *ipv4.PacketConn + token Token + grpcHost string + grpcPort uint16 + shutdownWaitGroup sync.WaitGroup +} + +// Token returns our token that is being announced to the EAAS network. Use this +// token for further communication with services on other devices. +func (l *Beacon) Token() Token { + return l.token +} + +// Shutdown shuts down the listener. +func (l *Beacon) Shutdown() error { + err := l.packetConn4.Close() + + // wait for Listen goroutine to finish + l.shutdownWaitGroup.Wait() + + return err +} + +// List will start a goroutine which waits for EAAS clients to announce back to +// them. It will automatically terminate once this listener is shut down. +func (l *Beacon) listen() { + // make Close() wait for us + l.shutdownWaitGroup.Add(1) + + go func() { + defer l.shutdownWaitGroup.Done() + + b := make([]byte, 8) + for { + n, cm, addr, err := l.packetConn4.ReadFrom(b) + if errors.Is(err, net.ErrClosed) { + break + } + if err != nil { + // TODO - log this somehow + continue + } + if err = l.handleIncomingIPv4Packet(b[0:n], cm, addr); err != nil { + // TODO - log this somehow + continue + } + + } + }() +} + +func (l *Beacon) getIP() net.IP { + return socket.GetIPFromAddress(l.packetConn4.LocalAddr()) +} + +func (l *Beacon) getGRPCURL(ip net.IP) string { + var host string + switch { + case len(l.grpcHost) > 0: + host = l.grpcHost + default: + host = ip.String() + } + return fmt.Sprintf("grpc://%s:%d", host, l.grpcPort) +} + +func (l *Beacon) replyIPv4(cm *ipv4.ControlMessage, srcAddr net.Addr) error { + if cm == nil { + panic("control message must not be nil") + } + // figure out the an IP that we could be reachable from based on the + // interface the broadcast came from + intf, err := net.InterfaceByIndex(cm.IfIndex) + if err != nil { + return err + } + addrs, err := intf.Addrs() + if err != nil { + return err + } + // for _, addr := range addrs { + // ip := socket.GetIPFromAddress(addr) + // mask := socket.GetMaskFromAddress(addr) + // net := &net.IPNet{ + // IP: ip, + // Mask: mask, + // } + // net.Contains(ip net.IP) + // } + ip := socket.GetIPFromAddress(addrs[0]) + // TODO - optimization: cache the built message because it will be sent repeatedly? + m := &eaasDiscoveryResponseMessage{ + TokenPrefixedMessage: messages.TokenPrefixedMessage{ + Token: messages.Token(l.token), + }, + Hostname: l.hostname, + SoftwareVersion: l.softwareVersion, + URL: l.getGRPCURL(ip), + Extra: "_", + } + b := new(bytes.Buffer) + if err := m.WriteMessageTo(b); err != nil { + return err + } + ncm := &ipv4.ControlMessage{ + IfIndex: cm.IfIndex, + } + if _, err := l.packetConn4.WriteTo(b.Bytes(), ncm, srcAddr); err != nil { + return err + } + return nil +} + +func (l *Beacon) handleIncomingIPv4Packet(b []byte, cm *ipv4.ControlMessage, srcAddr net.Addr) error { + // decode message + r := bytes.NewReader(b) + m := new(eaasDiscoveryRequestMessage) + if err := m.ReadMessageFrom(r); err != nil { + return err + } + + return l.replyIPv4(cm, srcAddr) +} + +// StartBeacon sets up an EAAS beacon. +func StartBeacon() (*Beacon, error) { + return StartBeaconWithConfiguration(nil) +} + +var zeroToken = Token{} + +// StartBeaconWithConfiguration sets up a EAAS announcer with the given configuration. +func StartBeaconWithConfiguration(beaconConfig *BeaconConfiguration) (beacon *Beacon, err error) { + // Use empty configuration if no configuration object was passed + if beaconConfig == nil { + beaconConfig = new(BeaconConfiguration) + } + + // Initialize token if none was configured + token := beaconConfig.Token + if bytes.Equal(beaconConfig.Token[:], zeroToken[:]) { + if _, err = rand.Read(token[:]); err != nil { + return + } + } + + // Use default EAAS gRPC port if none was set + grpcPort := beaconConfig.GRPCPort + if grpcPort == 0 { + grpcPort = DefaultEAASGRPCPort + } + + // udpConn, err := net.ListenUDP("udp", makeEAASDiscoveryAddress(bindIP)) + // if err != nil { + // return + // } + + // We are setting up a shared UDP address socket here to allow other + // applications to still listen for EAAS discovery messages + config := &net.ListenConfig{ + Control: socket.SetSocketControlForReusePort, + } + packetConn, err := config.ListenPacket( + context.TODO(), + "udp4", + makeEAASDiscoveryAddress(net.IPv4zero).String()) + if err != nil { + return + } + + ipv4PacketConn := ipv4.NewPacketConn(packetConn) + if err := ipv4PacketConn.SetControlMessage(ipv4.FlagInterface, true); err != nil { + err = errors.Join(err, packetConn.Close()) + return nil, err + } + // if err := ipv4PacketConn.SetControlMessage(ipv4.FlagSrc, true); err != nil { + // err = errors.Join(err, packetConn.Close()) + // return nil, err + // } + + b := &Beacon{ + packetConn4: ipv4PacketConn, + hostname: beaconConfig.Name, + softwareVersion: beaconConfig.SoftwareVersion, + token: token, + grpcHost: beaconConfig.GRPCHost, + grpcPort: grpcPort, + } + go b.listen() + + return b, nil +} diff --git a/eaas/beacon_configuration.go b/eaas/beacon_configuration.go new file mode 100644 index 0000000..83976c6 --- /dev/null +++ b/eaas/beacon_configuration.go @@ -0,0 +1,35 @@ +package eaas + +// The default EAAS gRPC API port. +const DefaultEAASGRPCPort uint16 = 50010 + +// The default EAAS HTTP server port. +const DefaultEAASHTTPPort uint16 = DefaultEAASGRPCPort + 10 + +// BeaconConfiguration contains configurable values for setting up a EAAS +// announcer. +type BeaconConfiguration struct { + // Name is the name under which we announce ourselves to the network. Denon + // software tends to use hostname here. + Name string + + // SoftwareVersion is your application's version. It is used for StagelinQ + // announcements to the network. + SoftwareVersion string + + // Token is used as part of announcements and main data communication. It is + // currently recommended to leave this empty. + Token Token + + // The host to report back to clients with where the EAAS gRPC API is + // listening on. + // + // If left empty, defaults to the IP the beacon is bound to. It is + // recommended to leave this empty. + GRPCHost string + + // The port to report back to clients which the EAAS gRPC API is listening on. + // + // If left zero, defaults to the default EAAS gRPC API port (50010). + GRPCPort uint16 +} diff --git a/eaas/listener.go b/eaas/listener.go index 544b5ad..46686ba 100644 --- a/eaas/listener.go +++ b/eaas/listener.go @@ -32,7 +32,7 @@ const ( eaasDiscoveryAddressString = ":11224" ) -func makeEAASDiscoveryBroadcastAddress(ip net.IP) *net.UDPAddr { +func makeEAASDiscoveryAddress(ip net.IP) *net.UDPAddr { return &net.UDPAddr{ IP: ip, Port: 11224, diff --git a/go.mod b/go.mod index 747b021..baa3f61 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/gorilla/mux v1.8.1 github.com/lithammer/fuzzysearch v1.1.8 github.com/rivo/tview v0.0.0-20240307173318-e804876934a1 + golang.org/x/net v0.22.0 google.golang.org/grpc v1.62.1 google.golang.org/protobuf v1.33.0 ) @@ -90,7 +91,6 @@ require ( golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.16.0 // indirect - golang.org/x/net v0.22.0 // indirect golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/term v0.18.0 // indirect diff --git a/internal/socket/addr_ip.go b/internal/socket/addr_ip.go new file mode 100644 index 0000000..e07a0b4 --- /dev/null +++ b/internal/socket/addr_ip.go @@ -0,0 +1,18 @@ +package socket + +import ( + "net" +) + +func GetIPFromAddress(address net.Addr) net.IP { + switch convertedAddress := address.(type) { + case *net.UDPAddr: + return convertedAddress.IP + case *net.TCPAddr: + return convertedAddress.IP + case *net.IPNet: + return convertedAddress.IP + default: + panic("unsupported network address type") + } +} diff --git a/internal/socket/addr_net.go b/internal/socket/addr_net.go new file mode 100644 index 0000000..20d4309 --- /dev/null +++ b/internal/socket/addr_net.go @@ -0,0 +1,14 @@ +package socket + +import ( + "net" +) + +func GetMaskFromAddress(address net.Addr) net.IPMask { + switch convertedAddress := address.(type) { + case *net.IPNet: + return convertedAddress.Mask + default: + return nil + } +} diff --git a/internal/socket/addr_port_go1.18.go b/internal/socket/addr_port_go1.18.go new file mode 100644 index 0000000..e1cc446 --- /dev/null +++ b/internal/socket/addr_port_go1.18.go @@ -0,0 +1,22 @@ +//go:build go1.18 +// +build go1.18 + +package socket + +import ( + "net" + "net/netip" +) + +type convertableToAddrPort interface { + AddrPort() netip.AddrPort +} + +func GetPortFromAddress(address net.Addr) uint16 { + switch convertedAddress := address.(type) { + case convertableToAddrPort: + return convertedAddress.AddrPort().Port() + default: + panic("unsupported network address type") + } +} diff --git a/internal/socket/addr_port_legacy.go b/internal/socket/addr_port_legacy.go new file mode 100644 index 0000000..4d94da9 --- /dev/null +++ b/internal/socket/addr_port_legacy.go @@ -0,0 +1,24 @@ +//go:build !go1.18 +// +build !go1.18 + +package socket + +import ( + "net" + "net/netip" +) + +type convertableToAddrPort interface { + AddrPort() netip.AddrPort +} + +func GetPortFromAddress(address net.Addr) uint16 { + switch convertedAddress := address.(type) { + case *net.UDPAddr: + return uint16(convertedAddress.Port) + case *net.TCPAddr: + return uint16(convertedAddress.Port) + default: + panic("unsupported network address type") + } +} diff --git a/internal/socket/addr_port_test.go b/internal/socket/addr_port_test.go new file mode 100644 index 0000000..16ddd5a --- /dev/null +++ b/internal/socket/addr_port_test.go @@ -0,0 +1,19 @@ +package socket + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_getPort(t *testing.T) { + require.Equal(t, 12345, GetPortFromAddress(&net.TCPAddr{ + IP: net.IPv4(1, 2, 3, 4), + Port: 12345, + })) + require.Equal(t, 12345, GetPortFromAddress(&net.UDPAddr{ + IP: net.IPv4(1, 2, 3, 4), + Port: 12345, + })) +} diff --git a/internal/socket/broadcast.go b/internal/socket/broadcast.go index 6f81ad3..371f751 100644 --- a/internal/socket/broadcast.go +++ b/internal/socket/broadcast.go @@ -39,3 +39,29 @@ addrsLoop: retval = ips return } + +func MakeBroadcastIP(ip net.IP, mask net.IPMask) (bip net.IP) { + // get 4-byte representation of ipv4 is possible, nil if not an ipv4 address + convertedIPv4 := false + if ip4 := ip.To4(); ip4 != nil { + convertedIPv4 = len(ip) != len(ip4) + ip = ip4 + } + + if len(mask) != len(ip) { + // mask and ip are different sizes, panic! + panic("net mask and ip address are different sizes") + } + + bip = make(net.IP, len(ip)) + for i := range mask { + bip[i] = ip[i] | ^mask[i] + } + + // convert back to 16-byte representation if input was 16-byte, too + if convertedIPv4 { + bip = bip.To16() + } + + return +} diff --git a/internal/socket/util_test.go b/internal/socket/broadcast_test.go similarity index 78% rename from internal/socket/util_test.go rename to internal/socket/broadcast_test.go index 457597e..f294e9d 100644 --- a/internal/socket/util_test.go +++ b/internal/socket/broadcast_test.go @@ -7,17 +7,6 @@ import ( "github.com/stretchr/testify/require" ) -func Test_getPort(t *testing.T) { - require.Equal(t, 12345, GetPortFromAddress(&net.TCPAddr{ - IP: net.IPv4(1, 2, 3, 4), - Port: 12345, - })) - require.Equal(t, 12345, GetPortFromAddress(&net.UDPAddr{ - IP: net.IPv4(1, 2, 3, 4), - Port: 12345, - })) -} - func Test_makeBroadcastIP(t *testing.T) { testValues := []struct { IP net.IP diff --git a/internal/socket/reuse_port_nonwindows.go b/internal/socket/reuse_port_nonwindows.go deleted file mode 100644 index 88f9043..0000000 --- a/internal/socket/reuse_port_nonwindows.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !windows -// +build !windows - -package socket - -import "syscall" - -func SetSocketControlForReusePort(_, _ string, c syscall.RawConn) error { - return c.Control(func(fd uintptr) { - syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) - syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_DONTROUTE, 1) - }) -} diff --git a/internal/socket/sockopts.go b/internal/socket/sockopts.go new file mode 100644 index 0000000..694a91a --- /dev/null +++ b/internal/socket/sockopts.go @@ -0,0 +1,9 @@ +package socket + +import "syscall" + +// The function type needed to be provided to [net.Dialer.Control]. +type controlFunc func(network, address string, c syscall.RawConn) error + +// Sanity check +var _ controlFunc = SetSocketControlForReusePort diff --git a/internal/socket/sockopts_nonwindows.go b/internal/socket/sockopts_nonwindows.go new file mode 100644 index 0000000..7ec15c1 --- /dev/null +++ b/internal/socket/sockopts_nonwindows.go @@ -0,0 +1,23 @@ +//go:build !windows +// +build !windows + +package socket + +import ( + "log" + "syscall" +) + +func SetSocketControlForReusePort(_, _ string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { + log.Println("Could not set sockopt SO_REUSEADDR:", err) + } + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1); err != nil { + log.Println("Could not set sockopt SO_BROADCAST:", err) + } + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_DONTROUTE, 1); err != nil { + log.Println("Could not set sockopt SO_DONTROUTE:", err) + } + }) +} diff --git a/internal/socket/reuse_port_windows.go b/internal/socket/sockopts_windows.go similarity index 100% rename from internal/socket/reuse_port_windows.go rename to internal/socket/sockopts_windows.go diff --git a/internal/socket/util.go b/internal/socket/util.go deleted file mode 100644 index 7430981..0000000 --- a/internal/socket/util.go +++ /dev/null @@ -1,42 +0,0 @@ -package socket - -import ( - "net" -) - -func GetPortFromAddress(address net.Addr) int { - switch convertedAddress := address.(type) { - case *net.UDPAddr: - return convertedAddress.Port - case *net.TCPAddr: - return convertedAddress.Port - default: - panic("unsupported network address type") - } -} - -func MakeBroadcastIP(ip net.IP, mask net.IPMask) (bip net.IP) { - // get 4-byte representation of ipv4 is possible, nil if not an ipv4 address - convertedIPv4 := false - if ip4 := ip.To4(); ip4 != nil { - convertedIPv4 = len(ip) != len(ip4) - ip = ip4 - } - - if len(mask) != len(ip) { - // mask and ip are different sizes, panic! - panic("net mask and ip address are different sizes") - } - - bip = make(net.IP, len(ip)) - for i := range mask { - bip[i] = ip[i] | ^mask[i] - } - - // convert back to 16-byte representation if input was 16-byte, too - if convertedIPv4 { - bip = bip.To16() - } - - return -}