diff --git a/cmd/storage-discover/key_uuid.go b/cmd/storage-discover/key_uuid.go new file mode 100644 index 0000000..d6cd15a --- /dev/null +++ b/cmd/storage-discover/key_uuid.go @@ -0,0 +1,38 @@ +package main + +import ( + "io" + "os" + + "github.com/google/uuid" +) + +var zeroUUID = uuid.UUID{} + +func loadUUIDKey() (string, error) { + var id uuid.UUID + if f, err := os.Open("eaas-id.txt"); err == nil { + defer f.Close() + keyBytes, err := io.ReadAll(f) + if err != nil { + return "", err + } + id, err = uuid.ParseBytes(keyBytes) + if err != nil { + return "", err + } + } + if id == zeroUUID { + var err error + id, err = uuid.NewUUID() + if err != nil { + return "", err + } + keyBytes, err := id.MarshalBinary() + if err != nil { + return "", err + } + os.WriteFile("eaas-id.txt", keyBytes, 0o600) + } + return id.String(), nil +} diff --git a/cmd/storage-discover/main.go b/cmd/storage-discover/main.go index b3d7f4d..aa9fd8c 100644 --- a/cmd/storage-discover/main.go +++ b/cmd/storage-discover/main.go @@ -2,16 +2,14 @@ package main import ( "context" - "crypto/ed25519" - "crypto/rand" - "crypto/x509" + "encoding/json" + "errors" "flag" "io" "log" "os" "time" - "github.com/google/uuid" "github.com/icedream/go-stagelinq/eaas" "github.com/icedream/go-stagelinq/eaas/proto/enginelibrary" "github.com/icedream/go-stagelinq/eaas/proto/networktrust" @@ -27,8 +25,7 @@ const ( var ( grpcURL string hostname string - key ed25519.PrivateKey - id uuid.UUID + identity string ) func init() { @@ -40,58 +37,6 @@ func init() { if err != nil { hostname = "eaas-demo" } - - if f, err := os.Open("eaas-key.bin"); err == nil { - defer f.Close() - keyBytes, err := io.ReadAll(f) - if err != nil { - panic(err) - } - readKey, err := x509.ParsePKCS8PrivateKey(keyBytes) - if err != nil { - panic(err) - } - if edkey, ok := readKey.(ed25519.PrivateKey); !ok { - panic("eaas-key.bin is not an ed25519 private key") - } else { - key = edkey - } - } - if key == nil { - _, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - keyBytes, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - panic(err) - } - os.WriteFile("eaas-key.bin", keyBytes, 0o600) - key = priv - } - - if f, err := os.Open("eaas-id.txt"); err == nil { - defer f.Close() - keyBytes, err := io.ReadAll(f) - if err != nil { - panic(err) - } - id, err = uuid.ParseBytes(keyBytes) - if err != nil { - panic(err) - } - } - if key == nil { - id, err = uuid.NewUUID() - if err != nil { - panic(err) - } - keyBytes, err := id.MarshalBinary() - if err != nil { - panic(err) - } - os.WriteFile("eaas-id.txt", keyBytes, 0o600) - } } type App struct { @@ -107,22 +52,33 @@ func main() { runEngineLibraryUI(grpcURL) } +func marshalJSON(v any) []byte { + s, _ := json.Marshal(v) + return s +} + func runEngineLibraryUI(grpcURL string) { + // load our identity so we don't have to repeatedly re-verify + identity, err := loadUUIDKey() + if err != nil { + log.Fatal(err) + } + ctx := context.Background() connection, err := eaas.DialContext(ctx, grpcURL) if err != nil { - panic(err) + log.Fatal(err) } - // pk := string(key.Public().(ed25519.PublicKey)) - pk := id.String() log.Println("Waiting for approval on the other end...") resp, err := connection.CreateTrust(ctx, &networktrust.CreateTrustRequest{ DeviceName: &hostname, - Ed25519Pk: &pk, + // I honestly don't know why in the proto it was defined as "Ed25519Pk"... + Ed25519Pk: &identity, + // ...or why there even is a WireguardPort field, too?! }) if err != nil { - panic(err) + log.Fatal(err) } switch { case resp.GetGranted() != nil: @@ -139,23 +95,42 @@ func runEngineLibraryUI(grpcURL string) { if err != nil { panic(err) } - var pageSize uint32 = 100 + var pageSize uint32 = 25 + getTracksResp, err := connection.GetTracks(ctx, &enginelibrary.GetTracksRequest{ + PageSize: &pageSize, + }) + if err != nil { + panic(err) + } + for _, track := range getTracksResp.GetTracks() { + log.Printf("Track: %s", string(marshalJSON(track))) + getTrackResp, err := connection.GetTrack(ctx, &enginelibrary.GetTrackRequest{ + TrackId: track.GetMetadata().Id, + }) + if err != nil { + log.Println("\tfailed to GetTrack on this track") + continue + } + log.Printf("\t%+v", getTrackResp) + } for _, playlist := range getLibraryResp.GetPlaylists() { - log.Printf("Playlist %q (%q)", playlist.GetTitle(), playlist.GetListType()) - + log.Printf("Playlist: %s", string(marshalJSON(playlist))) getTracksResp, err := connection.GetTracks(ctx, &enginelibrary.GetTracksRequest{ PlaylistId: playlist.Id, PageSize: &pageSize, }) + if errors.Is(err, io.EOF) { + // BUG - empty playlist causes EOF, reconnect + connection, err = eaas.DialContext(ctx, grpcURL) + if err != nil { + panic(err) + } + } if err != nil { panic(err) } for _, track := range getTracksResp.GetTracks() { - metadata := track.GetMetadata() - if metadata == nil { - continue - } - log.Printf("\tTrack %s", metadata.String()) + log.Printf("\tTrack: ID %s", track.GetMetadata().GetId()) } } } diff --git a/cmd/storage/.gitattributes b/cmd/storage/.gitattributes index 4b0a82f..d427248 100644 --- a/cmd/storage/.gitattributes +++ b/cmd/storage/.gitattributes @@ -1 +1,3 @@ *.m4a filter=lfs diff=lfs merge=lfs -text +*.beatgrid filter=lfs diff=lfs merge=lfs -text +*.waveform filter=lfs diff=lfs merge=lfs -text diff --git a/cmd/storage/Icedream - Whiplash (Radio Edit).m4a b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a index 1b5dafe..f5ea7be 100644 --- a/cmd/storage/Icedream - Whiplash (Radio Edit).m4a +++ b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5407491f594edd312074196d814f00cea0cd5737ad440b5ddd02f2b1511c8f81 +oid sha256:1ecca0471c006231ede7f8f2152104cee47597908e9f54c3d0298882bf95f6c5 size 4158124 diff --git a/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.beatgrid b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.beatgrid new file mode 100644 index 0000000..a77c832 --- /dev/null +++ b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.beatgrid @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fba7a60b354a661ffcdc93679d5dd6e3acf700f63f0a3e051e1d2e89b43914f5 +size 59 diff --git a/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.waveform b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.waveform new file mode 100644 index 0000000..af2e6a5 --- /dev/null +++ b/cmd/storage/Icedream - Whiplash (Radio Edit).m4a.waveform @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81342152f52d1707dbbd4d95b8ea9dd5108a1f4529e88e216230b529211b236b +size 2856 diff --git a/cmd/storage/demo.go b/cmd/storage/demo.go index 3b3604d..900682a 100644 --- a/cmd/storage/demo.go +++ b/cmd/storage/demo.go @@ -4,40 +4,81 @@ import ( "bytes" _ "embed" "fmt" + "path/filepath" "strconv" + "strings" "github.com/dhowden/tag" "github.com/google/uuid" + "github.com/icedream/go-stagelinq/eaas" "github.com/icedream/go-stagelinq/eaas/proto/enginelibrary" "google.golang.org/protobuf/types/known/timestamppb" ) +// Imports needed for image resizing (see commented out code for it) +// import ( +// "image" +// _ "image/jpeg" +// _ "image/png" +// ) + var ( - demoTrackFileName = "Icedream - Whiplash (Radio Edit).flac" + demoTrackFileName = "Icedream - Whiplash (Radio Edit).m4a" demoLibrary = "12eceaa2-f81a-4b63-b196-94648a3bdd95" demoLibraryName = "Demo Library" demoPlaylist = "55ab0c7c-6c35-429a-81d0-25b039a34a9f" demoPlaylistName = "Demo Playlist" demoPlaylistTrackCount uint32 = 1 - demoTrackIDs []string - demoTrackURL = "/demo/" + demoTrackFileName - demoTrackLength = uint32(len(demoTrackBytes)) - demoTrackMetadata enginelibrary.TrackMetadata - demoTrackArtwork []byte + demoTrackIDs = []string{ + "1 " + demoLibrary, + } + // HACK - imitating original Engine DJ software behavior by using Windows paths + demoTrackURL = filepath.Join("C:", "demo", demoTrackFileName) + // HACK - imitating original Engine DJ software behavior by adding brackets. + demoTrackURLGRPC = fmt.Sprintf("<%s>", filepath.ToSlash(demoTrackURL)) + demoTrackLength = uint32(len(demoTrackBytes)) + demoTrackMetadata enginelibrary.TrackMetadata + demoTrackArtwork []byte + demoToken eaas.Token = eaas.Token{ + 0x5e, 0xff, 0xae, 0x59, 0x12, 0x88, 0x29, 0x30, + 0xde, 0xad, 0xc0, 0xde, 0xc0, 0xff, 0xee, 0x00, + } ) //go:embed "Icedream - Whiplash (Radio Edit).m4a" var demoTrackBytes []byte +//go:embed "Icedream - Whiplash (Radio Edit).m4a.beatgrid" +var demoBeatGrid []byte + +//go:embed "Icedream - Whiplash (Radio Edit).m4a.waveform" +var demoOverviewWaveform []byte + +var demoTrackPreviewArtwork []byte + func init() { - for i := 0; i < int(demoPlaylistTrackCount); i++ { - demoTrackIDs = append(demoTrackIDs, uuid.New().String()) + if len(demoTrackIDs) == 0 { + for i := 0; i < int(demoPlaylistTrackCount); i++ { + demoTrackIDs = append(demoTrackIDs, uuid.New().String()) + } } demoTrackMetadata.DateAdded = timestamppb.Now() if metadata, err := tag.ReadFrom(bytes.NewReader(demoTrackBytes)); err == nil { if metadata.Picture() != nil { demoTrackArtwork = metadata.Picture().Data + demoTrackPreviewArtwork = demoTrackArtwork + // // If you wanna be nice to the hardware, you can have the server + // // shrink down the artwork. I don't think even the original Engine + // // DJ software does that though. + // img, _, err := image.Decode(bytes.NewReader(demoTrackArtwork)) + // if err == nil { + // img = resize.Resize(240, 240, img, resize.Lanczos2) + // } + // var b bytes.Buffer + // if err := jpeg.Encode(&b, img, &jpeg.Options{Quality: 70}); err == nil { + // demoTrackPreviewArtwork = b.Bytes() + // } } if v := metadata.Artist(); len(v) > 0 { demoTrackMetadata.Artist = &v @@ -64,19 +105,21 @@ func init() { } if v, ok := metadata.Raw()["KEY"]; ok { s := fmt.Sprint(v) + s = strings.Trim(s, "\x00 ") demoTrackMetadata.Key = &s } if v, ok := metadata.Raw()["LABEL"]; ok { s := fmt.Sprint(v) + s = strings.Trim(s, "\x00 ") demoTrackMetadata.Label = &s } if v, ok := metadata.Raw()["REMIXER"]; ok { s := fmt.Sprint(v) + s = strings.Trim(s, "\x00 ") demoTrackMetadata.Remixer = &s } if v := uint32(metadata.Year()); v > 0 { demoTrackMetadata.Year = &v } - } } diff --git a/cmd/storage/engine_library_service_server.go b/cmd/storage/engine_library_service_server.go index 11988ee..7323894 100644 --- a/cmd/storage/engine_library_service_server.go +++ b/cmd/storage/engine_library_service_server.go @@ -4,6 +4,7 @@ import ( "context" _ "embed" "fmt" + "log" "math" "strconv" @@ -15,31 +16,40 @@ import ( var _ enginelibrary.EngineLibraryServiceServer = &EngineLibraryServiceServer{} +// EngineLibraryServiceServer is an example library service server +// implementation. +// +// It will provide a the demo audio file as if contained in a library with +// playlists. Some functions not needed for the task are left unimplemented. type EngineLibraryServiceServer struct { enginelibrary.UnimplementedEngineLibraryServiceServer } // EventStream implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) EventStream(ctx context.Context, req *enginelibrary.EventStreamRequest) (*enginelibrary.EventStreamResponse, error) { + log.Printf("EventStream: %+v", req) return &enginelibrary.EventStreamResponse{ Event: []*enginelibrary.Event{}, }, nil } // GetCredentials implements enginelibrary.EngineLibraryServiceServer. -func (e *EngineLibraryServiceServer) GetCredentials(context.Context, *enginelibrary.GetCredentialsRequest) (*enginelibrary.GetCredentialsResponse, error) { +func (e *EngineLibraryServiceServer) GetCredentials(ctx context.Context, req *enginelibrary.GetCredentialsRequest) (*enginelibrary.GetCredentialsResponse, error) { + log.Printf("GetCredentials: %+v", req) panic("unimplemented") } // GetHistoryPlayedTracks implements enginelibrary.EngineLibraryServiceServer. -func (e *EngineLibraryServiceServer) GetHistoryPlayedTracks(context.Context, *enginelibrary.GetHistoryPlayedTracksRequest) (*enginelibrary.GetHistoryPlayedTracksResponse, error) { +func (e *EngineLibraryServiceServer) GetHistoryPlayedTracks(ctx context.Context, req *enginelibrary.GetHistoryPlayedTracksRequest) (*enginelibrary.GetHistoryPlayedTracksResponse, error) { + log.Printf("GetHistoryPlayedTracks: %+v", req) return &enginelibrary.GetHistoryPlayedTracksResponse{ Tracks: []*enginelibrary.HistoryPlayedTrack{}, }, nil } // GetHistorySessions implements enginelibrary.EngineLibraryServiceServer. -func (e *EngineLibraryServiceServer) GetHistorySessions(context.Context, *enginelibrary.GetHistorySessionsRequest) (*enginelibrary.GetHistorySessionsResponse, error) { +func (e *EngineLibraryServiceServer) GetHistorySessions(ctx context.Context, req *enginelibrary.GetHistorySessionsRequest) (*enginelibrary.GetHistorySessionsResponse, error) { + log.Printf("GetHistorySessions: %+v", req) return &enginelibrary.GetHistorySessionsResponse{ Sessions: []*enginelibrary.HistorySession{}, }, nil @@ -47,6 +57,7 @@ func (e *EngineLibraryServiceServer) GetHistorySessions(context.Context, *engine // GetLibraries implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetLibraries(ctx context.Context, req *enginelibrary.GetLibrariesRequest) (*enginelibrary.GetLibrariesResponse, error) { + log.Printf("GetLibraries: %+v", req) return &enginelibrary.GetLibrariesResponse{ Libraries: []*enginelibrary.Library{ { @@ -59,8 +70,9 @@ func (e *EngineLibraryServiceServer) GetLibraries(ctx context.Context, req *engi // GetLibrary implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetLibrary(ctx context.Context, req *enginelibrary.GetLibraryRequest) (*enginelibrary.GetLibraryResponse, error) { + log.Printf("GetLibrary: %+v", req) switch req.GetLibraryId() { - case demoLibrary: + case "", demoLibrary: return &enginelibrary.GetLibraryResponse{ Playlists: []*enginelibrary.PlaylistMetadata{ { @@ -79,6 +91,7 @@ func (e *EngineLibraryServiceServer) GetLibrary(ctx context.Context, req *engine // GetSearchFilters implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetSearchFilters(ctx context.Context, req *enginelibrary.GetSearchFiltersRequest) (*enginelibrary.GetSearchFiltersResponse, error) { + log.Printf("GetSearchFilters: %+v", req) resp := &enginelibrary.GetSearchFiltersResponse{ SearchFilters: &enginelibrary.SearchFilterOptions{}, } @@ -143,26 +156,40 @@ func generateDemoTrackMetadata(trackID string) *enginelibrary.TrackMetadata { return &metadata } +var unsetFloat64 float64 = -1 + // GetTrack implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetTrack(ctx context.Context, req *enginelibrary.GetTrackRequest) (*enginelibrary.GetTrackResponse, error) { - if req.GetLibraryId() != demoLibrary && req.LibraryId != nil { + log.Printf("GetTrack: %+v", req) + if len(req.GetLibraryId()) != 0 && req.GetLibraryId() != demoLibrary { return nil, status.Error(codes.NotFound, "library not found") } for _, trackID := range demoTrackIDs { + metadata := generateDemoTrackMetadata(trackID) if trackID == req.GetTrackId() { - return &enginelibrary.GetTrackResponse{ + resp := &enginelibrary.GetTrackResponse{ Blob: &enginelibrary.TrackBlob{ Type: &enginelibrary.TrackBlob_Url{ Url: &enginelibrary.TrackBlobUrl{ - Url: &demoTrackURL, + Url: &demoTrackURLGRPC, FileSize: &demoTrackLength, }, }, }, - Metadata: generateDemoTrackMetadata(trackID), - PerformanceData: nil, // TODO - }, nil + Metadata: generateDemoTrackMetadata(trackID), + PerformanceData: &enginelibrary.TrackPerformanceData{ + Bpm: metadata.Bpm, + BeatGrid: demoBeatGrid, + MainCue: &enginelibrary.MainCue{ + Position: &unsetFloat64, + InitialPosition: &unsetFloat64, + }, + OverviewWaveform: demoOverviewWaveform, + }, + } + log.Printf("=> Found demo track ID: %+v", resp) + return resp, nil } } @@ -171,21 +198,31 @@ func (e *EngineLibraryServiceServer) GetTrack(ctx context.Context, req *engineli // GetTracks implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) GetTracks(ctx context.Context, req *enginelibrary.GetTracksRequest) (*enginelibrary.GetTracksResponse, error) { - switch req.GetLibraryId() { - case "", demoLibrary: + log.Printf("GetTracks: %+v", req) + switch { + case req.GetPlaylistId() == demoPlaylist: // specific playlist resp := &enginelibrary.GetTracksResponse{ Tracks: []*enginelibrary.ListTrack{}, } for _, trackID := range demoTrackIDs { resp.Tracks = append(resp.Tracks, &enginelibrary.ListTrack{ Metadata: generateDemoTrackMetadata(trackID), - PreviewArtwork: demoTrackArtwork, + PreviewArtwork: demoTrackPreviewArtwork, }) } - return &enginelibrary.GetTracksResponse{ + return resp, nil + case req.GetLibraryId() == "" || req.GetLibraryId() == demoLibrary: // specific or default library + resp := &enginelibrary.GetTracksResponse{ Tracks: []*enginelibrary.ListTrack{}, - }, nil - default: + } + for _, trackID := range demoTrackIDs { + resp.Tracks = append(resp.Tracks, &enginelibrary.ListTrack{ + Metadata: generateDemoTrackMetadata(trackID), + PreviewArtwork: demoTrackPreviewArtwork, + }) + } + return resp, nil + default: // neither playlist nor library match return &enginelibrary.GetTracksResponse{ Tracks: []*enginelibrary.ListTrack{}, }, nil @@ -194,11 +231,13 @@ func (e *EngineLibraryServiceServer) GetTracks(ctx context.Context, req *enginel // PutEvents implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) PutEvents(ctx context.Context, req *enginelibrary.PutEventsRequest) (*enginelibrary.PutEventsResponse, error) { + log.Printf("PutEvents: %+v", req) return &enginelibrary.PutEventsResponse{}, nil } // SearchTracks implements enginelibrary.EngineLibraryServiceServer. func (e *EngineLibraryServiceServer) SearchTracks(ctx context.Context, req *enginelibrary.SearchTracksRequest) (*enginelibrary.SearchTracksResponse, error) { + log.Printf("SearchTracks: %+v", req) resp := &enginelibrary.SearchTracksResponse{ Tracks: []*enginelibrary.ListTrack{}, } @@ -261,7 +300,7 @@ trackLoop: } resp.Tracks = append(resp.Tracks, &enginelibrary.ListTrack{ Metadata: metadata, - PreviewArtwork: demoTrackArtwork, + PreviewArtwork: demoTrackPreviewArtwork, }) } return resp, nil diff --git a/cmd/storage/http.go b/cmd/storage/http.go index 4f3bb45..26dd9f8 100644 --- a/cmd/storage/http.go +++ b/cmd/storage/http.go @@ -3,28 +3,59 @@ package main import ( "bytes" "io" + "log" "net/http" + "net/url" "strconv" "github.com/gorilla/mux" ) -func newHTTPServiceHandler() http.Handler { - r := mux.NewRouter() - r.Get("/download/{path}").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - requestedPath := vars["path"] - if requestedPath != demoTrackURL { - w.WriteHeader(http.StatusNotFound) - return - } - w.Header().Set("Content-length", strconv.Itoa(len(demoTrackBytes))) - w.WriteHeader(http.StatusOK) - f := bytes.NewReader(demoTrackBytes) - io.Copy(w, f) - }) - r.Get("/ping").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) +func logMuxHandling(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("Mux, handling: %s %s", r.Method, r.URL.String()) + h.ServeHTTP(w, r) }) +} + +func handleNotFound(w http.ResponseWriter, r *http.Request) { + log.Printf("Mux, not found: %s %s", r.Method, r.URL.String()) + w.WriteHeader(http.StatusNotFound) +} + +func handlePing(w http.ResponseWriter, r *http.Request) { + log.Println("HTTP: Ping") + w.WriteHeader(http.StatusOK) +} + +func handleDownload(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + requestedPath, err := url.PathUnescape(vars["path"]) + if err != nil { + log.Println("HTTP: Download, bad path:", requestedPath) + w.WriteHeader(http.StatusBadRequest) + return + } + if requestedPath != demoTrackURLGRPC { + log.Println("HTTP: Download, not found:", requestedPath) + w.WriteHeader(http.StatusNotFound) + return + } + log.Println("HTTP: Download, OK:", requestedPath) + w.Header().Set("Content-type", "application/octet-stream") + w.Header().Set("Content-length", strconv.Itoa(len(demoTrackBytes))) + w.WriteHeader(http.StatusOK) + f := bytes.NewReader(demoTrackBytes) + io.Copy(w, f) +} + +func eaasHTTPHandler() http.Handler { + r := mux.NewRouter() + r.Use(logMuxHandling) + r.NotFoundHandler = http.HandlerFunc(handleNotFound) + r.UseEncodedPath() + r.SkipClean(true) + r.HandleFunc("/download/{path}", handleDownload).Methods(http.MethodGet) + r.HandleFunc("/ping", handlePing).Methods(http.MethodGet) return r } diff --git a/cmd/storage/main.go b/cmd/storage/main.go index a12738e..e573b7a 100644 --- a/cmd/storage/main.go +++ b/cmd/storage/main.go @@ -1,14 +1,9 @@ package main import ( - "bytes" "context" "crypto/rand" - "encoding/binary" - "encoding/hex" - "errors" "fmt" - "io" "log" "net" "net/http" @@ -17,6 +12,7 @@ import ( "syscall" "time" + "github.com/icedream/go-stagelinq/eaas" "github.com/icedream/go-stagelinq/eaas/proto/enginelibrary" "github.com/icedream/go-stagelinq/eaas/proto/networktrust" "google.golang.org/grpc" @@ -28,26 +24,28 @@ const ( timeout = 5 * time.Second ) +var hostname string + +func init() { + var err error + hostname, err = os.Hostname() + if err != nil { + hostname = "eaas-demoserver" + } +} + func main() { + // Generate random token to identify with. + // + // Engine uses the token to know whether you just logged onto the network or + // whether you're a library that just restarted. For our demo purposes this + // doesn't matter too much though, so we just regenerate on bootup. var token [16]byte if _, err := rand.Read(token[:]); err != nil { panic(err) } - // listener, err := stagelinq.ListenWithConfiguration(&stagelinq.ListenerConfiguration{ - // DiscoveryTimeout: timeout, - // SoftwareName: appName, - // SoftwareVersion: appVersion, - // Name: "testing", - // }) - // if err != nil { - // panic(err) - // } - // defer listener.Close() - - // listener.AnnounceEvery(time.Second) - - ctx := context.TODO() + ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -67,87 +65,46 @@ func main() { s.Shutdown(timeout) }() - // set up grpc listener + // Set up gRPC API + grpcPort := eaas.DefaultEAASGRPCPort grpcListener, err := net.ListenTCP("tcp", &net.TCPAddr{ - IP: net.IPv4(0, 0, 0, 0), - Port: 50010, + Port: int(grpcPort), }) if err != nil { panic(err) } enginelibrary.RegisterEngineLibraryServiceServer(grpcServer, &EngineLibraryServiceServer{}) - networktrust.RegisterNetworkTrustServiceServer(grpcServer, &NetworkTrustServer{}) + networktrust.RegisterNetworkTrustServiceServer(grpcServer, &NetworkTrustServiceServer{}) go func() { log.Println("Listening on GRPC") _ = grpcServer.Serve(grpcListener) }() - // set up http listener - s.Addr = ":50020" - s.Handler = newHTTPServiceHandler() + // Set up HTTP server + s.Addr = fmt.Sprintf(":%d", eaas.DefaultEAASHTTPPort) + s.Handler = eaasHTTPHandler() go func() { log.Println("Listening on HTTP") _ = s.ListenAndServe() }() - // listen for broadcasts - udpListener, err := net.ListenUDP("udp", &net.UDPAddr{ - IP: net.IPv4(255, 255, 255, 255), - Port: 11224, + // Listen for beacon UDP broadcasts + log.Println("Beacon starting") + beacon, err := eaas.StartBeaconWithConfiguration(&eaas.BeaconConfiguration{ + Name: hostname, + SoftwareVersion: appVersion, + GRPCPort: grpcPort, + Token: demoToken, }) if err != nil { panic(err) } - udpC := make(chan *net.UDPAddr, 2) - go func() { - b := make([]byte, 6) - for { - n, addr, err := udpListener.ReadFromUDP(b) - if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { - return - } - if err != nil { - log.Println("UDP error, ignoring:", err) - continue - } - if n != 6 { - log.Println("UDP message too short, ignoring") - continue - } - if !bytes.Equal(b, eaasMagic) { - log.Println("UDP broadcast invalid, ignoring") - continue - } - udpC <- addr - } - }() - hostname, err := os.Hostname() - if err != nil { - hostname = "demo" - } - go func() { - log.Println("Listening on UDP") - for { - select { - case addr := <-udpC: - msg := new(bytes.Buffer) - msg.Write(eaasResponseMagic) - msg.Write(token[:]) - messages.WriteNetworkString(msg, hostname) - uri := fmt.Sprintf("grpc://%s:%d", "192.168.188.120", 50010) - binary.Write(msg, binary.BigEndian, uint32(len(uri))) - messages.WriteNetworkString(msg, appVersion) - msg.Write([]byte{0, 0, 0, 2, 0, 0x5f}) // TODO - b := msg.Bytes() - log.Println("Sending UDP beacon\n", hex.Dump(b)) - udpListener.WriteToUDP(b, addr) - case <-ctx.Done(): - _ = udpListener.Close() - return - } - } + defer func() { + log.Println("Beacon shutting down") + beacon.Shutdown() }() - // wait for interrupt/term + // Wait for interrupt/term + log.Println("Running") <-ctx.Done() } diff --git a/cmd/storage/network_trust_service_server.go b/cmd/storage/network_trust_service_server.go index 60418c7..7e66274 100644 --- a/cmd/storage/network_trust_service_server.go +++ b/cmd/storage/network_trust_service_server.go @@ -2,21 +2,22 @@ package main import ( "context" - "time" + "log" "github.com/icedream/go-stagelinq/eaas/proto/networktrust" ) -var _ networktrust.NetworkTrustServiceServer = &NetworkTrustServer{} +var _ networktrust.NetworkTrustServiceServer = &NetworkTrustServiceServer{} -type NetworkTrustServer struct { +// NetworkTrustServiceServer is an example network trust service server +// implementation. All it does is approve everything that asks. +type NetworkTrustServiceServer struct { networktrust.UnimplementedNetworkTrustServiceServer } // CreateTrust implements networktrust.NetworkTrustServiceServer. -func (n *NetworkTrustServer) CreateTrust(ctx context.Context, _ *networktrust.CreateTrustRequest) (*networktrust.CreateTrustResponse, error) { - // safety sleep to not confuse Engine - time.After(time.Second) +func (n *NetworkTrustServiceServer) CreateTrust(ctx context.Context, req *networktrust.CreateTrustRequest) (*networktrust.CreateTrustResponse, error) { + log.Printf("CreateTrust: %+v", req) // Just allow all for now return &networktrust.CreateTrustResponse{ diff --git a/eaas/announcer.go b/eaas/announcer.go deleted file mode 100644 index 1890260..0000000 --- a/eaas/announcer.go +++ /dev/null @@ -1,158 +0,0 @@ -package eaas - -import ( - "bytes" - "context" - "errors" - "math/rand" - "net" - "sync" - "time" - - "github.com/icedream/go-stagelinq/internal/socket" -) - -// Announcer listens on UDP port 11224 for EAAS clients and announces itself to them. -type Announcer struct { - softwareVersion string - hostname string - packetConn net.PacketConn - token Token - port uint16 - shutdownCond *sync.Cond - shutdownWaitGroup sync.WaitGroup -} - -// Token returns our token that is being announced to the EAAS network. Use this -// token for further communication with services on other devices. -func (l *Announcer) Token() Token { - return l.token -} - -// Close shuts down the listener. -func (l *Announcer) Close() error { - // notify goroutines we are going to shut down and wait for them to finish - l.shutdownCond.Broadcast() - l.shutdownWaitGroup.Wait() - - return l.packetConn.Close() -} - -// Announce announces this EAAS listener to the network. -// -// This function should be called before actually listening in for devices to -// allow them to pick up our token for communication immediately. -func (l *Announcer) Announce() error { - return l.announce() -} - -// AnnounceEvery will start a goroutine which calls the Announce function at given interval. -// It will automatically terminate once this listener is shut down. -// A recommended value for the interval is 1 second. -func (l *Announcer) AnnounceEvery(interval time.Duration) { - shutdownC := make(chan interface{}, 1) - - // make Close() wait for us - l.shutdownWaitGroup.Add(1) - - // listen for shutdown signal broadcast, forward it to our own channel - go func() { - l.shutdownCond.L.Lock() - defer l.shutdownCond.L.Unlock() - l.shutdownCond.Wait() - shutdownC <- nil - }() - - go func() { - defer l.shutdownWaitGroup.Done() - - // timestamp for when to send next announcement - ticker := time.NewTicker(interval) - - // do first announcement immediately - l.Announce() - - for { - select { - case <-ticker.C: // next interval - announcement - if err := l.Announce(); errors.Is(err, net.ErrClosed) { - return - } - // NOTE - Considering AnnounceEvery is a fire-and-forget command we're ignoring other errors here for now. Not sure how to properly handle them otherwise atm. - case <-shutdownC: - return - } - } - }() -} - -func (l *Announcer) announce() (err error) { - // TODO - optimization: cache the built message because it will be sent repeatedly? - m := &eaasDiscoveryRequestMessage{} - b := new(bytes.Buffer) - err = m.WriteMessageTo(b) - if err != nil { - return - } - finalBytes := b.Bytes() - ips, err := socket.GetAllBroadcastIPs() - if err != nil { - return - } - for _, ip := range ips { - addr := makeEAASDiscoveryBroadcastAddress(ip) - packetConn, err := net.DialUDP("udp", nil, addr) - if err == nil { - _, _ = packetConn.Write(finalBytes) - packetConn.Close() - } - } - - return -} - -// Announce sets up a EAAS announcer. -func Announce() (announcer *Announcer, err error) { - return AnnounceWithConfiguration(nil) -} - -var zeroToken = Token{} - -// AnnounceWithConfiguration sets up a EAAS announcer with the given configuration. -func AnnounceWithConfiguration(announcerConfig *AnnouncerConfiguration) (announcer *Announcer, err error) { - // Use empty configuration if no configuration object was passed - if announcerConfig == nil { - announcerConfig = new(AnnouncerConfiguration) - } - - // Initialize token if none was configured - token := announcerConfig.Token - if bytes.Equal(announcerConfig.Token[:], zeroToken[:]) { - if _, err = rand.Read(token[:]); err != nil { - return - } - } - - // Use background context if none was configured - ctx := announcerConfig.Context - if ctx == nil { - ctx = context.Background() - } - - // We are setting up a shared UDP address socket here to allow other applications to still listen for EAAS discovery messages - config := &net.ListenConfig{ - Control: socket.SetSocketControlForReusePort, - } - packetConn, err := config.ListenPacket(ctx, eaasDiscoveryNetwork, eaasDiscoveryAddressString) - if err != nil { - return - } - - return &Announcer{ - hostname: announcerConfig.Name, - packetConn: packetConn, - softwareVersion: announcerConfig.SoftwareVersion, - token: token, - shutdownCond: sync.NewCond(&sync.Mutex{}), - }, nil -} diff --git a/eaas/announcer_configuration.go b/eaas/announcer_configuration.go deleted file mode 100644 index 29d6d9a..0000000 --- a/eaas/announcer_configuration.go +++ /dev/null @@ -1,22 +0,0 @@ -package eaas - -import ( - "context" -) - -// AnnouncerConfiguration contains configurable values for setting up a EAAS -// announcer. -type AnnouncerConfiguration struct { - // Context can be set to allow cancellation of network operations from somewhere else in the code. - Context context.Context - - // Name is the name under which we announce ourselves to the network. Denon - // software tends to use hostname here. - Name string - - // SoftwareVersion is your application's version. It is used for StagelinQ announcements to the network. - SoftwareVersion string - - // Token is used as part of announcements and main data communication. It is currently recommended to leave this empty. - Token Token -} diff --git a/eaas/beacon.go b/eaas/beacon.go new file mode 100644 index 0000000..0871c6d --- /dev/null +++ b/eaas/beacon.go @@ -0,0 +1,200 @@ +package eaas + +import ( + "bytes" + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "sync" + + "github.com/icedream/go-stagelinq/internal/messages" + "github.com/icedream/go-stagelinq/internal/socket" + "golang.org/x/net/ipv4" +) + +// Beacon listens on UDP port 11224 for EAAS clients and announces itself to them. +type Beacon struct { + softwareVersion string + hostname string + packetConn4 *ipv4.PacketConn + token Token + grpcHost string + grpcPort uint16 + shutdownWaitGroup sync.WaitGroup +} + +// Token returns our token that is being announced to the EAAS network. Use this +// token for further communication with services on other devices. +func (l *Beacon) Token() Token { + return l.token +} + +// Shutdown shuts down the listener. +func (l *Beacon) Shutdown() error { + err := l.packetConn4.Close() + + // wait for Listen goroutine to finish + l.shutdownWaitGroup.Wait() + + return err +} + +// List will start a goroutine which waits for EAAS clients to announce back to +// them. It will automatically terminate once this listener is shut down. +func (l *Beacon) listen() { + // make Close() wait for us + l.shutdownWaitGroup.Add(1) + + go func() { + defer l.shutdownWaitGroup.Done() + + b := make([]byte, 8) + for { + n, cm, addr, err := l.packetConn4.ReadFrom(b) + if errors.Is(err, net.ErrClosed) { + break + } + if err != nil { + // TODO - log this somehow + continue + } + if err = l.handleIncomingIPv4Packet(b[0:n], cm, addr); err != nil { + // TODO - log this somehow + continue + } + + } + }() +} + +func (l *Beacon) getIP() net.IP { + return socket.GetIPFromAddress(l.packetConn4.LocalAddr()) +} + +func (l *Beacon) getGRPCURL(ip net.IP) string { + var host string + switch { + case len(l.grpcHost) > 0: + host = l.grpcHost + default: + host = ip.String() + } + return fmt.Sprintf("grpc://%s:%d", host, l.grpcPort) +} + +func (l *Beacon) replyIPv4(cm *ipv4.ControlMessage, srcAddr net.Addr) error { + if cm == nil { + panic("control message must not be nil") + } + // figure out the an IP that we could be reachable from based on the + // interface the broadcast came from + intf, err := net.InterfaceByIndex(cm.IfIndex) + if err != nil { + return err + } + addrs, err := intf.Addrs() + if err != nil { + return err + } + ip := socket.GetIPFromAddress(addrs[0]) + // TODO - optimization: cache the built message because it will be sent repeatedly? + m := &eaasDiscoveryResponseMessage{ + TokenPrefixedMessage: messages.TokenPrefixedMessage{ + Token: messages.Token(l.token), + }, + Hostname: l.hostname, + SoftwareVersion: l.softwareVersion, + URL: l.getGRPCURL(ip), + Extra: "_", + } + b := new(bytes.Buffer) + if err := m.WriteMessageTo(b); err != nil { + return err + } + ncm := &ipv4.ControlMessage{ + IfIndex: cm.IfIndex, + } + if _, err := l.packetConn4.WriteTo(b.Bytes(), ncm, srcAddr); err != nil { + return err + } + return nil +} + +func (l *Beacon) handleIncomingIPv4Packet(b []byte, cm *ipv4.ControlMessage, srcAddr net.Addr) error { + // decode message + r := bytes.NewReader(b) + m := new(eaasDiscoveryRequestMessage) + if err := m.ReadMessageFrom(r); err != nil { + return err + } + + return l.replyIPv4(cm, srcAddr) +} + +// StartBeacon sets up an EAAS beacon. +func StartBeacon() (*Beacon, error) { + return StartBeaconWithConfiguration(nil) +} + +var zeroToken = Token{} + +// StartBeaconWithConfiguration sets up a EAAS announcer with the given configuration. +func StartBeaconWithConfiguration(beaconConfig *BeaconConfiguration) (beacon *Beacon, err error) { + // Use empty configuration if no configuration object was passed + if beaconConfig == nil { + beaconConfig = new(BeaconConfiguration) + } + + // Initialize token if none was configured + token := beaconConfig.Token + if bytes.Equal(beaconConfig.Token[:], zeroToken[:]) { + if _, err = rand.Read(token[:]); err != nil { + return + } + } + + // Use default EAAS gRPC port if none was set + grpcPort := beaconConfig.GRPCPort + if grpcPort == 0 { + grpcPort = DefaultEAASGRPCPort + } + + config := &net.ListenConfig{ + Control: socket.SetSocketControlForReusePort, + } + packetConn, err := config.ListenPacket( + context.TODO(), + "udp4", + makeEAASDiscoveryAddress(net.IPv4zero).String()) + if err != nil { + return + } + + ipv4PacketConn := ipv4.NewPacketConn(packetConn) + // NOTE - this part only works on Linux and is unimplemented on Windows... + // + // This is however necessary so we return the correct IP for Engine software + // to connect to. + if err := ipv4PacketConn.SetControlMessage(ipv4.FlagInterface, true); err != nil { + err = errors.Join( + fmt.Errorf( + "failed to set control messages to forward interface index: %w", + err), + packetConn.Close()) + return nil, err + } + + b := &Beacon{ + packetConn4: ipv4PacketConn, + hostname: beaconConfig.Name, + softwareVersion: beaconConfig.SoftwareVersion, + token: token, + grpcHost: beaconConfig.GRPCHost, + grpcPort: grpcPort, + } + go b.listen() + + return b, nil +} diff --git a/eaas/beacon_configuration.go b/eaas/beacon_configuration.go new file mode 100644 index 0000000..83976c6 --- /dev/null +++ b/eaas/beacon_configuration.go @@ -0,0 +1,35 @@ +package eaas + +// The default EAAS gRPC API port. +const DefaultEAASGRPCPort uint16 = 50010 + +// The default EAAS HTTP server port. +const DefaultEAASHTTPPort uint16 = DefaultEAASGRPCPort + 10 + +// BeaconConfiguration contains configurable values for setting up a EAAS +// announcer. +type BeaconConfiguration struct { + // Name is the name under which we announce ourselves to the network. Denon + // software tends to use hostname here. + Name string + + // SoftwareVersion is your application's version. It is used for StagelinQ + // announcements to the network. + SoftwareVersion string + + // Token is used as part of announcements and main data communication. It is + // currently recommended to leave this empty. + Token Token + + // The host to report back to clients with where the EAAS gRPC API is + // listening on. + // + // If left empty, defaults to the IP the beacon is bound to. It is + // recommended to leave this empty. + GRPCHost string + + // The port to report back to clients which the EAAS gRPC API is listening on. + // + // If left zero, defaults to the default EAAS gRPC API port (50010). + GRPCPort uint16 +} diff --git a/eaas/listener.go b/eaas/listener.go index 544b5ad..46686ba 100644 --- a/eaas/listener.go +++ b/eaas/listener.go @@ -32,7 +32,7 @@ const ( eaasDiscoveryAddressString = ":11224" ) -func makeEAASDiscoveryBroadcastAddress(ip net.IP) *net.UDPAddr { +func makeEAASDiscoveryAddress(ip net.IP) *net.UDPAddr { return &net.UDPAddr{ IP: ip, Port: 11224, diff --git a/go.mod b/go.mod index 747b021..baa3f61 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/gorilla/mux v1.8.1 github.com/lithammer/fuzzysearch v1.1.8 github.com/rivo/tview v0.0.0-20240307173318-e804876934a1 + golang.org/x/net v0.22.0 google.golang.org/grpc v1.62.1 google.golang.org/protobuf v1.33.0 ) @@ -90,7 +91,6 @@ require ( golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.16.0 // indirect - golang.org/x/net v0.22.0 // indirect golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/term v0.18.0 // indirect diff --git a/internal/socket/addr_ip.go b/internal/socket/addr_ip.go new file mode 100644 index 0000000..e07a0b4 --- /dev/null +++ b/internal/socket/addr_ip.go @@ -0,0 +1,18 @@ +package socket + +import ( + "net" +) + +func GetIPFromAddress(address net.Addr) net.IP { + switch convertedAddress := address.(type) { + case *net.UDPAddr: + return convertedAddress.IP + case *net.TCPAddr: + return convertedAddress.IP + case *net.IPNet: + return convertedAddress.IP + default: + panic("unsupported network address type") + } +} diff --git a/internal/socket/addr_ip_test.go b/internal/socket/addr_ip_test.go new file mode 100644 index 0000000..246a067 --- /dev/null +++ b/internal/socket/addr_ip_test.go @@ -0,0 +1,24 @@ +package socket + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_GetIPFromAddress(t *testing.T) { + testIP := net.IPv4(1, 2, 3, 4) + require.Equal(t, testIP, GetIPFromAddress(&net.TCPAddr{ + IP: testIP, + Port: 12345, + })) + require.Equal(t, testIP, GetIPFromAddress(&net.UDPAddr{ + IP: testIP, + Port: 12345, + })) + require.Equal(t, testIP, GetIPFromAddress(&net.IPNet{ + IP: testIP, + Mask: net.IPv4Mask(255, 255, 255, 0), + })) +} diff --git a/internal/socket/addr_mask.go b/internal/socket/addr_mask.go new file mode 100644 index 0000000..20d4309 --- /dev/null +++ b/internal/socket/addr_mask.go @@ -0,0 +1,14 @@ +package socket + +import ( + "net" +) + +func GetMaskFromAddress(address net.Addr) net.IPMask { + switch convertedAddress := address.(type) { + case *net.IPNet: + return convertedAddress.Mask + default: + return nil + } +} diff --git a/internal/socket/addr_mask_test.go b/internal/socket/addr_mask_test.go new file mode 100644 index 0000000..cc908b6 --- /dev/null +++ b/internal/socket/addr_mask_test.go @@ -0,0 +1,16 @@ +package socket + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_GetMaskFromAddress(t *testing.T) { + testMask := net.IPv4Mask(255, 255, 255, 0) + require.Equal(t, testMask, GetMaskFromAddress(&net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: testMask, + })) +} diff --git a/internal/socket/addr_port_go1.18.go b/internal/socket/addr_port_go1.18.go new file mode 100644 index 0000000..e1cc446 --- /dev/null +++ b/internal/socket/addr_port_go1.18.go @@ -0,0 +1,22 @@ +//go:build go1.18 +// +build go1.18 + +package socket + +import ( + "net" + "net/netip" +) + +type convertableToAddrPort interface { + AddrPort() netip.AddrPort +} + +func GetPortFromAddress(address net.Addr) uint16 { + switch convertedAddress := address.(type) { + case convertableToAddrPort: + return convertedAddress.AddrPort().Port() + default: + panic("unsupported network address type") + } +} diff --git a/internal/socket/addr_port_legacy.go b/internal/socket/addr_port_legacy.go new file mode 100644 index 0000000..4d94da9 --- /dev/null +++ b/internal/socket/addr_port_legacy.go @@ -0,0 +1,24 @@ +//go:build !go1.18 +// +build !go1.18 + +package socket + +import ( + "net" + "net/netip" +) + +type convertableToAddrPort interface { + AddrPort() netip.AddrPort +} + +func GetPortFromAddress(address net.Addr) uint16 { + switch convertedAddress := address.(type) { + case *net.UDPAddr: + return uint16(convertedAddress.Port) + case *net.TCPAddr: + return uint16(convertedAddress.Port) + default: + panic("unsupported network address type") + } +} diff --git a/internal/socket/addr_port_test.go b/internal/socket/addr_port_test.go new file mode 100644 index 0000000..f3de6e0 --- /dev/null +++ b/internal/socket/addr_port_test.go @@ -0,0 +1,19 @@ +package socket + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_GetPortFromAddress(t *testing.T) { + require.Equal(t, 12345, GetPortFromAddress(&net.TCPAddr{ + IP: net.IPv4(1, 2, 3, 4), + Port: 12345, + })) + require.Equal(t, 12345, GetPortFromAddress(&net.UDPAddr{ + IP: net.IPv4(1, 2, 3, 4), + Port: 12345, + })) +} diff --git a/internal/socket/broadcast.go b/internal/socket/broadcast.go index 6f81ad3..371f751 100644 --- a/internal/socket/broadcast.go +++ b/internal/socket/broadcast.go @@ -39,3 +39,29 @@ addrsLoop: retval = ips return } + +func MakeBroadcastIP(ip net.IP, mask net.IPMask) (bip net.IP) { + // get 4-byte representation of ipv4 is possible, nil if not an ipv4 address + convertedIPv4 := false + if ip4 := ip.To4(); ip4 != nil { + convertedIPv4 = len(ip) != len(ip4) + ip = ip4 + } + + if len(mask) != len(ip) { + // mask and ip are different sizes, panic! + panic("net mask and ip address are different sizes") + } + + bip = make(net.IP, len(ip)) + for i := range mask { + bip[i] = ip[i] | ^mask[i] + } + + // convert back to 16-byte representation if input was 16-byte, too + if convertedIPv4 { + bip = bip.To16() + } + + return +} diff --git a/internal/socket/util_test.go b/internal/socket/broadcast_test.go similarity index 78% rename from internal/socket/util_test.go rename to internal/socket/broadcast_test.go index 457597e..f294e9d 100644 --- a/internal/socket/util_test.go +++ b/internal/socket/broadcast_test.go @@ -7,17 +7,6 @@ import ( "github.com/stretchr/testify/require" ) -func Test_getPort(t *testing.T) { - require.Equal(t, 12345, GetPortFromAddress(&net.TCPAddr{ - IP: net.IPv4(1, 2, 3, 4), - Port: 12345, - })) - require.Equal(t, 12345, GetPortFromAddress(&net.UDPAddr{ - IP: net.IPv4(1, 2, 3, 4), - Port: 12345, - })) -} - func Test_makeBroadcastIP(t *testing.T) { testValues := []struct { IP net.IP diff --git a/internal/socket/reuse_port_nonwindows.go b/internal/socket/reuse_port_nonwindows.go deleted file mode 100644 index 88f9043..0000000 --- a/internal/socket/reuse_port_nonwindows.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !windows -// +build !windows - -package socket - -import "syscall" - -func SetSocketControlForReusePort(_, _ string, c syscall.RawConn) error { - return c.Control(func(fd uintptr) { - syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) - syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_DONTROUTE, 1) - }) -} diff --git a/internal/socket/sockopts.go b/internal/socket/sockopts.go new file mode 100644 index 0000000..694a91a --- /dev/null +++ b/internal/socket/sockopts.go @@ -0,0 +1,9 @@ +package socket + +import "syscall" + +// The function type needed to be provided to [net.Dialer.Control]. +type controlFunc func(network, address string, c syscall.RawConn) error + +// Sanity check +var _ controlFunc = SetSocketControlForReusePort diff --git a/internal/socket/sockopts_nonwindows.go b/internal/socket/sockopts_nonwindows.go new file mode 100644 index 0000000..7ec15c1 --- /dev/null +++ b/internal/socket/sockopts_nonwindows.go @@ -0,0 +1,23 @@ +//go:build !windows +// +build !windows + +package socket + +import ( + "log" + "syscall" +) + +func SetSocketControlForReusePort(_, _ string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { + log.Println("Could not set sockopt SO_REUSEADDR:", err) + } + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1); err != nil { + log.Println("Could not set sockopt SO_BROADCAST:", err) + } + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_DONTROUTE, 1); err != nil { + log.Println("Could not set sockopt SO_DONTROUTE:", err) + } + }) +} diff --git a/internal/socket/reuse_port_windows.go b/internal/socket/sockopts_windows.go similarity index 100% rename from internal/socket/reuse_port_windows.go rename to internal/socket/sockopts_windows.go diff --git a/internal/socket/util.go b/internal/socket/util.go deleted file mode 100644 index 7430981..0000000 --- a/internal/socket/util.go +++ /dev/null @@ -1,42 +0,0 @@ -package socket - -import ( - "net" -) - -func GetPortFromAddress(address net.Addr) int { - switch convertedAddress := address.(type) { - case *net.UDPAddr: - return convertedAddress.Port - case *net.TCPAddr: - return convertedAddress.Port - default: - panic("unsupported network address type") - } -} - -func MakeBroadcastIP(ip net.IP, mask net.IPMask) (bip net.IP) { - // get 4-byte representation of ipv4 is possible, nil if not an ipv4 address - convertedIPv4 := false - if ip4 := ip.To4(); ip4 != nil { - convertedIPv4 = len(ip) != len(ip4) - ip = ip4 - } - - if len(mask) != len(ip) { - // mask and ip are different sizes, panic! - panic("net mask and ip address are different sizes") - } - - bip = make(net.IP, len(ip)) - for i := range mask { - bip[i] = ip[i] | ^mask[i] - } - - // convert back to 16-byte representation if input was 16-byte, too - if convertedIPv4 { - bip = bip.To16() - } - - return -}