From 6b5ed16ce843da84de8abeca4ace51644232888a Mon Sep 17 00:00:00 2001 From: Rob Gonnella Date: Sat, 13 Jul 2024 20:49:27 -0400 Subject: [PATCH] Minor refactor in scanners Cleans up graceful exit of scanners --- README.md | 2 +- pkg/scanner/arpscan.go | 112 +++++++++++++++++--------------- pkg/scanner/fullscan.go | 16 ++--- pkg/scanner/synscan.go | 125 ++++++++++++++++++++---------------- pkg/scanner/synscan_test.go | 2 +- 5 files changed, 137 insertions(+), 120 deletions(-) diff --git a/README.md b/README.md index 6456468..07157eb 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # go-lanscan -![Coverage](https://img.shields.io/badge/Coverage-91.8%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-91.3%25-brightgreen) A network cli and golang package that allows you to perform arp and syn scanning on a local area network. diff --git a/pkg/scanner/arpscan.go b/pkg/scanner/arpscan.go index 00f95b0..3635444 100644 --- a/pkg/scanner/arpscan.go +++ b/pkg/scanner/arpscan.go @@ -74,6 +74,8 @@ func (s *ArpScanner) Results() chan *ScanResult { // Scan implements the Scan method for ARP scanning func (s *ArpScanner) Scan() error { + defer s.reset() + fields := map[string]interface{}{ "interface": s.networkInfo.Interface().Name, "cidr": s.networkInfo.Cidr(), @@ -92,7 +94,6 @@ func (s *ArpScanner) Scan() error { s.scanningMux.Lock() s.scanning = true - // open a new handle each time so we don't hit buffer overflow error handle, err := s.cap.OpenLive( s.networkInfo.Interface().Name, 65536, @@ -109,6 +110,9 @@ func (s *ArpScanner) Scan() error { s.scanningMux.Unlock() s.handle = handle + timeout := make(chan struct{}) + + go s.startPacketReceiveTimeout(timeout) go s.readPackets() limiter := time.NewTicker(s.timing) @@ -131,8 +135,18 @@ func (s *ArpScanner) Scan() error { } s.packetSentAtMux.Lock() - defer s.packetSentAtMux.Unlock() s.lastPacketSentAt = time.Now() + s.packetSentAtMux.Unlock() + + <-timeout + + go s.Stop() + + go func() { + s.resultChan <- &ScanResult{ + Type: ARPDone, + } + }() return err } @@ -183,66 +197,64 @@ func (s *ArpScanner) SetPacketCapture(cap PacketCapture) { s.cap = cap } -func (s *ArpScanner) readPackets() { - done := make(chan struct{}) +func (s *ArpScanner) startPacketReceiveTimeout(timeout chan<- struct{}) { + for { + select { + case <-s.cancel: + go func() { + timeout <- struct{}{} + }() + return + default: + s.packetSentAtMux.RLock() + packetSentAt := s.lastPacketSentAt + s.packetSentAtMux.RUnlock() - go func() { - for { - select { - case <-done: + if !packetSentAt.IsZero() && time.Since(packetSentAt) >= s.idleTimeout { + go func() { + timeout <- struct{}{} + }() return - default: - var eth layers.Ethernet - var arp layers.ARP - var payload gopacket.Payload - - parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ð, &arp, &payload) - decoded := []gopacket.LayerType{} - packetData, _, err := s.handle.ReadPacketData() - - if err != nil { - s.debug.Error().Err(err).Msg("arp: error reading packet") - continue - } - - err = parser.DecodeLayers(packetData, &decoded) - - if err != nil { - s.debug.Error().Err(err).Msg("arp: error decoding packet") - continue - } - - OUTER: - for _, layerType := range decoded { - switch layerType { - case layers.LayerTypeARP: - go s.handleARPLayer(&arp) - break OUTER - } - } } - } - }() - defer func() { - go close(done) - s.reset() - }() + time.Sleep(time.Millisecond * 100) + } + } +} +func (s *ArpScanner) readPackets() { for { select { case <-s.cancel: return default: - s.packetSentAtMux.RLock() - packetSentAt := s.lastPacketSentAt - s.packetSentAtMux.RUnlock() + var eth layers.Ethernet + var arp layers.ARP + var payload gopacket.Payload - if !packetSentAt.IsZero() && time.Since(packetSentAt) >= s.idleTimeout { - s.resultChan <- &ScanResult{ - Type: ARPDone, + parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ð, &arp, &payload) + decoded := []gopacket.LayerType{} + packetData, _, err := s.handle.ReadPacketData() + + if err != nil { + s.debug.Error().Err(err).Msg("arp: error reading packet") + continue + } + + err = parser.DecodeLayers(packetData, &decoded) + + if err != nil { + s.debug.Error().Err(err).Msg("arp: error decoding packet") + continue + } + + INNER: + for _, layerType := range decoded { + switch layerType { + case layers.LayerTypeARP: + go s.handleARPLayer(&arp) + break INNER } - return } } } diff --git a/pkg/scanner/fullscan.go b/pkg/scanner/fullscan.go index 790e057..d053972 100644 --- a/pkg/scanner/fullscan.go +++ b/pkg/scanner/fullscan.go @@ -4,7 +4,6 @@ package scanner import ( "bytes" - "fmt" "slices" "sync" "time" @@ -116,18 +115,11 @@ func (s *FullScanner) Scan() error { go s.handleArpDone() } case r := <-s.synScanner.Results(): - switch r.Type { - case SYNResult: - go func() { - s.results <- r - }() - case SYNDone: - go func() { - s.results <- r - }() + go func() { + s.results <- r + }() + if r.Type == SYNDone { return nil - default: - return fmt.Errorf("unknown result type: %s", r.Type) } case err := <-s.errorChan: return err diff --git a/pkg/scanner/synscan.go b/pkg/scanner/synscan.go index 842aec1..bb7d5a2 100644 --- a/pkg/scanner/synscan.go +++ b/pkg/scanner/synscan.go @@ -86,6 +86,8 @@ func (s *SynScanner) Results() chan *ScanResult { // Scan implements SYN scanning func (s *SynScanner) Scan() error { + defer s.reset() + fields := map[string]interface{}{ "interface": s.networkInfo.Interface().Name, "cidr": s.networkInfo.Cidr(), @@ -136,6 +138,9 @@ func (s *SynScanner) Scan() error { s.handle = handle + timeout := make(chan struct{}) + + go s.startPacketReceiveTimeout(timeout) go s.readPackets() limiter := time.NewTicker(s.timing) @@ -159,8 +164,18 @@ func (s *SynScanner) Scan() error { } s.packetSentAtMux.Lock() - defer s.packetSentAtMux.Unlock() s.lastPacketSentAt = time.Now() + s.packetSentAtMux.Unlock() + + <-timeout + + go s.Stop() + + go func() { + s.resultChan <- &ScanResult{ + Type: SYNDone, + } + }() return nil } @@ -213,73 +228,71 @@ func (s *SynScanner) SetTargets(targets []*ArpScanResult) { s.targets = targets } -func (s *SynScanner) readPackets() { - done := make(chan struct{}) +func (s *SynScanner) startPacketReceiveTimeout(timeout chan<- struct{}) { + for { + select { + case <-s.cancel: + go func() { + timeout <- struct{}{} + }() + return + default: + s.packetSentAtMux.RLock() + packetSentAt := s.lastPacketSentAt + s.packetSentAtMux.RUnlock() - go func() { - for { - select { - case <-done: + if !packetSentAt.IsZero() && time.Since(packetSentAt) >= s.idleTimeout { + go func() { + timeout <- struct{}{} + }() return - default: - var eth layers.Ethernet - var ip4 layers.IPv4 - var tcp layers.TCP - var payload gopacket.Payload - - parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ð, &ip4, &tcp, &payload) - decoded := []gopacket.LayerType{} - packetData, _, err := s.handle.ReadPacketData() - - if err != nil { - s.debug.Error().Err(err).Msg("syn: read packet error") - continue - } - - err = parser.DecodeLayers(packetData, &decoded) - - if err != nil { - s.debug.Error().Err(err).Msg("syn: decode packet error") - continue - } - - synPacket := &SynPacket{} - - for _, layerType := range decoded { - switch layerType { - case layers.LayerTypeIPv4: - synPacket.IP4 = &ip4 - case layers.LayerTypeTCP: - synPacket.TCP = &tcp - } - } - - if synPacket.IP4 != nil && synPacket.TCP != nil { - go s.handlePacket(synPacket) - } } - } - }() - defer func() { - go close(done) - s.reset() - }() + time.Sleep(time.Millisecond * 100) + } + } +} +func (s *SynScanner) readPackets() { for { select { case <-s.cancel: return default: - s.packetSentAtMux.RLock() - packetSentAt := s.lastPacketSentAt - s.packetSentAtMux.RUnlock() + var eth layers.Ethernet + var ip4 layers.IPv4 + var tcp layers.TCP + var payload gopacket.Payload + + parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, ð, &ip4, &tcp, &payload) + decoded := []gopacket.LayerType{} + packetData, _, err := s.handle.ReadPacketData() + + if err != nil { + s.debug.Error().Err(err).Msg("syn: read packet error") + continue + } - if !packetSentAt.IsZero() && time.Since(packetSentAt) >= s.idleTimeout { - s.resultChan <- &ScanResult{ - Type: SYNDone, + err = parser.DecodeLayers(packetData, &decoded) + + if err != nil { + s.debug.Error().Err(err).Msg("syn: decode packet error") + continue + } + + synPacket := &SynPacket{} + + for _, layerType := range decoded { + switch layerType { + case layers.LayerTypeIPv4: + synPacket.IP4 = &ip4 + case layers.LayerTypeTCP: + synPacket.TCP = &tcp } - return + } + + if synPacket.IP4 != nil && synPacket.TCP != nil { + go s.handlePacket(synPacket) } } } diff --git a/pkg/scanner/synscan_test.go b/pkg/scanner/synscan_test.go index 75d5844..c8a56fc 100644 --- a/pkg/scanner/synscan_test.go +++ b/pkg/scanner/synscan_test.go @@ -339,7 +339,7 @@ func TestSynScanner(t *testing.T) { assert.ErrorIs(st, mockErr, err) }) - t.Run("performs syn scan ", func(st *testing.T) { + t.Run("performs syn scan", func(st *testing.T) { capture := mock_scanner.NewMockPacketCapture(ctrl) handle := mock_scanner.NewMockPacketCaptureHandle(ctrl) netInfo := mock_network.NewMockNetwork(ctrl)