diff --git a/internal/core/core.go b/internal/core/core.go index d52adbe..48ed181 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -152,13 +152,12 @@ OUTER: go c.processSynResult(res.Payload.(*scanner.SynScanResult)) case scanner.SYNDone: c.processSynDone() + c.log.Info().Msg("received syn done: breaking loop") break OUTER } } } - c.scanner.Stop() - c.log.Info().Str("duration", time.Since(start).String()).Msg("go-lanscan complete") return nil diff --git a/internal/core/core_test.go b/internal/core/core_test.go index 4334cd5..735bb64 100644 --- a/internal/core/core_test.go +++ b/internal/core/core_test.go @@ -225,8 +225,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -294,8 +292,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -346,8 +342,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -441,8 +435,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -537,8 +529,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -611,8 +601,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 1, @@ -646,8 +634,6 @@ func TestCore(t *testing.T) { return nil }) - mockScanner.EXPECT().Stop() - runner.Initialize( mockScanner, 10, diff --git a/mock/scanner/scanner.go b/mock/scanner/scanner.go index 8f48373..3248d3c 100644 --- a/mock/scanner/scanner.go +++ b/mock/scanner/scanner.go @@ -208,6 +208,22 @@ func (mr *MockPacketCaptureHandleMockRecorder) WritePacketData(arg0 any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacketData", reflect.TypeOf((*MockPacketCaptureHandle)(nil).WritePacketData), arg0) } +// ZeroCopyReadPacketData mocks base method. +func (m *MockPacketCaptureHandle) ZeroCopyReadPacketData() ([]byte, gopacket.CaptureInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ZeroCopyReadPacketData") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(gopacket.CaptureInfo) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ZeroCopyReadPacketData indicates an expected call of ZeroCopyReadPacketData. +func (mr *MockPacketCaptureHandleMockRecorder) ZeroCopyReadPacketData() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZeroCopyReadPacketData", reflect.TypeOf((*MockPacketCaptureHandle)(nil).ZeroCopyReadPacketData)) +} + // MockPacketCapture is a mock of PacketCapture interface. type MockPacketCapture struct { ctrl *gomock.Controller diff --git a/pkg/internal/test-helper/test-helper.go b/pkg/internal/test-helper/test-helper.go index a74d3e4..02a4dc5 100644 --- a/pkg/internal/test-helper/test-helper.go +++ b/pkg/internal/test-helper/test-helper.go @@ -28,6 +28,8 @@ func NewArpReplyReadResult(srcIP net.IP, srcHwAddr net.HardwareAddr) (data []byt DstProtAddress: []byte{192, 168, 1, 1}, // Target IP address } + var payload gopacket.Payload + buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -36,7 +38,7 @@ func NewArpReplyReadResult(srcIP net.IP, srcHwAddr net.HardwareAddr) (data []byt } // Serialize the ARP packet. - gopacket.SerializeLayers(buf, opts, ð, &arp) + gopacket.SerializeLayers(buf, opts, ð, &arp, &payload) return buf.Bytes(), gopacket.CaptureInfo{}, nil } @@ -60,6 +62,8 @@ func NewArpRequestReadResult() (data []byte, ci gopacket.CaptureInfo, err error) DstProtAddress: []byte{192, 168, 1, 1}, // Target IP address } + var payload gopacket.Payload + buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -68,7 +72,7 @@ func NewArpRequestReadResult() (data []byte, ci gopacket.CaptureInfo, err error) } // Serialize the ARP packet. - gopacket.SerializeLayers(buf, opts, ð, &arp) + gopacket.SerializeLayers(buf, opts, ð, &arp, &payload) return buf.Bytes(), gopacket.CaptureInfo{}, nil } diff --git a/pkg/scanner/arpscan.go b/pkg/scanner/arpscan.go index 4681dae..7aa8211 100644 --- a/pkg/scanner/arpscan.go +++ b/pkg/scanner/arpscan.go @@ -11,6 +11,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" "github.com/robgonnella/go-lanscan/internal/logger" "github.com/robgonnella/go-lanscan/internal/util" @@ -93,7 +94,7 @@ func (s *ArpScanner) Scan() error { s.networkInfo.Interface().Name, 65536, true, - s.idleTimeout, + pcap.BlockForever, ) if err != nil { @@ -161,22 +162,55 @@ func (s *ArpScanner) SetPacketCapture(cap PacketCapture) { } func (s *ArpScanner) readPackets() { - packetSource := gopacket.NewPacketSource(s.handle, layers.LayerTypeEthernet) - packetSource.DecodeOptions.NoCopy = true - packetSource.DecodeOptions.Lazy = true + stopChan := make(chan struct{}) - defer s.reset() + go func() { + for { + select { + case <-stopChan: + return + default: + var eth layers.Ethernet + var arp layers.ARP + var payload gopacket.Payload + + parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ð, &arp, &payload) + decoded := []gopacket.LayerType{} + packetData, _, err := s.handle.ReadPacketData() + + if err != nil { + s.debug.Error().Err(err).Msg("arp: error reading packet") + continue + } + + err = parser.DecodeLayers(packetData, &decoded) + + if err != nil { + s.debug.Error().Err(err).Msg("arp: error decoding packet") + continue + } + + OUTER: + for _, layerType := range decoded { + switch layerType { + case layers.LayerTypeARP: + go s.handleARPLayer(&arp) + break OUTER + } + } + } + } + }() + + defer func() { + stopChan <- struct{}{} + s.reset() + }() for { select { case <-s.ctx.Done(): return - case packet := <-packetSource.Packets(): - arpLayer := packet.Layer(layers.LayerTypeARP) - - if arpLayer != nil { - go s.handleARPLayer(arpLayer.(*layers.ARP)) - } default: s.packetSentAtMux.RLock() packetSentAt := s.lastPacketSentAt diff --git a/pkg/scanner/arpscan_test.go b/pkg/scanner/arpscan_test.go index 1237a05..7c51f15 100644 --- a/pkg/scanner/arpscan_test.go +++ b/pkg/scanner/arpscan_test.go @@ -111,6 +111,62 @@ func TestArpScanner(t *testing.T) { assert.ErrorIs(st, mockErr, err) }) + t.Run("prints debug message if reading packets returns error", func(st *testing.T) { + cap := mock_scanner.NewMockPacketCapture(ctrl) + handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) + mockNetInfo := mock_network.NewMockNetwork(ctrl) + + arpScanner := scanner.NewArpScanner( + []string{}, + mockNetInfo, + scanner.WithPacketCapture(cap), + ) + + wg := sync.WaitGroup{} + wg.Add(1) + + mockNetInfo.EXPECT().Interface().AnyTimes().Return(mockInterface) + mockNetInfo.EXPECT().IPNet().Return(mockIPNet).AnyTimes() + mockNetInfo.EXPECT().UserIP().Return(mockUserIP) + mockNetInfo.EXPECT().Cidr().AnyTimes().Return(cidr) + + cap.EXPECT().OpenLive( + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any()).Return(handle, nil) + + cap.EXPECT().SerializeLayers(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + handle.EXPECT().Close().AnyTimes() + + handle.EXPECT().WritePacketData(gomock.Any()).DoAndReturn(func(data []byte) (err error) { + defer func() { + arpScanner.Stop() + wg.Done() + }() + return nil + }) + + firstCall := true + handle.EXPECT().ReadPacketData().AnyTimes().DoAndReturn(func() (data []byte, ci gopacket.CaptureInfo, err error) { + if firstCall { + firstCall = false + return nil, gopacket.CaptureInfo{}, errors.New("mock ReadPacketData error") + } + return test_helper.NewArpReplyReadResult( + mockNonIncludedArpSrcIP, + []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, + ) + }) + + err := arpScanner.Scan() + + wg.Wait() + + assert.NoError(st, err) + }) + t.Run("performs arp scan on default network info", func(st *testing.T) { cap := mock_scanner.NewMockPacketCapture(ctrl) handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) diff --git a/pkg/scanner/fullscan.go b/pkg/scanner/fullscan.go index 1e50aca..c5842dd 100644 --- a/pkg/scanner/fullscan.go +++ b/pkg/scanner/fullscan.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/robgonnella/go-lanscan/internal/logger" "github.com/robgonnella/go-lanscan/internal/util" "github.com/robgonnella/go-lanscan/pkg/network" "github.com/robgonnella/go-lanscan/pkg/oui" @@ -30,6 +31,7 @@ type FullScanner struct { scanning bool scanningMux *sync.RWMutex deviceMux *sync.RWMutex + debug logger.DebugLogger } func NewFullScanner( @@ -68,6 +70,7 @@ func NewFullScanner( scanning: false, scanningMux: &sync.RWMutex{}, deviceMux: &sync.RWMutex{}, + debug: logger.NewDebugLogger(), } for _, o := range options { @@ -142,6 +145,7 @@ func (s *FullScanner) Stop() { if s.synScanner != nil { s.synScanner.Stop() } + s.debug.Info().Msg("all scanners stopped") } func (s *FullScanner) SetRequestNotifications(cb func(req *Request)) { diff --git a/pkg/scanner/options.go b/pkg/scanner/options.go index 4994157..8c492dc 100644 --- a/pkg/scanner/options.go +++ b/pkg/scanner/options.go @@ -11,7 +11,7 @@ import ( // How long to wait before sending next packet // the faster you send packets the more packets // will be missed when reading -const defaultAccuracy = time.Millisecond * 100 +const defaultAccuracy = time.Microsecond * 100 type ScannerOption = func(s Scanner) diff --git a/pkg/scanner/synscan.go b/pkg/scanner/synscan.go index 5601b8d..61b6ed1 100644 --- a/pkg/scanner/synscan.go +++ b/pkg/scanner/synscan.go @@ -5,12 +5,12 @@ package scanner import ( "context" "fmt" - "net" "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" "github.com/thediveo/netdb" "github.com/robgonnella/go-lanscan/internal/logger" @@ -19,6 +19,11 @@ import ( "github.com/robgonnella/go-lanscan/pkg/oui" ) +type SynPacket struct { + Ip4 *layers.IPv4 + TCP *layers.TCP +} + type SynScanner struct { ctx context.Context cancel context.CancelFunc @@ -102,7 +107,7 @@ func (s *SynScanner) Scan() error { s.networkInfo.Interface().Name, 65536, true, - s.idleTimeout, + pcap.BlockForever, ) if err != nil { @@ -186,18 +191,60 @@ func (s *SynScanner) SetTargets(targets []*ArpScanResult) { } func (s *SynScanner) readPackets() { - packetSource := gopacket.NewPacketSource(s.handle, layers.LayerTypeEthernet) - packetSource.DecodeOptions.NoCopy = true - packetSource.DecodeOptions.Lazy = true + stopChan := make(chan struct{}) + + go func() { + for { + select { + case <-stopChan: + return + default: + var eth layers.Ethernet + var ip4 layers.IPv4 + var tcp layers.TCP + var payload gopacket.Payload + + parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ð, &ip4, &tcp, &payload) + decoded := []gopacket.LayerType{} + packetData, _, err := s.handle.ReadPacketData() + + if err != nil { + s.debug.Error().Err(err).Msg("syn: read packet error") + } + + err = parser.DecodeLayers(packetData, &decoded) + + if err != nil { + s.debug.Error().Err(err).Msg("syn: decode packet error") + } + + synPacket := &SynPacket{} + + for _, layerType := range decoded { + switch layerType { + case layers.LayerTypeIPv4: + synPacket.Ip4 = &ip4 + case layers.LayerTypeTCP: + synPacket.TCP = &tcp + } + } + + if synPacket.Ip4 != nil && synPacket.TCP != nil { + go s.handlePacket(synPacket) + } + } + } + }() - defer s.reset() + defer func() { + stopChan <- struct{}{} + s.reset() + }() for { select { case <-s.ctx.Done(): return - case packet := <-packetSource.Packets(): - go s.handlePacket(packet) default: s.packetSentAtMux.RLock() packetSentAt := s.lastPacketSentAt @@ -213,19 +260,13 @@ func (s *SynScanner) readPackets() { } } -func (s *SynScanner) handlePacket(packet gopacket.Packet) { - netLayer := packet.NetworkLayer() - - if netLayer == nil { - return - } - - srcIP := netLayer.NetworkFlow().Src().String() +func (s *SynScanner) handlePacket(synPacket *SynPacket) { + srcIP := synPacket.Ip4.SrcIP var targetIdx int isExpected := util.SliceIncludesFunc(s.targets, func(t *ArpScanResult, i int) bool { - if t.IP.Equal(net.ParseIP(srcIP)) { + if t.IP.Equal(srcIP) { targetIdx = i return true } @@ -239,23 +280,23 @@ func (s *SynScanner) handlePacket(packet gopacket.Packet) { target := s.targets[targetIdx] - tcpLayer := packet.Layer(layers.LayerTypeTCP) - - if tcpLayer == nil { + if synPacket.TCP.DstPort != layers.TCPPort(s.listenPort) { return } - tcp := tcpLayer.(*layers.TCP) - - if tcp.DstPort != layers.TCPPort(s.listenPort) { - return + fields := map[string]interface{}{ + "ip": target.IP, + "port": synPacket.TCP.SrcPort.String(), + "open": synPacket.TCP.SYN && synPacket.TCP.ACK, } - if tcp.SYN && tcp.ACK { + s.debug.Info().Fields(fields).Msg("received response") + + if synPacket.TCP.SYN && synPacket.TCP.ACK { serviceName := "" s.serviceQueryMux.Lock() - service := netdb.ServiceByPort(int(tcp.SrcPort), "") + service := netdb.ServiceByPort(int(synPacket.TCP.SrcPort), "") s.serviceQueryMux.Unlock() if service != nil { @@ -267,7 +308,7 @@ func (s *SynScanner) handlePacket(packet gopacket.Packet) { IP: target.IP, Status: StatusOnline, Port: Port{ - ID: uint16(tcp.SrcPort), + ID: uint16(synPacket.TCP.SrcPort), Service: serviceName, Status: PortOpen, }, diff --git a/pkg/scanner/synscan_test.go b/pkg/scanner/synscan_test.go index c9a19ad..730d1ef 100644 --- a/pkg/scanner/synscan_test.go +++ b/pkg/scanner/synscan_test.go @@ -119,6 +119,74 @@ func TestSynScanner(t *testing.T) { assert.ErrorIs(st, mockErr, err) }) + t.Run("prints debug message if reading packets returns error", func(st *testing.T) { + cap := mock_scanner.NewMockPacketCapture(ctrl) + handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) + netInfo := mock_network.NewMockNetwork(ctrl) + + wg := sync.WaitGroup{} + wg.Add(1) + + listenPort := uint16(54321) + + synScanner := scanner.NewSynScanner( + []*scanner.ArpScanResult{ + { + IP: net.ParseIP("172.17.1.1"), + MAC: mockInterface.HardwareAddr, + Vendor: "unknown", + }, + }, + netInfo, + []string{"22"}, + listenPort, + scanner.WithPacketCapture(cap), + ) + + netInfo.EXPECT().Interface().AnyTimes().Return(mockInterface) + netInfo.EXPECT().Cidr().AnyTimes().Return(cidr) + netInfo.EXPECT().UserIP().Return(mockUserIP) + + cap.EXPECT().OpenLive( + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any()).Return(handle, nil) + + cap.EXPECT().SerializeLayers(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + handle.EXPECT().SetBPFFilter(gomock.Any()) + handle.EXPECT().Close().AnyTimes() + + handle.EXPECT().WritePacketData(gomock.Any()).DoAndReturn(func(data []byte) (err error) { + defer func() { + synScanner.Stop() + wg.Done() + }() + return nil + }) + + firstCall := true + + handle.EXPECT().ReadPacketData().AnyTimes().DoAndReturn(func() (data []byte, ci gopacket.CaptureInfo, err error) { + if firstCall { + firstCall = false + return nil, gopacket.CaptureInfo{}, errors.New("mock ReadPacketData error") + } + return test_helper.NewSynWithAckResponsePacketBytes( + net.ParseIP("172.17.1.1"), + 22, + listenPort, + ) + }) + + err := synScanner.Scan() + + wg.Wait() + + assert.NoError(st, err) + }) + t.Run("returns error if SetBPFFilter returns error", func(st *testing.T) { cap := mock_scanner.NewMockPacketCapture(ctrl) handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) diff --git a/pkg/scanner/types.go b/pkg/scanner/types.go index d7dcd78..8508261 100644 --- a/pkg/scanner/types.go +++ b/pkg/scanner/types.go @@ -39,6 +39,7 @@ type Scanner interface { type PacketCaptureHandle interface { Close() ReadPacketData() (data []byte, ci gopacket.CaptureInfo, err error) + ZeroCopyReadPacketData() (data []byte, ci gopacket.CaptureInfo, err error) WritePacketData(data []byte) (err error) SetBPFFilter(expr string) (err error) }