diff --git a/README.md b/README.md index fbbc18f..7b37008 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # go-lanscan -![Coverage](https://img.shields.io/badge/Coverage-91.2%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-91.8%25-brightgreen) A network cli and golang package that allows you to perform arp and syn scanning on a local area network. @@ -84,12 +84,10 @@ First you must install the following dependencies You can provide the following options to all scanners -- Provide callback for notifications when packet requests are sent to target +- Provide channel for notifications when packet requests are sent to target ```go - callback := func(request *scanner.Request) { - fmt.Printf("syn packet sent to %s on port %s", request.IP, request.Port) - } + requests := make(chan *scanner.Request) synScanner := scanner.NewSynScanner( targets, @@ -98,15 +96,15 @@ You can provide the following options to all scanners listenPort, synResults, synDone, - scanner.WithRequestNotifications(callback), + scanner.WithRequestNotifications(requests), ) // or - option := scanner.WithRequestNotifications(callback) - option(synScanner) + option := scanner.WithRequestNotifications(requests) + option(requests) // or - synScanner.SetRequestNotifications(callback) + synScanner.SetRequestNotifications(requests) ``` - Provide your own idle timeout. If no packets are received from our targets diff --git a/internal/core/core.go b/internal/core/core.go index d52adbe..c15078b 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -59,25 +59,27 @@ func (r *Results) MarshalJSON() ([]byte, error) { } type Core struct { - arpOnly bool - printJson bool - noProgress bool - outFile string - portLen int - results *Results - pw progress.Writer - arpTracker *progress.Tracker - synTracker *progress.Tracker - errorChan chan error - scanner scanner.Scanner - mux *sync.RWMutex - log logger.Logger + arpOnly bool + printJson bool + noProgress bool + outFile string + portLen int + results *Results + pw progress.Writer + arpTracker *progress.Tracker + synTracker *progress.Tracker + requestNotifier chan *scanner.Request + errorChan chan error + scanner scanner.Scanner + mux *sync.RWMutex + log logger.Logger } func New() *Core { return &Core{ - mux: &sync.RWMutex{}, - log: logger.New(), + requestNotifier: make(chan *scanner.Request), + mux: &sync.RWMutex{}, + log: logger.New(), } } @@ -102,7 +104,7 @@ func (c *Core) Initialize( if noProgress { logger.SetGlobalLevel(zerolog.Disabled) } else { - coreScanner.SetRequestNotifications(c.requestCallback) + coreScanner.SetRequestNotifications(c.requestNotifier) } c.scanner = coreScanner @@ -122,11 +124,11 @@ func (c *Core) Run() error { start := time.Now() if !c.noProgress { + c.pw.AppendTracker(c.arpTracker) + go c.monitorRequestNotifications() go c.pw.Render() } - c.pw.AppendTracker(c.arpTracker) - // run in go routine so we can process in results in parallel go func() { if err := c.scanner.Scan(); err != nil { @@ -157,8 +159,6 @@ OUTER: } } - c.scanner.Stop() - c.log.Info().Str("duration", time.Since(start).String()).Msg("go-lanscan complete") return nil @@ -252,44 +252,6 @@ func (c *Core) processArpDone() { } } -func (c *Core) requestCallback(r *scanner.Request) { - switch r.Type { - case scanner.ArpRequest: - c.arpTracker.Increment(1) - - message := fmt.Sprintf("arp - scanning %s", r.IP) - - if c.arpTracker.IsDone() { - message = "arp - scan complete" - // delay to print line after message is updated - time.AfterFunc(time.Millisecond*100, func() { - c.log.Info().Msg("compiling arp results...") - }) - } - - c.arpTracker.Message = message - case scanner.SynRequest: - c.synTracker.Increment(1) - - message := fmt.Sprintf( - "syn - scanning port %d on %s", - r.Port, - r.IP, - ) - - if c.synTracker.IsDone() { - message = "syn - scan complete" - // delay to print line after message is updated - time.AfterFunc(time.Millisecond*100, func() { - c.log.Info().Msg("compiling syn results...") - }) - } - - c.synTracker.Message = message - - } -} - func (c *Core) printArpResults() { c.mux.RLock() defer c.mux.RUnlock() @@ -387,6 +349,45 @@ func (c *Core) printSynResults() { } } +func (c *Core) monitorRequestNotifications() { + for r := range c.requestNotifier { + switch r.Type { + case scanner.ArpRequest: + c.arpTracker.Increment(1) + + message := fmt.Sprintf("arp - scanning %s", r.IP) + + if c.arpTracker.IsDone() { + message = "arp - scan complete" + // delay to print line after message is updated + time.AfterFunc(time.Millisecond*100, func() { + c.log.Info().Msg("compiling arp results...") + }) + } + + c.arpTracker.Message = message + case scanner.SynRequest: + c.synTracker.Increment(1) + + message := fmt.Sprintf( + "syn - scanning port %d on %s", + r.Port, + r.IP, + ) + + if c.synTracker.IsDone() { + message = "syn - scan complete" + // delay to print line after message is updated + time.AfterFunc(time.Millisecond*100, func() { + c.log.Info().Msg("compiling syn results...") + }) + } + + c.synTracker.Message = message + } + } +} + // helpers func progressWriter() progress.Writer { pw := progress.NewWriter() diff --git a/internal/core/core_test.go b/internal/core/core_test.go index 4334cd5..abc9af4 100644 --- a/internal/core/core_test.go +++ b/internal/core/core_test.go @@ -190,23 +190,23 @@ func TestCore(t *testing.T) { mockScanner.EXPECT().Results().Return(scanResults).AnyTimes() - var callback func(r *scanner.Request) + var requestNotifier chan *scanner.Request mockScanner.EXPECT().SetRequestNotifications(gomock.Any()).DoAndReturn( - func(cb func(r *scanner.Request)) { - callback = cb + func(c chan *scanner.Request) { + requestNotifier = c }, ) mockScanner.EXPECT().Scan().DoAndReturn(func() error { mac, _ := net.ParseMAC("00:00:00:00:00:00") - callback(&scanner.Request{ - Type: scanner.ArpRequest, - IP: "172.17.0.1", - }) - time.AfterFunc(time.Millisecond*100, func() { + requestNotifier <- &scanner.Request{ + Type: scanner.ArpRequest, + IP: "172.17.0.1", + } + scanResults <- &scanner.ScanResult{ Type: scanner.ARPResult, Payload: &scanner.ArpScanResult{ @@ -225,8 +225,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -259,23 +257,23 @@ func TestCore(t *testing.T) { mockScanner.EXPECT().Results().Return(scanResults).AnyTimes() - var callback func(r *scanner.Request) + var requestNotifier chan *scanner.Request mockScanner.EXPECT().SetRequestNotifications(gomock.Any()).DoAndReturn( - func(cb func(r *scanner.Request)) { - callback = cb + func(c chan *scanner.Request) { + requestNotifier = c }, ) mockScanner.EXPECT().Scan().DoAndReturn(func() error { mac, _ := net.ParseMAC("00:00:00:00:00:00") - callback(&scanner.Request{ - Type: scanner.ArpRequest, - IP: "172.17.0.1", - }) - time.AfterFunc(time.Millisecond*100, func() { + requestNotifier <- &scanner.Request{ + Type: scanner.ArpRequest, + IP: "172.17.0.1", + } + scanResults <- &scanner.ScanResult{ Type: scanner.ARPResult, Payload: &scanner.ArpScanResult{ @@ -294,8 +292,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -346,8 +342,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -379,11 +373,11 @@ func TestCore(t *testing.T) { mockScanner.EXPECT().Results().Return(scanResults).AnyTimes() - var callback func(r *scanner.Request) + var requestNotifier chan *scanner.Request mockScanner.EXPECT().SetRequestNotifications(gomock.Any()).DoAndReturn( - func(cb func(r *scanner.Request)) { - callback = cb + func(c chan *scanner.Request) { + requestNotifier = c }, ) @@ -391,18 +385,12 @@ func TestCore(t *testing.T) { ip := net.ParseIP("172.17.0.1") mac, _ := net.ParseMAC("00:00:00:00:00:00") - callback(&scanner.Request{ - Type: scanner.ArpRequest, - IP: "172.17.0.1", - }) - - callback(&scanner.Request{ - Type: scanner.SynRequest, - IP: "172.17.0.1", - Port: 22, - }) - time.AfterFunc(time.Millisecond*100, func() { + requestNotifier <- &scanner.Request{ + Type: scanner.ArpRequest, + IP: "172.17.0.1", + } + scanResults <- &scanner.ScanResult{ Type: scanner.ARPResult, Payload: &scanner.ArpScanResult{ @@ -418,6 +406,12 @@ func TestCore(t *testing.T) { } }) time.AfterFunc(time.Millisecond*300, func() { + requestNotifier <- &scanner.Request{ + Type: scanner.SynRequest, + IP: "172.17.0.1", + Port: 22, + } + scanResults <- &scanner.ScanResult{ Type: scanner.SYNResult, Payload: &scanner.SynScanResult{ @@ -441,8 +435,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -475,11 +467,11 @@ func TestCore(t *testing.T) { mockScanner.EXPECT().Results().Return(scanResults).AnyTimes() - var callback func(r *scanner.Request) + var requestNotifier chan *scanner.Request mockScanner.EXPECT().SetRequestNotifications(gomock.Any()).DoAndReturn( - func(cb func(r *scanner.Request)) { - callback = cb + func(c chan *scanner.Request) { + requestNotifier = c }, ) @@ -487,18 +479,12 @@ func TestCore(t *testing.T) { ip := net.ParseIP("172.17.0.1") mac, _ := net.ParseMAC("00:00:00:00:00:00") - callback(&scanner.Request{ - Type: scanner.ArpRequest, - IP: "172.17.0.1", - }) - - callback(&scanner.Request{ - Type: scanner.SynRequest, - IP: "172.17.0.1", - Port: 22, - }) - time.AfterFunc(time.Millisecond*100, func() { + requestNotifier <- &scanner.Request{ + Type: scanner.ArpRequest, + IP: "172.17.0.1", + } + scanResults <- &scanner.ScanResult{ Type: scanner.ARPResult, Payload: &scanner.ArpScanResult{ @@ -514,6 +500,12 @@ func TestCore(t *testing.T) { } }) time.AfterFunc(time.Millisecond*300, func() { + requestNotifier <- &scanner.Request{ + Type: scanner.SynRequest, + IP: "172.17.0.1", + Port: 22, + } + scanResults <- &scanner.ScanResult{ Type: scanner.SYNResult, Payload: &scanner.SynScanResult{ @@ -537,8 +529,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -611,8 +601,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -646,8 +634,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 10, diff --git a/mock/scanner/scanner.go b/mock/scanner/scanner.go index 8f48373..dee3c27 100644 --- a/mock/scanner/scanner.go +++ b/mock/scanner/scanner.go @@ -106,7 +106,7 @@ func (mr *MockScannerMockRecorder) SetPacketCapture(arg0 any) *gomock.Call { } // SetRequestNotifications mocks base method. -func (m *MockScanner) SetRequestNotifications(arg0 func(*scanner.Request)) { +func (m *MockScanner) SetRequestNotifications(arg0 chan *scanner.Request) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetRequestNotifications", arg0) } @@ -208,6 +208,22 @@ func (mr *MockPacketCaptureHandleMockRecorder) WritePacketData(arg0 any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacketData", reflect.TypeOf((*MockPacketCaptureHandle)(nil).WritePacketData), arg0) } +// ZeroCopyReadPacketData mocks base method. +func (m *MockPacketCaptureHandle) ZeroCopyReadPacketData() ([]byte, gopacket.CaptureInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ZeroCopyReadPacketData") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(gopacket.CaptureInfo) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ZeroCopyReadPacketData indicates an expected call of ZeroCopyReadPacketData. +func (mr *MockPacketCaptureHandleMockRecorder) ZeroCopyReadPacketData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZeroCopyReadPacketData", reflect.TypeOf((*MockPacketCaptureHandle)(nil).ZeroCopyReadPacketData)) +} + // MockPacketCapture is a mock of PacketCapture interface. type MockPacketCapture struct { ctrl *gomock.Controller diff --git a/pkg/internal/test-helper/test-helper.go b/pkg/internal/test-helper/test-helper.go index a74d3e4..02a4dc5 100644 --- a/pkg/internal/test-helper/test-helper.go +++ b/pkg/internal/test-helper/test-helper.go @@ -28,6 +28,8 @@ func NewArpReplyReadResult(srcIP net.IP, srcHwAddr net.HardwareAddr) (data []byt DstProtAddress: []byte{192, 168, 1, 1}, // Target IP address } + var payload gopacket.Payload + buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -36,7 +38,7 @@ func NewArpReplyReadResult(srcIP net.IP, srcHwAddr net.HardwareAddr) (data []byt } // Serialize the ARP packet. - gopacket.SerializeLayers(buf, opts, ð, &arp) + gopacket.SerializeLayers(buf, opts, ð, &arp, &payload) return buf.Bytes(), gopacket.CaptureInfo{}, nil } @@ -60,6 +62,8 @@ func NewArpRequestReadResult() (data []byte, ci gopacket.CaptureInfo, err error) DstProtAddress: []byte{192, 168, 1, 1}, // Target IP address } + var payload gopacket.Payload + buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -68,7 +72,7 @@ func NewArpRequestReadResult() (data []byte, ci gopacket.CaptureInfo, err error) } // Serialize the ARP packet. - gopacket.SerializeLayers(buf, opts, ð, &arp) + gopacket.SerializeLayers(buf, opts, ð, &arp, &payload) return buf.Bytes(), gopacket.CaptureInfo{}, nil } diff --git a/pkg/scanner/arpscan.go b/pkg/scanner/arpscan.go index 4681dae..43fa2d8 100644 --- a/pkg/scanner/arpscan.go +++ b/pkg/scanner/arpscan.go @@ -11,6 +11,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" "github.com/robgonnella/go-lanscan/internal/logger" "github.com/robgonnella/go-lanscan/internal/util" @@ -26,7 +27,7 @@ type ArpScanner struct { cap PacketCapture handle PacketCaptureHandle resultChan chan *ScanResult - notificationCB func(a *Request) + requestNotifier chan *Request scanning bool lastPacketSentAt time.Time idleTimeout time.Duration @@ -93,7 +94,7 @@ func (s *ArpScanner) Scan() error { s.networkInfo.Interface().Name, 65536, true, - s.idleTimeout, + pcap.BlockForever, ) if err != nil { @@ -141,8 +142,8 @@ func (s *ArpScanner) Stop() { } } -func (s *ArpScanner) SetRequestNotifications(cb func(a *Request)) { - s.notificationCB = cb +func (s *ArpScanner) SetRequestNotifications(c chan *Request) { + s.requestNotifier = c } func (s *ArpScanner) SetIdleTimeout(duration time.Duration) { @@ -161,22 +162,55 @@ func (s *ArpScanner) SetPacketCapture(cap PacketCapture) { } func (s *ArpScanner) readPackets() { - packetSource := gopacket.NewPacketSource(s.handle, layers.LayerTypeEthernet) - packetSource.DecodeOptions.NoCopy = true - packetSource.DecodeOptions.Lazy = true + stopChan := make(chan struct{}) - defer s.reset() + go func() { + for { + select { + case <-stopChan: + return + default: + var eth layers.Ethernet + var arp layers.ARP + var payload gopacket.Payload + + parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ð, &arp, &payload) + decoded := []gopacket.LayerType{} + packetData, _, err := s.handle.ReadPacketData() + + if err != nil { + s.debug.Error().Err(err).Msg("arp: error reading packet") + continue + } + + err = parser.DecodeLayers(packetData, &decoded) + + if err != nil { + s.debug.Error().Err(err).Msg("arp: error decoding packet") + continue + } + + OUTER: + for _, layerType := range decoded { + switch layerType { + case layers.LayerTypeARP: + go s.handleARPLayer(&arp) + break OUTER + } + } + } + } + }() + + defer func() { + stopChan <- struct{}{} + s.reset() + }() for { select { case <-s.ctx.Done(): return - case packet := <-packetSource.Packets(): - arpLayer := packet.Layer(layers.LayerTypeARP) - - if arpLayer != nil { - go s.handleARPLayer(arpLayer.(*layers.ARP)) - } default: s.packetSentAtMux.RLock() packetSentAt := s.lastPacketSentAt @@ -255,8 +289,10 @@ func (s *ArpScanner) writePacketData(ip net.IP) error { return err } - if s.notificationCB != nil { - go s.notificationCB(&Request{Type: ArpRequest, IP: ip.String()}) + if s.requestNotifier != nil { + go func() { + s.requestNotifier <- &Request{Type: ArpRequest, IP: ip.String()} + }() } return nil diff --git a/pkg/scanner/arpscan_test.go b/pkg/scanner/arpscan_test.go index 1237a05..2e7db05 100644 --- a/pkg/scanner/arpscan_test.go +++ b/pkg/scanner/arpscan_test.go @@ -111,6 +111,62 @@ func TestArpScanner(t *testing.T) { assert.ErrorIs(st, mockErr, err) }) + t.Run("prints debug message if reading packets returns error", func(st *testing.T) { + cap := mock_scanner.NewMockPacketCapture(ctrl) + handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) + mockNetInfo := mock_network.NewMockNetwork(ctrl) + + arpScanner := scanner.NewArpScanner( + []string{}, + mockNetInfo, + scanner.WithPacketCapture(cap), + ) + + wg := sync.WaitGroup{} + wg.Add(1) + + mockNetInfo.EXPECT().Interface().AnyTimes().Return(mockInterface) + mockNetInfo.EXPECT().IPNet().Return(mockIPNet).AnyTimes() + mockNetInfo.EXPECT().UserIP().Return(mockUserIP) + mockNetInfo.EXPECT().Cidr().AnyTimes().Return(cidr) + + cap.EXPECT().OpenLive( + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any()).Return(handle, nil) + + cap.EXPECT().SerializeLayers(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + handle.EXPECT().Close().AnyTimes() + + handle.EXPECT().WritePacketData(gomock.Any()).DoAndReturn(func(data []byte) (err error) { + defer func() { + arpScanner.Stop() + wg.Done() + }() + return nil + }) + + firstCall := true + handle.EXPECT().ReadPacketData().AnyTimes().DoAndReturn(func() (data []byte, ci gopacket.CaptureInfo, err error) { + if firstCall { + firstCall = false + return nil, gopacket.CaptureInfo{}, errors.New("mock ReadPacketData error") + } + return test_helper.NewArpReplyReadResult( + mockNonIncludedArpSrcIP, + []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, + ) + }) + + err := arpScanner.Scan() + + wg.Wait() + + assert.NoError(st, err) + }) + t.Run("performs arp scan on default network info", func(st *testing.T) { cap := mock_scanner.NewMockPacketCapture(ctrl) handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) @@ -498,20 +554,23 @@ func TestArpScanner(t *testing.T) { assert.Panics(st, panicTestFunc) }) - t.Run("calls request notification callback", func(st *testing.T) { + t.Run("sends request notifications", func(st *testing.T) { cap := mock_scanner.NewMockPacketCapture(ctrl) handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) mockNetInfo := mock_network.NewMockNetwork(ctrl) - callback := func(request *scanner.Request) { - assert.NotNil(st, request) - } + requestNotifier := make(chan *scanner.Request) + + go func() { + r := <-requestNotifier + assert.NotNil(st, r) + }() arpScanner := scanner.NewArpScanner( []string{"172.17.1.1"}, mockNetInfo, scanner.WithPacketCapture(cap), - scanner.WithRequestNotifications(callback), + scanner.WithRequestNotifications(requestNotifier), ) wg := sync.WaitGroup{} @@ -558,15 +617,18 @@ func TestArpScanner(t *testing.T) { handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) mockNetInfo := mock_network.NewMockNetwork(ctrl) - callback := func(request *scanner.Request) { - assert.NotNil(st, request) - } + requestNotifier := make(chan *scanner.Request) + + go func() { + r := <-requestNotifier + assert.NotNil(st, r) + }() arpScanner := scanner.NewArpScanner( []string{"172.17.1.1"}, mockNetInfo, scanner.WithPacketCapture(cap), - scanner.WithRequestNotifications(callback), + scanner.WithRequestNotifications(requestNotifier), ) mockNetInfo.EXPECT().Interface().AnyTimes().Return(mockInterface) @@ -605,15 +667,18 @@ func TestArpScanner(t *testing.T) { handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) mockNetInfo := mock_network.NewMockNetwork(ctrl) - callback := func(request *scanner.Request) { - assert.NotNil(st, request) - } + requestNotifier := make(chan *scanner.Request) + + go func() { + r := <-requestNotifier + assert.NotNil(st, r) + }() arpScanner := scanner.NewArpScanner( []string{"172.17.1.1"}, mockNetInfo, scanner.WithPacketCapture(cap), - scanner.WithRequestNotifications(callback), + scanner.WithRequestNotifications(requestNotifier), ) mockNetInfo.EXPECT().Interface().AnyTimes().Return(mockInterface) diff --git a/pkg/scanner/fullscan.go b/pkg/scanner/fullscan.go index 1e50aca..caaaa1a 100644 --- a/pkg/scanner/fullscan.go +++ b/pkg/scanner/fullscan.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/robgonnella/go-lanscan/internal/logger" "github.com/robgonnella/go-lanscan/internal/util" "github.com/robgonnella/go-lanscan/pkg/network" "github.com/robgonnella/go-lanscan/pkg/oui" @@ -30,6 +31,7 @@ type FullScanner struct { scanning bool scanningMux *sync.RWMutex deviceMux *sync.RWMutex + debug logger.DebugLogger } func NewFullScanner( @@ -68,6 +70,7 @@ func NewFullScanner( scanning: false, scanningMux: &sync.RWMutex{}, deviceMux: &sync.RWMutex{}, + debug: logger.NewDebugLogger(), } for _, o := range options { @@ -142,11 +145,12 @@ func (s *FullScanner) Stop() { if s.synScanner != nil { s.synScanner.Stop() } + s.debug.Info().Msg("all scanners stopped") } -func (s *FullScanner) SetRequestNotifications(cb func(req *Request)) { - s.arpScanner.SetRequestNotifications(cb) - s.synScanner.SetRequestNotifications(cb) +func (s *FullScanner) SetRequestNotifications(c chan *Request) { + s.arpScanner.SetRequestNotifications(c) + s.synScanner.SetRequestNotifications(c) } func (s *FullScanner) SetIdleTimeout(d time.Duration) { diff --git a/pkg/scanner/options.go b/pkg/scanner/options.go index 4994157..94c2710 100644 --- a/pkg/scanner/options.go +++ b/pkg/scanner/options.go @@ -11,13 +11,13 @@ import ( // How long to wait before sending next packet // the faster you send packets the more packets // will be missed when reading -const defaultAccuracy = time.Millisecond * 100 +const defaultAccuracy = time.Microsecond * 100 type ScannerOption = func(s Scanner) -func WithRequestNotifications(cb func(a *Request)) ScannerOption { +func WithRequestNotifications(c chan *Request) ScannerOption { return func(s Scanner) { - s.SetRequestNotifications(cb) + s.SetRequestNotifications(c) } } diff --git a/pkg/scanner/options_test.go b/pkg/scanner/options_test.go index 4ee1261..0fbf11f 100644 --- a/pkg/scanner/options_test.go +++ b/pkg/scanner/options_test.go @@ -25,13 +25,14 @@ func TestOptions(t *testing.T) { t.Run("sets options", func(st *testing.T) { netInfo := mock_network.NewMockNetwork(ctrl) + requestNotifier := make(chan *scanner.Request) scanner.NewArpScanner( []string{}, netInfo, scanner.WithIdleTimeout(time.Second*5), scanner.WithPacketCapture(testPacketCapture), - scanner.WithRequestNotifications(func(r *scanner.Request) {}), + scanner.WithRequestNotifications(requestNotifier), scanner.WithVendorInfo(vendorRepo), ) @@ -42,7 +43,7 @@ func TestOptions(t *testing.T) { 54321, scanner.WithIdleTimeout(time.Second*5), scanner.WithPacketCapture(testPacketCapture), - scanner.WithRequestNotifications(func(r *scanner.Request) {}), + scanner.WithRequestNotifications(requestNotifier), scanner.WithVendorInfo(vendorRepo), ) @@ -53,7 +54,7 @@ func TestOptions(t *testing.T) { 54321, scanner.WithIdleTimeout(time.Second*5), scanner.WithPacketCapture(testPacketCapture), - scanner.WithRequestNotifications(func(r *scanner.Request) {}), + scanner.WithRequestNotifications(requestNotifier), scanner.WithVendorInfo(vendorRepo), ) }) diff --git a/pkg/scanner/synscan.go b/pkg/scanner/synscan.go index 5601b8d..757a308 100644 --- a/pkg/scanner/synscan.go +++ b/pkg/scanner/synscan.go @@ -5,12 +5,12 @@ package scanner import ( "context" "fmt" - "net" "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" "github.com/thediveo/netdb" "github.com/robgonnella/go-lanscan/internal/logger" @@ -19,6 +19,11 @@ import ( "github.com/robgonnella/go-lanscan/pkg/oui" ) +type SynPacket struct { + Ip4 *layers.IPv4 + TCP *layers.TCP +} + type SynScanner struct { ctx context.Context cancel context.CancelFunc @@ -29,7 +34,7 @@ type SynScanner struct { cap PacketCapture handle PacketCaptureHandle resultChan chan *ScanResult - notificationCB func(a *Request) + requestNotifier chan *Request scanning bool lastPacketSentAt time.Time idleTimeout time.Duration @@ -102,7 +107,7 @@ func (s *SynScanner) Scan() error { s.networkInfo.Interface().Name, 65536, true, - s.idleTimeout, + pcap.BlockForever, ) if err != nil { @@ -165,8 +170,8 @@ func (s *SynScanner) Stop() { } } -func (s *SynScanner) SetRequestNotifications(cb func(a *Request)) { - s.notificationCB = cb +func (s *SynScanner) SetRequestNotifications(c chan *Request) { + s.requestNotifier = c } func (s *SynScanner) SetIdleTimeout(duration time.Duration) { @@ -186,18 +191,60 @@ func (s *SynScanner) SetTargets(targets []*ArpScanResult) { } func (s *SynScanner) readPackets() { - packetSource := gopacket.NewPacketSource(s.handle, layers.LayerTypeEthernet) - packetSource.DecodeOptions.NoCopy = true - packetSource.DecodeOptions.Lazy = true + stopChan := make(chan struct{}) + + go func() { + for { + select { + case <-stopChan: + return + default: + var eth layers.Ethernet + var ip4 layers.IPv4 + var tcp layers.TCP + var payload gopacket.Payload + + parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ð, &ip4, &tcp, &payload) + decoded := []gopacket.LayerType{} + packetData, _, err := s.handle.ReadPacketData() + + if err != nil { + s.debug.Error().Err(err).Msg("syn: read packet error") + } + + err = parser.DecodeLayers(packetData, &decoded) + + if err != nil { + s.debug.Error().Err(err).Msg("syn: decode packet error") + } + + synPacket := &SynPacket{} + + for _, layerType := range decoded { + switch layerType { + case layers.LayerTypeIPv4: + synPacket.Ip4 = &ip4 + case layers.LayerTypeTCP: + synPacket.TCP = &tcp + } + } + + if synPacket.Ip4 != nil && synPacket.TCP != nil { + go s.handlePacket(synPacket) + } + } + } + }() - defer s.reset() + defer func() { + stopChan <- struct{}{} + s.reset() + }() for { select { case <-s.ctx.Done(): return - case packet := <-packetSource.Packets(): - go s.handlePacket(packet) default: s.packetSentAtMux.RLock() packetSentAt := s.lastPacketSentAt @@ -213,19 +260,13 @@ func (s *SynScanner) readPackets() { } } -func (s *SynScanner) handlePacket(packet gopacket.Packet) { - netLayer := packet.NetworkLayer() - - if netLayer == nil { - return - } - - srcIP := netLayer.NetworkFlow().Src().String() +func (s *SynScanner) handlePacket(synPacket *SynPacket) { + srcIP := synPacket.Ip4.SrcIP var targetIdx int isExpected := util.SliceIncludesFunc(s.targets, func(t *ArpScanResult, i int) bool { - if t.IP.Equal(net.ParseIP(srcIP)) { + if t.IP.Equal(srcIP) { targetIdx = i return true } @@ -239,23 +280,23 @@ func (s *SynScanner) handlePacket(packet gopacket.Packet) { target := s.targets[targetIdx] - tcpLayer := packet.Layer(layers.LayerTypeTCP) - - if tcpLayer == nil { + if synPacket.TCP.DstPort != layers.TCPPort(s.listenPort) { return } - tcp := tcpLayer.(*layers.TCP) - - if tcp.DstPort != layers.TCPPort(s.listenPort) { - return + fields := map[string]interface{}{ + "ip": target.IP, + "port": synPacket.TCP.SrcPort.String(), + "open": synPacket.TCP.SYN && synPacket.TCP.ACK, } - if tcp.SYN && tcp.ACK { + s.debug.Info().Fields(fields).Msg("received response") + + if synPacket.TCP.SYN && synPacket.TCP.ACK { serviceName := "" s.serviceQueryMux.Lock() - service := netdb.ServiceByPort(int(tcp.SrcPort), "") + service := netdb.ServiceByPort(int(synPacket.TCP.SrcPort), "") s.serviceQueryMux.Unlock() if service != nil { @@ -267,7 +308,7 @@ func (s *SynScanner) handlePacket(packet gopacket.Packet) { IP: target.IP, Status: StatusOnline, Port: Port{ - ID: uint16(tcp.SrcPort), + ID: uint16(synPacket.TCP.SrcPort), Service: serviceName, Status: PortOpen, }, @@ -320,12 +361,14 @@ func (s *SynScanner) writePacketData(target *ArpScanResult, port uint16) error { return err } - if s.notificationCB != nil { - go s.notificationCB(&Request{ - Type: SynRequest, - IP: target.IP.String(), - Port: port, - }) + if s.requestNotifier != nil { + go func() { + s.requestNotifier <- &Request{ + Type: SynRequest, + IP: target.IP.String(), + Port: port, + } + }() } return nil diff --git a/pkg/scanner/synscan_test.go b/pkg/scanner/synscan_test.go index c9a19ad..d225474 100644 --- a/pkg/scanner/synscan_test.go +++ b/pkg/scanner/synscan_test.go @@ -119,6 +119,74 @@ func TestSynScanner(t *testing.T) { assert.ErrorIs(st, mockErr, err) }) + t.Run("prints debug message if reading packets returns error", func(st *testing.T) { + cap := mock_scanner.NewMockPacketCapture(ctrl) + handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) + netInfo := mock_network.NewMockNetwork(ctrl) + + wg := sync.WaitGroup{} + wg.Add(1) + + listenPort := uint16(54321) + + synScanner := scanner.NewSynScanner( + []*scanner.ArpScanResult{ + { + IP: net.ParseIP("172.17.1.1"), + MAC: mockInterface.HardwareAddr, + Vendor: "unknown", + }, + }, + netInfo, + []string{"22"}, + listenPort, + scanner.WithPacketCapture(cap), + ) + + netInfo.EXPECT().Interface().AnyTimes().Return(mockInterface) + netInfo.EXPECT().Cidr().AnyTimes().Return(cidr) + netInfo.EXPECT().UserIP().Return(mockUserIP) + + cap.EXPECT().OpenLive( + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any()).Return(handle, nil) + + cap.EXPECT().SerializeLayers(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + handle.EXPECT().SetBPFFilter(gomock.Any()) + handle.EXPECT().Close().AnyTimes() + + handle.EXPECT().WritePacketData(gomock.Any()).DoAndReturn(func(data []byte) (err error) { + defer func() { + synScanner.Stop() + wg.Done() + }() + return nil + }) + + firstCall := true + + handle.EXPECT().ReadPacketData().AnyTimes().DoAndReturn(func() (data []byte, ci gopacket.CaptureInfo, err error) { + if firstCall { + firstCall = false + return nil, gopacket.CaptureInfo{}, errors.New("mock ReadPacketData error") + } + return test_helper.NewSynWithAckResponsePacketBytes( + net.ParseIP("172.17.1.1"), + 22, + listenPort, + ) + }) + + err := synScanner.Scan() + + wg.Wait() + + assert.NoError(st, err) + }) + t.Run("returns error if SetBPFFilter returns error", func(st *testing.T) { cap := mock_scanner.NewMockPacketCapture(ctrl) handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) @@ -453,19 +521,21 @@ func TestSynScanner(t *testing.T) { assert.NoError(st, err) }) - t.Run("calls notification callback", func(st *testing.T) { + t.Run("sends request notifications", func(st *testing.T) { cap := mock_scanner.NewMockPacketCapture(ctrl) handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) netInfo := mock_network.NewMockNetwork(ctrl) - callbackCalled := false + requestNotifier := make(chan *scanner.Request) + notified := false wg := sync.WaitGroup{} wg.Add(2) - callback := func(req *scanner.Request) { - callbackCalled = true + go func() { + <-requestNotifier + notified = true wg.Done() - } + }() listenPort := uint16(54321) @@ -481,7 +551,7 @@ func TestSynScanner(t *testing.T) { []string{"22"}, listenPort, scanner.WithPacketCapture(cap), - scanner.WithRequestNotifications(callback), + scanner.WithRequestNotifications(requestNotifier), ) netInfo.EXPECT().Interface().AnyTimes().Return(mockInterface) @@ -521,6 +591,6 @@ func TestSynScanner(t *testing.T) { assert.NoError(st, err) - assert.True(st, callbackCalled) + assert.True(st, notified) }) } diff --git a/pkg/scanner/types.go b/pkg/scanner/types.go index d7dcd78..a344ed0 100644 --- a/pkg/scanner/types.go +++ b/pkg/scanner/types.go @@ -30,7 +30,7 @@ type Scanner interface { Scan() error Stop() Results() chan *ScanResult - SetRequestNotifications(cb func(a *Request)) + SetRequestNotifications(c chan *Request) SetIdleTimeout(d time.Duration) IncludeVendorInfo(repo oui.VendorRepo) SetPacketCapture(cap PacketCapture) @@ -39,6 +39,7 @@ type Scanner interface { type PacketCaptureHandle interface { Close() ReadPacketData() (data []byte, ci gopacket.CaptureInfo, err error) + ZeroCopyReadPacketData() (data []byte, ci gopacket.CaptureInfo, err error) WritePacketData(data []byte) (err error) SetBPFFilter(expr string) (err error) }