Skip to content

Commit

Permalink
Minor refactor in scanners
Browse files Browse the repository at this point in the history
Cleans up graceful exit of scanners
  • Loading branch information
robgonnella authored and actions-user committed Jul 14, 2024
1 parent c5cfa02 commit cbbadcd
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 120 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# go-lanscan
![Coverage](https://img.shields.io/badge/Coverage-91.8%25-brightgreen)
![Coverage](https://img.shields.io/badge/Coverage-91.7%25-brightgreen)

A network cli and golang package that allows you to perform arp and syn
scanning on a local area network.
Expand Down
109 changes: 59 additions & 50 deletions pkg/scanner/arpscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ func (s *ArpScanner) Results() chan *ScanResult {

// Scan implements the Scan method for ARP scanning
func (s *ArpScanner) Scan() error {
defer s.reset()

fields := map[string]interface{}{
"interface": s.networkInfo.Interface().Name,
"cidr": s.networkInfo.Cidr(),
Expand All @@ -92,7 +94,6 @@ func (s *ArpScanner) Scan() error {
s.scanningMux.Lock()
s.scanning = true

// open a new handle each time so we don't hit buffer overflow error
handle, err := s.cap.OpenLive(
s.networkInfo.Interface().Name,
65536,
Expand All @@ -109,6 +110,9 @@ func (s *ArpScanner) Scan() error {
s.scanningMux.Unlock()
s.handle = handle

timeout := make(chan struct{})

go s.startPacketReceiveTimeout(timeout)
go s.readPackets()

limiter := time.NewTicker(s.timing)
Expand All @@ -131,8 +135,18 @@ func (s *ArpScanner) Scan() error {
}

s.packetSentAtMux.Lock()
defer s.packetSentAtMux.Unlock()
s.lastPacketSentAt = time.Now()
s.packetSentAtMux.Unlock()

<-timeout

go s.Stop()

go func() {
s.resultChan <- &ScanResult{
Type: ARPDone,
}
}()

return err
}
Expand Down Expand Up @@ -183,66 +197,61 @@ func (s *ArpScanner) SetPacketCapture(cap PacketCapture) {
s.cap = cap
}

func (s *ArpScanner) readPackets() {
done := make(chan struct{})
func (s *ArpScanner) startPacketReceiveTimeout(timeout chan<- struct{}) {
for {
select {
case <-s.cancel:
return
default:
s.packetSentAtMux.RLock()
packetSentAt := s.lastPacketSentAt
s.packetSentAtMux.RUnlock()

go func() {
for {
select {
case <-done:
if !packetSentAt.IsZero() && time.Since(packetSentAt) >= s.idleTimeout {
go func() {
timeout <- struct{}{}
}()
return
default:
var eth layers.Ethernet
var arp layers.ARP
var payload gopacket.Payload

parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &eth, &arp, &payload)
decoded := []gopacket.LayerType{}
packetData, _, err := s.handle.ReadPacketData()

if err != nil {
s.debug.Error().Err(err).Msg("arp: error reading packet")
continue
}

err = parser.DecodeLayers(packetData, &decoded)

if err != nil {
s.debug.Error().Err(err).Msg("arp: error decoding packet")
continue
}

OUTER:
for _, layerType := range decoded {
switch layerType {
case layers.LayerTypeARP:
go s.handleARPLayer(&arp)
break OUTER
}
}
}
}
}()

defer func() {
go close(done)
s.reset()
}()
time.Sleep(time.Second)
}
}
}

func (s *ArpScanner) readPackets() {
for {
select {
case <-s.cancel:
return
default:
s.packetSentAtMux.RLock()
packetSentAt := s.lastPacketSentAt
s.packetSentAtMux.RUnlock()
var eth layers.Ethernet
var arp layers.ARP
var payload gopacket.Payload

if !packetSentAt.IsZero() && time.Since(packetSentAt) >= s.idleTimeout {
s.resultChan <- &ScanResult{
Type: ARPDone,
parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &eth, &arp, &payload)
decoded := []gopacket.LayerType{}
packetData, _, err := s.handle.ReadPacketData()

if err != nil {
s.debug.Error().Err(err).Msg("arp: error reading packet")
continue
}

err = parser.DecodeLayers(packetData, &decoded)

if err != nil {
s.debug.Error().Err(err).Msg("arp: error decoding packet")
continue
}

INNER:
for _, layerType := range decoded {
switch layerType {
case layers.LayerTypeARP:
go s.handleARPLayer(&arp)
break INNER
}
return
}
}
}
Expand Down
16 changes: 4 additions & 12 deletions pkg/scanner/fullscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package scanner

import (
"bytes"
"fmt"
"slices"
"sync"
"time"
Expand Down Expand Up @@ -116,18 +115,11 @@ func (s *FullScanner) Scan() error {
go s.handleArpDone()
}
case r := <-s.synScanner.Results():
switch r.Type {
case SYNResult:
go func() {
s.results <- r
}()
case SYNDone:
go func() {
s.results <- r
}()
go func() {
s.results <- r
}()
if r.Type == SYNDone {
return nil
default:
return fmt.Errorf("unknown result type: %s", r.Type)
}
case err := <-s.errorChan:
return err
Expand Down
122 changes: 66 additions & 56 deletions pkg/scanner/synscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ func (s *SynScanner) Results() chan *ScanResult {

// Scan implements SYN scanning
func (s *SynScanner) Scan() error {
defer s.reset()

fields := map[string]interface{}{
"interface": s.networkInfo.Interface().Name,
"cidr": s.networkInfo.Cidr(),
Expand Down Expand Up @@ -136,6 +138,9 @@ func (s *SynScanner) Scan() error {

s.handle = handle

timeout := make(chan struct{})

go s.startPacketReceiveTimeout(timeout)
go s.readPackets()

limiter := time.NewTicker(s.timing)
Expand All @@ -159,8 +164,18 @@ func (s *SynScanner) Scan() error {
}

s.packetSentAtMux.Lock()
defer s.packetSentAtMux.Unlock()
s.lastPacketSentAt = time.Now()
s.packetSentAtMux.Unlock()

<-timeout

go s.Stop()

go func() {
s.resultChan <- &ScanResult{
Type: SYNDone,
}
}()

return nil
}
Expand Down Expand Up @@ -213,73 +228,68 @@ func (s *SynScanner) SetTargets(targets []*ArpScanResult) {
s.targets = targets
}

func (s *SynScanner) readPackets() {
done := make(chan struct{})
func (s *SynScanner) startPacketReceiveTimeout(timeout chan<- struct{}) {
for {
select {
case <-s.cancel:
return
default:
s.packetSentAtMux.RLock()
packetSentAt := s.lastPacketSentAt
s.packetSentAtMux.RUnlock()

go func() {
for {
select {
case <-done:
if !packetSentAt.IsZero() && time.Since(packetSentAt) >= s.idleTimeout {
go func() {
timeout <- struct{}{}
}()
return
default:
var eth layers.Ethernet
var ip4 layers.IPv4
var tcp layers.TCP
var payload gopacket.Payload

parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &eth, &ip4, &tcp, &payload)
decoded := []gopacket.LayerType{}
packetData, _, err := s.handle.ReadPacketData()

if err != nil {
s.debug.Error().Err(err).Msg("syn: read packet error")
continue
}

err = parser.DecodeLayers(packetData, &decoded)

if err != nil {
s.debug.Error().Err(err).Msg("syn: decode packet error")
continue
}

synPacket := &SynPacket{}

for _, layerType := range decoded {
switch layerType {
case layers.LayerTypeIPv4:
synPacket.IP4 = &ip4
case layers.LayerTypeTCP:
synPacket.TCP = &tcp
}
}

if synPacket.IP4 != nil && synPacket.TCP != nil {
go s.handlePacket(synPacket)
}
}
}
}()

defer func() {
go close(done)
s.reset()
}()
time.Sleep(time.Second)
}
}
}

func (s *SynScanner) readPackets() {
for {
select {
case <-s.cancel:
return
default:
s.packetSentAtMux.RLock()
packetSentAt := s.lastPacketSentAt
s.packetSentAtMux.RUnlock()
var eth layers.Ethernet
var ip4 layers.IPv4
var tcp layers.TCP
var payload gopacket.Payload

parser := gopacket.NewDecodingLayerParser(layers.LayerTypeEthernet, &eth, &ip4, &tcp, &payload)
decoded := []gopacket.LayerType{}
packetData, _, err := s.handle.ReadPacketData()

if err != nil {
s.debug.Error().Err(err).Msg("syn: read packet error")
continue
}

if !packetSentAt.IsZero() && time.Since(packetSentAt) >= s.idleTimeout {
s.resultChan <- &ScanResult{
Type: SYNDone,
err = parser.DecodeLayers(packetData, &decoded)

if err != nil {
s.debug.Error().Err(err).Msg("syn: decode packet error")
continue
}

synPacket := &SynPacket{}

for _, layerType := range decoded {
switch layerType {
case layers.LayerTypeIPv4:
synPacket.IP4 = &ip4
case layers.LayerTypeTCP:
synPacket.TCP = &tcp
}
return
}

if synPacket.IP4 != nil && synPacket.TCP != nil {
go s.handlePacket(synPacket)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scanner/synscan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func TestSynScanner(t *testing.T) {
assert.ErrorIs(st, mockErr, err)
})

t.Run("performs syn scan ", func(st *testing.T) {
t.Run("performs syn scan", func(st *testing.T) {
capture := mock_scanner.NewMockPacketCapture(ctrl)
handle := mock_scanner.NewMockPacketCaptureHandle(ctrl)
netInfo := mock_network.NewMockNetwork(ctrl)
Expand Down

0 comments on commit cbbadcd

Please sign in to comment.