Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace log.Logger with Logger interface #50

Merged
merged 1 commit into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ package nfqueue
import (
"bytes"
"encoding/binary"
"log"
"time"

"github.com/florianl/go-nfqueue/internal/unix"

"github.com/mdlayher/netlink"
)

func extractAttribute(log *log.Logger, a *Attribute, data []byte) error {
func extractAttribute(log Logger, a *Attribute, data []byte) error {
ad, err := netlink.NewAttributeDecoder(data)
if err != nil {
return err
Expand Down Expand Up @@ -91,7 +90,7 @@ func extractAttribute(log *log.Logger, a *Attribute, data []byte) error {
skbPrio := ad.Uint32()
a.SkbPrio = &skbPrio
default:
log.Printf("Unknown attribute Type: 0x%x\tData: %v\n", ad.Type(), ad.Bytes())
log.Errorf("Unknown attribute Type: 0x%x\tData: %v", ad.Type(), ad.Bytes())
}
}

Expand All @@ -105,7 +104,7 @@ func checkHeader(data []byte) int {
return 0
}

func extractAttributes(log *log.Logger, msg []byte) (Attribute, error) {
func extractAttributes(log Logger, msg []byte) (Attribute, error) {
attrs := Attribute{}

offset := checkHeader(msg[:2])
Expand Down
32 changes: 19 additions & 13 deletions nfqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/binary"
"fmt"
"log"
"sync"
"time"

Expand All @@ -13,12 +12,13 @@ import (
"github.com/mdlayher/netlink"
)

// devNull satisfies io.Writer, in case *log.Logger is not provided
var _ Logger = (*devNull)(nil)

// devNull satisfies the Logger interface.
type devNull struct{}

func (devNull) Write(p []byte) (int, error) {
return 0, nil
}
func (dn *devNull) Debugf(format string, args ...interface{}) {}
func (dn *devNull) Errorf(format string, args ...interface{}) {}

// Close the connection to the netfilter queue subsystem
func (nfqueue *Nfqueue) Close() error {
Expand Down Expand Up @@ -150,7 +150,7 @@ func (nfqueue *Nfqueue) Register(ctx context.Context, fn HookFunc) error {
return 0
}
}
nfqueue.logger.Printf("Could not receive message: %v\n", err)
nfqueue.logger.Errorf("Could not receive message: %v", err)
return 1
})
}
Expand Down Expand Up @@ -259,7 +259,7 @@ func (nfqueue *Nfqueue) execute(req netlink.Message) (uint32, error) {
return seq, nil
}

func parseMsg(log *log.Logger, msg netlink.Message) (Attribute, error) {
func parseMsg(log Logger, msg netlink.Message) (Attribute, error) {
a, err := extractAttributes(log, msg.Data)
if err != nil {
return a, err
Expand All @@ -272,7 +272,7 @@ type Nfqueue struct {
// Con is the pure representation of a netlink socket
Con *netlink.Conn

logger *log.Logger
logger Logger

wg sync.WaitGroup

Expand All @@ -286,6 +286,12 @@ type Nfqueue struct {
setWriteTimeout func() error
}

// Logger provides logging functionality.
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}

// Open a connection to the netfilter queue subsystem
func Open(config *Config) (*Nfqueue, error) {
var nfqueue Nfqueue
Expand All @@ -309,7 +315,7 @@ func Open(config *Config) (*Nfqueue, error) {
nfqueue.maxQueueLen = []byte{0x00, 0x00, 0x00, 0x00}
binary.BigEndian.PutUint32(nfqueue.maxQueueLen, config.MaxQueueLen)
if config.Logger == nil {
nfqueue.logger = log.New(new(devNull), "", 0)
nfqueue.logger = new(devNull)
} else {
nfqueue.logger = config.Logger
}
Expand Down Expand Up @@ -365,7 +371,7 @@ func (nfqueue *Nfqueue) setVerdict(id uint32, verdict int, batch bool, attribute
}

if err := nfqueue.setWriteTimeout(); err != nil {
nfqueue.logger.Printf("could not set write timeout: %v\n", err)
nfqueue.logger.Errorf("could not set write timeout: %v\n", err)
}
_, sErr := nfqueue.Con.Send(req)
return sErr
Expand All @@ -378,7 +384,7 @@ func (nfqueue *Nfqueue) socketCallback(ctx context.Context, fn HookFunc, errfn E
{Type: nfQaCfgCmd, Data: []byte{nfUlnlCfgCmdUnbind, 0x0, 0x0, byte(nfqueue.family)}},
})
if err != nil {
nfqueue.logger.Printf("Could not unbind from queue: %v\n", err)
nfqueue.logger.Errorf("Could not unbind from queue: %v", err)
}
}()

Expand All @@ -395,7 +401,7 @@ func (nfqueue *Nfqueue) socketCallback(ctx context.Context, fn HookFunc, errfn E

for {
if err := ctx.Err(); err != nil {
nfqueue.logger.Printf("Stop receiving nfqueue messages: %v\n", err)
nfqueue.logger.Errorf("Stop receiving nfqueue messages: %v", err)
return
}
replys, err := nfqueue.Con.Receive()
Expand All @@ -413,7 +419,7 @@ func (nfqueue *Nfqueue) socketCallback(ctx context.Context, fn HookFunc, errfn E
}
m, err := parseMsg(nfqueue.logger, msg)
if err != nil {
nfqueue.logger.Printf("Could not parse message: %v", err)
nfqueue.logger.Errorf("Could not parse message: %v", err)
continue
}
if ret := fn(m); ret != 0 {
Expand Down
3 changes: 1 addition & 2 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package nfqueue

import (
"errors"
"log"
"time"
)

Expand Down Expand Up @@ -73,7 +72,7 @@ type Config struct {
WriteTimeout time.Duration

// Interface to log internals.
Logger *log.Logger
Logger Logger
}

// Various errors
Expand Down
Loading