Skip to content

Commit

Permalink
replace log.Logger with Logger interface
Browse files Browse the repository at this point in the history
Signed-off-by: Florian Lehner <[email protected]>
  • Loading branch information
florianl committed May 11, 2024
1 parent 02940c3 commit c7c4099
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
7 changes: 3 additions & 4 deletions attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ package nfqueue
import (
"bytes"
"encoding/binary"
"log"
"time"

"github.com/florianl/go-nfqueue/internal/unix"

"github.com/mdlayher/netlink"
)

func extractAttribute(log *log.Logger, a *Attribute, data []byte) error {
func extractAttribute(log Logger, a *Attribute, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
Expand Down Expand Up @@ -91,7 +90,7 @@ func extractAttribute(log *log.Logger, a *Attribute, data []byte) error {
skbPrio := ad.Uint32()
a.SkbPrio = &skbPrio
default:
log.Printf("Unknown attribute Type: 0x%x\tData: %v\n", ad.Type(), ad.Bytes())
log.Errorf("Unknown attribute Type: 0x%x\tData: %v", ad.Type(), ad.Bytes())
}
}

Expand All @@ -105,7 +104,7 @@ func checkHeader(data []byte) int {
return 0
}

func extractAttributes(log *log.Logger, msg []byte) (Attribute, error) {
func extractAttributes(log Logger, msg []byte) (Attribute, error) {
attrs := Attribute{}

offset := checkHeader(msg[:2])
Expand Down
32 changes: 19 additions & 13 deletions nfqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/binary"
"fmt"
"log"
"sync"
"time"

Expand All @@ -13,12 +12,13 @@ import (
"github.com/mdlayher/netlink"
)

// devNull satisfies io.Writer, in case *log.Logger is not provided
var _ Logger = (*devNull)(nil)

// devNull satisfies the Logger interface.
type devNull struct{}

func (devNull) Write(p []byte) (int, error) {
return 0, nil
}
func (dn *devNull) Debugf(format string, args ...interface{}) {}
func (dn *devNull) Errorf(format string, args ...interface{}) {}

// Close the connection to the netfilter queue subsystem
func (nfqueue *Nfqueue) Close() error {
Expand Down Expand Up @@ -150,7 +150,7 @@ func (nfqueue *Nfqueue) Register(ctx context.Context, fn HookFunc) error {
return 0
}
}
nfqueue.logger.Printf("Could not receive message: %v\n", err)
nfqueue.logger.Errorf("Could not receive message: %v", err)
return 1
})
}
Expand Down Expand Up @@ -259,7 +259,7 @@ func (nfqueue *Nfqueue) execute(req netlink.Message) (uint32, error) {
return seq, nil
}

func parseMsg(log *log.Logger, msg netlink.Message) (Attribute, error) {
func parseMsg(log Logger, msg netlink.Message) (Attribute, error) {
a, err := extractAttributes(log, msg.Data)
if err != nil {
return a, err
Expand All @@ -272,7 +272,7 @@ type Nfqueue struct {
// Con is the pure representation of a netlink socket
Con *netlink.Conn

logger *log.Logger
logger Logger

wg sync.WaitGroup

Expand All @@ -286,6 +286,12 @@ type Nfqueue struct {
setWriteTimeout func() error
}

// Logger provides logging functionality.
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}

// Open a connection to the netfilter queue subsystem
func Open(config *Config) (*Nfqueue, error) {
var nfqueue Nfqueue
Expand All @@ -309,7 +315,7 @@ func Open(config *Config) (*Nfqueue, error) {
nfqueue.maxQueueLen = []byte{0x00, 0x00, 0x00, 0x00}
binary.BigEndian.PutUint32(nfqueue.maxQueueLen, config.MaxQueueLen)
if config.Logger == nil {
nfqueue.logger = log.New(new(devNull), "", 0)
nfqueue.logger = new(devNull)
} else {
nfqueue.logger = config.Logger
}
Expand Down Expand Up @@ -365,7 +371,7 @@ func (nfqueue *Nfqueue) setVerdict(id uint32, verdict int, batch bool, attribute
}

if err := nfqueue.setWriteTimeout(); err != nil {
nfqueue.logger.Printf("could not set write timeout: %v\n", err)
nfqueue.logger.Errorf("could not set write timeout: %v\n", err)
}
_, sErr := nfqueue.Con.Send(req)
return sErr
Expand All @@ -378,7 +384,7 @@ func (nfqueue *Nfqueue) socketCallback(ctx context.Context, fn HookFunc, errfn E
{Type: nfQaCfgCmd, Data: []byte{nfUlnlCfgCmdUnbind, 0x0, 0x0, byte(nfqueue.family)}},
})
if err != nil {
nfqueue.logger.Printf("Could not unbind from queue: %v\n", err)
nfqueue.logger.Errorf("Could not unbind from queue: %v", err)
}
}()

Expand All @@ -395,7 +401,7 @@ func (nfqueue *Nfqueue) socketCallback(ctx context.Context, fn HookFunc, errfn E

for {
if err := ctx.Err(); err != nil {
nfqueue.logger.Printf("Stop receiving nfqueue messages: %v\n", err)
nfqueue.logger.Errorf("Stop receiving nfqueue messages: %v", err)
return
}
replys, err := nfqueue.Con.Receive()
Expand All @@ -413,7 +419,7 @@ func (nfqueue *Nfqueue) socketCallback(ctx context.Context, fn HookFunc, errfn E
}
m, err := parseMsg(nfqueue.logger, msg)
if err != nil {
nfqueue.logger.Printf("Could not parse message: %v", err)
nfqueue.logger.Errorf("Could not parse message: %v", err)
continue
}
if ret := fn(m); ret != 0 {
Expand Down
3 changes: 1 addition & 2 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package nfqueue

import (
"errors"
"log"
"time"
)

Expand Down Expand Up @@ -73,7 +72,7 @@ type Config struct {
WriteTimeout time.Duration

// Interface to log internals.
Logger *log.Logger
Logger Logger
}

// Various errors
Expand Down

0 comments on commit c7c4099

Please sign in to comment.