Skip to content

Commit

Permalink
fix: netlink race condition (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobyxdd authored Feb 6, 2024
1 parent 6871244 commit 843f178
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions io/nfqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ var _ PacketIO = (*nfqueuePacketIO)(nil)
var errNotNFQueuePacket = errors.New("not an NFQueue packet")

type nfqueuePacketIO struct {
n *nfqueue.Nfqueue
local bool
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables
n *nfqueue.Nfqueue
local bool
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables
iptSet bool // whether iptables rules are set
}

type NFQueuePacketIOConfig struct {
Expand Down Expand Up @@ -74,22 +75,16 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
if err != nil {
return nil, err
}
io := &nfqueuePacketIO{
return &nfqueuePacketIO{
n: n,
local: config.Local,
ipt4: ipt4,
ipt6: ipt6,
}
err = io.setupIpt(config.Local, false)
if err != nil {
_ = n.Close()
return nil, err
}
return io, nil
}, nil
}

func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
return n.n.RegisterWithErrorFunc(ctx,
err := n.n.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int {
if a.PacketID == nil || a.Ct == nil || a.Payload == nil || len(*a.Payload) < 20 {
// Invalid packet, ignore
Expand All @@ -106,6 +101,17 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error
func(e error) int {
return okBoolToInt(cb(nil, e))
})
if err != nil {
return err
}
if !n.iptSet {
err = n.setupIpt(n.local, false)
if err != nil {
return err
}
n.iptSet = true
}
return nil
}

func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
Expand Down Expand Up @@ -150,9 +156,13 @@ func (n *nfqueuePacketIO) setupIpt(local, remove bool) error {
}

func (n *nfqueuePacketIO) Close() error {
err := n.setupIpt(n.local, true)
_ = n.n.Close()
return err
if n.iptSet {
err := n.setupIpt(n.local, true)
if err != nil {
return err
}
}
return n.n.Close()
}

var _ Packet = (*nfqueuePacket)(nil)
Expand Down

0 comments on commit 843f178

Please sign in to comment.