From 61fb6d77d353e415b24bfc25f4190b1e23d7aa86 Mon Sep 17 00:00:00 2001 From: Ilya Mashchenko Date: Fri, 29 Nov 2024 12:43:26 +0200 Subject: [PATCH] chore(go.d/pkg/socket): add err to callback return values (#19103) --- .../plugin/go.d/collector/beanstalk/client.go | 40 ++- src/go/plugin/go.d/collector/boinc/client.go | 15 +- .../plugin/go.d/collector/dovecot/client.go | 7 +- .../plugin/go.d/collector/gearman/client.go | 25 +- .../plugin/go.d/collector/hddtemp/client.go | 12 +- .../plugin/go.d/collector/memcached/client.go | 8 +- .../go.d/collector/openvpn/client/client.go | 17 +- .../collector/openvpn/client/client_test.go | 4 +- src/go/plugin/go.d/collector/tor/client.go | 22 +- .../plugin/go.d/collector/unbound/collect.go | 4 +- src/go/plugin/go.d/collector/upsd/client.go | 4 +- src/go/plugin/go.d/collector/uwsgi/client.go | 24 +- .../go.d/collector/zookeeper/fetcher.go | 18 +- .../go.d/collector/zookeeper/fetcher_test.go | 12 +- .../plugin/go.d/collector/zookeeper/init.go | 7 +- src/go/plugin/go.d/pkg/socket/client.go | 118 ++++---- src/go/plugin/go.d/pkg/socket/client_test.go | 218 ++++++--------- src/go/plugin/go.d/pkg/socket/server.go | 257 ++++++++++++++++++ src/go/plugin/go.d/pkg/socket/servers_test.go | 139 ---------- src/go/plugin/go.d/pkg/socket/utils.go | 2 +- 20 files changed, 477 insertions(+), 476 deletions(-) create mode 100644 src/go/plugin/go.d/pkg/socket/server.go delete mode 100644 src/go/plugin/go.d/pkg/socket/servers_test.go diff --git a/src/go/plugin/go.d/collector/beanstalk/client.go b/src/go/plugin/go.d/collector/beanstalk/client.go index 00fd8a13c347c7..c3839e6e9d0819 100644 --- a/src/go/plugin/go.d/collector/beanstalk/client.go +++ b/src/go/plugin/go.d/collector/beanstalk/client.go @@ -88,9 +88,10 @@ func newBeanstalkConn(conf Config, log *logger.Logger) beanstalkConn { return &beanstalkClient{ Logger: log, client: socket.New(socket.Config{ - Address: conf.Address, - Timeout: conf.Timeout.Duration(), - TLSConf: nil, + Address: conf.Address, + Timeout: conf.Timeout.Duration(), + MaxReadLines: 2000, + TLSConf: nil, }), } } @@ -180,43 +181,34 @@ func (c *beanstalkClient) queryStatsTube(tubeName string) (*tubeStats, error) { } func (c *beanstalkClient) query(command string) (string, []byte, error) { - var resp string - var length int - var body []byte - var err error - c.Debugf("executing command: %s", command) - const limitReadLines = 1000 - var num int + var ( + resp string + body []byte + length int + err error + ) - clientErr := c.client.Command(command+"\r\n", func(line []byte) bool { + if err := c.client.Command(command+"\r\n", func(line []byte) (bool, error) { if resp == "" { s := string(line) c.Debugf("command '%s' response: '%s'", command, s) resp, length, err = parseResponseLine(s) if err != nil { - err = fmt.Errorf("command '%s' line '%s': %v", command, s, err) + return false, fmt.Errorf("command '%s' line '%s': %v", command, s, err) } - return err == nil && resp == "OK" - } - if num++; num >= limitReadLines { - err = fmt.Errorf("command '%s': read line limit exceeded (%d)", command, limitReadLines) - return false + return resp == "OK", nil } body = append(body, line...) body = append(body, '\n') - return len(body) < length - }) - if clientErr != nil { - return "", nil, fmt.Errorf("command '%s' client error: %v", command, clientErr) - } - if err != nil { - return "", nil, err + return len(body) < length, nil + }); err != nil { + return "", nil, fmt.Errorf("command '%s': %v", command, err) } return resp, body, nil diff --git a/src/go/plugin/go.d/collector/boinc/client.go b/src/go/plugin/go.d/collector/boinc/client.go index 7635330a25f1b8..15cd4e134add07 100644 --- a/src/go/plugin/go.d/collector/boinc/client.go +++ b/src/go/plugin/go.d/collector/boinc/client.go @@ -111,25 +111,20 @@ func (c *boincClient) send(req *boincRequest) (*boincReply, error) { var b bytes.Buffer - clientErr := c.conn.Command(string(reqData), func(bs []byte) bool { + if err := c.conn.Command(string(reqData), func(bs []byte) (bool, error) { s := strings.TrimSpace(string(bs)) if s == "" { - return true + return true, nil } if b.Len() == 0 && s != respStart { - err = fmt.Errorf("unexpected response first line: %s", s) - return false + return false, fmt.Errorf("unexpected response first line: %s", s) } b.WriteString(s) - return s != respEnd - }) - if clientErr != nil { - return nil, fmt.Errorf("failed to send command: %v", clientErr) - } - if err != nil { + return s != respEnd, nil + }); err != nil { return nil, fmt.Errorf("failed to send command: %v", err) } diff --git a/src/go/plugin/go.d/collector/dovecot/client.go b/src/go/plugin/go.d/collector/dovecot/client.go index b7b5fa2c6fe815..962cb693d0bde5 100644 --- a/src/go/plugin/go.d/collector/dovecot/client.go +++ b/src/go/plugin/go.d/collector/dovecot/client.go @@ -37,14 +37,13 @@ func (c *dovecotClient) queryExportGlobal() ([]byte, error) { var b bytes.Buffer var n int - err := c.conn.Command("EXPORT\tglobal\n", func(bs []byte) bool { + if err := c.conn.Command("EXPORT\tglobal\n", func(bs []byte) (bool, error) { b.Write(bs) b.WriteByte('\n') n++ - return n < 2 - }) - if err != nil { + return n < 2, nil + }); err != nil { return nil, err } diff --git a/src/go/plugin/go.d/collector/gearman/client.go b/src/go/plugin/go.d/collector/gearman/client.go index c42b2b9bd17f03..9006d6eb50884f 100644 --- a/src/go/plugin/go.d/collector/gearman/client.go +++ b/src/go/plugin/go.d/collector/gearman/client.go @@ -19,8 +19,9 @@ type gearmanConn interface { func newGearmanConn(conf Config) gearmanConn { return &gearmanClient{conn: socket.New(socket.Config{ - Address: conf.Address, - Timeout: conf.Timeout.Duration(), + Address: conf.Address, + Timeout: conf.Timeout.Duration(), + MaxReadLines: 10000, })} } @@ -45,32 +46,20 @@ func (c *gearmanClient) queryPriorityStatus() ([]byte, error) { } func (c *gearmanClient) query(cmd string) ([]byte, error) { - const limitReadLines = 10000 - var num int - var err error var b bytes.Buffer - clientErr := c.conn.Command(cmd+"\n", func(bs []byte) bool { + if err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) { s := string(bs) if strings.HasPrefix(s, "ERR") { - err = fmt.Errorf("command '%s': %s", cmd, s) - return false + return false, fmt.Errorf("command '%s': %s", cmd, s) } b.WriteString(s) b.WriteByte('\n') - if num++; num >= limitReadLines { - err = fmt.Errorf("command '%s': read line limit exceeded (%d)", cmd, limitReadLines) - return false - } - return !strings.HasPrefix(s, ".") - }) - if clientErr != nil { - return nil, fmt.Errorf("command '%s' client error: %v", cmd, clientErr) - } - if err != nil { + return !strings.HasPrefix(s, "."), nil + }); err != nil { return nil, err } diff --git a/src/go/plugin/go.d/collector/hddtemp/client.go b/src/go/plugin/go.d/collector/hddtemp/client.go index d289e3a8ab8f48..aa02411396a9c9 100644 --- a/src/go/plugin/go.d/collector/hddtemp/client.go +++ b/src/go/plugin/go.d/collector/hddtemp/client.go @@ -25,21 +25,15 @@ type hddtempClient struct { } func (c *hddtempClient) queryHddTemp() (string, error) { - var i int - var s string - cfg := socket.Config{ Address: c.address, Timeout: c.timeout, } - err := socket.ConnectAndRead(cfg, func(bs []byte) bool { - if i++; i > 1 { - return false - } + var s string + err := socket.ConnectAndRead(cfg, func(bs []byte) (bool, error) { s = string(bs) - return true - + return false, nil }) if err != nil { return "", err diff --git a/src/go/plugin/go.d/collector/memcached/client.go b/src/go/plugin/go.d/collector/memcached/client.go index aa3a45294ffddb..e561ed06af2961 100644 --- a/src/go/plugin/go.d/collector/memcached/client.go +++ b/src/go/plugin/go.d/collector/memcached/client.go @@ -36,13 +36,13 @@ func (c *memcachedClient) disconnect() { func (c *memcachedClient) queryStats() ([]byte, error) { var b bytes.Buffer - err := c.conn.Command("stats\r\n", func(bytes []byte) bool { + if err := c.conn.Command("stats\r\n", func(bytes []byte) (bool, error) { s := strings.TrimSpace(string(bytes)) b.WriteString(s) b.WriteByte('\n') - return !(strings.HasPrefix(s, "END") || strings.HasPrefix(s, "ERROR")) - }) - if err != nil { + + return !(strings.HasPrefix(s, "END") || strings.HasPrefix(s, "ERROR")), nil + }); err != nil { return nil, err } return b.Bytes(), nil diff --git a/src/go/plugin/go.d/collector/openvpn/client/client.go b/src/go/plugin/go.d/collector/openvpn/client/client.go index 23ceb18d87a2ec..1f1288e368e753 100644 --- a/src/go/plugin/go.d/collector/openvpn/client/client.go +++ b/src/go/plugin/go.d/collector/openvpn/client/client.go @@ -57,29 +57,26 @@ func (c *Client) Version() (*Version, error) { func (c *Client) get(command string, stopRead stopReadFunc) (output []string, err error) { var num int - var maxLinesErr error - err = c.Command(command, func(bytes []byte) bool { + if err := c.Command(command, func(bytes []byte) (bool, error) { line := string(bytes) num++ if num > maxLinesToRead { - maxLinesErr = fmt.Errorf("read line limit exceeded (%d)", maxLinesToRead) - return false + return false, fmt.Errorf("read line limit exceeded (%d)", maxLinesToRead) } // skip real-time messages if strings.HasPrefix(line, ">") { - return true + return true, nil } line = strings.Trim(line, "\r\n ") output = append(output, line) if stopRead != nil && stopRead(line) { - return false + return false, nil } - return true - }) - if maxLinesErr != nil { - return nil, maxLinesErr + return true, nil + }); err != nil { + return nil, err } return output, err } diff --git a/src/go/plugin/go.d/collector/openvpn/client/client_test.go b/src/go/plugin/go.d/collector/openvpn/client/client_test.go index d1257e877b3092..b7fa37b3c24548 100644 --- a/src/go/plugin/go.d/collector/openvpn/client/client_test.go +++ b/src/go/plugin/go.d/collector/openvpn/client/client_test.go @@ -98,7 +98,9 @@ func (m *mockSocketClient) Command(command string, process socket.Processor) err } for s.Scan() { - process(s.Bytes()) + if _, err := process(s.Bytes()); err != nil { + return err + } } return nil } diff --git a/src/go/plugin/go.d/collector/tor/client.go b/src/go/plugin/go.d/collector/tor/client.go index 66e784c3f1530f..52420cb1962e51 100644 --- a/src/go/plugin/go.d/collector/tor/client.go +++ b/src/go/plugin/go.d/collector/tor/client.go @@ -58,9 +58,9 @@ func (c *torControlClient) authenticate() error { } var s string - err := c.conn.Command(cmd+"\n", func(bs []byte) bool { + err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) { s = string(bs) - return false + return false, nil }) if err != nil { return fmt.Errorf("authentication failed: %v", err) @@ -74,7 +74,7 @@ func (c *torControlClient) authenticate() error { func (c *torControlClient) disconnect() { // https://spec.torproject.org/control-spec/commands.html#quit - _ = c.conn.Command(cmdQuit+"\n", func(bs []byte) bool { return false }) + _ = c.conn.Command(cmdQuit+"\n", func(bs []byte) (bool, error) { return false, nil }) _ = c.conn.Disconnect() } @@ -87,27 +87,21 @@ func (c *torControlClient) getInfo(keywords ...string) ([]byte, error) { cmd := fmt.Sprintf("%s %s", cmdGetInfo, strings.Join(keywords, " ")) var buf bytes.Buffer - var err error - clientErr := c.conn.Command(cmd+"\n", func(bs []byte) bool { + if err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) { s := string(bs) switch { case strings.HasPrefix(s, "250-"): buf.WriteString(strings.TrimPrefix(s, "250-")) buf.WriteByte('\n') - return true + return true, nil case strings.HasPrefix(s, "250 "): - return false + return false, nil default: - err = errors.New(s) - return false + return false, errors.New(s) } - }) - if clientErr != nil { - return nil, fmt.Errorf("command '%s' failed: %v", cmd, clientErr) - } - if err != nil { + }); err != nil { return nil, fmt.Errorf("command '%s' failed: %v", cmd, err) } diff --git a/src/go/plugin/go.d/collector/unbound/collect.go b/src/go/plugin/go.d/collector/unbound/collect.go index 411ce828412a84..c44e9f1549cbef 100644 --- a/src/go/plugin/go.d/collector/unbound/collect.go +++ b/src/go/plugin/go.d/collector/unbound/collect.go @@ -36,9 +36,9 @@ func (c *Collector) scrapeUnboundStats() ([]entry, error) { } defer func() { _ = c.client.Disconnect() }() - err := c.client.Command(command+"\n", func(bytes []byte) bool { + err := c.client.Command(command+"\n", func(bytes []byte) (bool, error) { output = append(output, string(bytes)) - return true + return true, nil }) if err != nil { return nil, fmt.Errorf("send command '%s': %w", command, err) diff --git a/src/go/plugin/go.d/collector/upsd/client.go b/src/go/plugin/go.d/collector/upsd/client.go index a88723e07583d0..67b4a26695b5d6 100644 --- a/src/go/plugin/go.d/collector/upsd/client.go +++ b/src/go/plugin/go.d/collector/upsd/client.go @@ -133,7 +133,7 @@ func (c *upsdClient) sendCommand(cmd string) ([]string, error) { var errMsg string endLine := getEndLine(cmd) - err := c.conn.Command(cmd+"\n", func(bytes []byte) bool { + err := c.conn.Command(cmd+"\n", func(bytes []byte) (bool, error) { line := string(bytes) resp = append(resp, line) @@ -141,7 +141,7 @@ func (c *upsdClient) sendCommand(cmd string) ([]string, error) { errMsg = strings.TrimPrefix(line, "ERR ") } - return line != endLine && errMsg == "" + return line != endLine && errMsg == "", nil }) if err != nil { return nil, err diff --git a/src/go/plugin/go.d/collector/uwsgi/client.go b/src/go/plugin/go.d/collector/uwsgi/client.go index 12138672f4eb12..80908b8a5cd00b 100644 --- a/src/go/plugin/go.d/collector/uwsgi/client.go +++ b/src/go/plugin/go.d/collector/uwsgi/client.go @@ -4,7 +4,6 @@ package uwsgi import ( "bytes" - "fmt" "time" "github.com/netdata/netdata/go/plugins/plugin/go.d/pkg/socket" @@ -28,30 +27,19 @@ type uwsgiClient struct { func (c *uwsgiClient) queryStats() ([]byte, error) { var b bytes.Buffer - var n int64 - var err error - const readLineLimit = 1000 * 10 cfg := socket.Config{ - Address: c.address, - Timeout: c.timeout, + Address: c.address, + Timeout: c.timeout, + MaxReadLines: 1000 * 10, } - clientErr := socket.ConnectAndRead(cfg, func(bs []byte) bool { + if err := socket.ConnectAndRead(cfg, func(bs []byte) (bool, error) { b.Write(bs) b.WriteByte('\n') - - if n++; n >= readLineLimit { - err = fmt.Errorf("read line limit exceeded %d", readLineLimit) - return false - } // The server will close the connection when it has finished sending data. - return true - }) - if clientErr != nil { - return nil, clientErr - } - if err != nil { + return true, nil + }); err != nil { return nil, err } diff --git a/src/go/plugin/go.d/collector/zookeeper/fetcher.go b/src/go/plugin/go.d/collector/zookeeper/fetcher.go index a6ae2c052a8037..e390ffecedf301 100644 --- a/src/go/plugin/go.d/collector/zookeeper/fetcher.go +++ b/src/go/plugin/go.d/collector/zookeeper/fetcher.go @@ -4,14 +4,11 @@ package zookeeper import ( "bytes" - "fmt" "unsafe" "github.com/netdata/netdata/go/plugins/plugin/go.d/pkg/socket" ) -const limitReadLines = 2000 - type fetcher interface { fetch(command string) ([]string, error) } @@ -26,21 +23,12 @@ func (c *zookeeperFetcher) fetch(command string) (rows []string, err error) { } defer func() { _ = c.Disconnect() }() - var num int - clientErr := c.Command(command, func(b []byte) bool { + if err := c.Command(command, func(b []byte) (bool, error) { if !isZKLine(b) || isMntrLineOK(b) { rows = append(rows, string(b)) } - if num += 1; num >= limitReadLines { - err = fmt.Errorf("read line limit exceeded (%d)", limitReadLines) - return false - } - return true - }) - if clientErr != nil { - return nil, clientErr - } - if err != nil { + return true, nil + }); err != nil { return nil, err } diff --git a/src/go/plugin/go.d/collector/zookeeper/fetcher_test.go b/src/go/plugin/go.d/collector/zookeeper/fetcher_test.go index 582d91c6f8ac41..846574ea09338f 100644 --- a/src/go/plugin/go.d/collector/zookeeper/fetcher_test.go +++ b/src/go/plugin/go.d/collector/zookeeper/fetcher_test.go @@ -22,14 +22,6 @@ func Test_clientFetch(t *testing.T) { assert.Len(t, rows, 10) } -func Test_clientFetchReadLineLimitExceeded(t *testing.T) { - c := &zookeeperFetcher{Client: &mockSocket{rowsNumResp: limitReadLines + 1}} - - rows, err := c.fetch("whatever\n") - assert.Error(t, err) - assert.Len(t, rows, 0) -} - type mockSocket struct { rowsNumResp int } @@ -44,7 +36,9 @@ func (m *mockSocket) Disconnect() error { func (m *mockSocket) Command(command string, process socket.Processor) error { for i := 0; i < m.rowsNumResp; i++ { - process([]byte(command)) + if _, err := process([]byte(command)); err != nil { + return err + } } return nil } diff --git a/src/go/plugin/go.d/collector/zookeeper/init.go b/src/go/plugin/go.d/collector/zookeeper/init.go index 5f80f775a0902f..4abd6e6ff53efd 100644 --- a/src/go/plugin/go.d/collector/zookeeper/init.go +++ b/src/go/plugin/go.d/collector/zookeeper/init.go @@ -30,9 +30,10 @@ func (c *Collector) initZookeeperFetcher() (fetcher, error) { } sock := socket.New(socket.Config{ - Address: c.Address, - Timeout: c.Timeout.Duration(), - TLSConf: tlsConf, + Address: c.Address, + Timeout: c.Timeout.Duration(), + TLSConf: tlsConf, + MaxReadLines: 2000, }) return &zookeeperFetcher{Client: sock}, nil diff --git a/src/go/plugin/go.d/pkg/socket/client.go b/src/go/plugin/go.d/pkg/socket/client.go index f7f5a4bd5d6f7f..4da4294ae7213b 100644 --- a/src/go/plugin/go.d/pkg/socket/client.go +++ b/src/go/plugin/go.d/pkg/socket/client.go @@ -4,25 +4,29 @@ package socket import ( "bufio" + "context" "crypto/tls" "errors" + "fmt" "net" "time" ) -// Processor function passed to the Socket.Command function. -// It is passed by the caller to process a command's response line by line. -type Processor func([]byte) bool +// Processor is a callback function passed to the Socket.Command method. +// It processes each response line received from the server. +type Processor func([]byte) (bool, error) -// Client is the interface that wraps the basic socket client operations -// and hides the implementation details from the users. -// Implementations should return TCP, UDP or Unix ready sockets. +// Client defines an interface for socket clients, abstracting the underlying implementation. +// Implementations should provide connections for various socket types such as TCP, UDP, or Unix domain sockets. type Client interface { Connect() error Disconnect() error Command(command string, process Processor) error } +// ConnectAndRead establishes a connection using the given configuration, +// executes the provided processor function on the incoming response lines, +// and ensures the connection is properly closed after use. func ConnectAndRead(cfg Config, process Processor) error { sock := New(cfg) @@ -35,46 +39,33 @@ func ConnectAndRead(cfg Config, process Processor) error { return sock.read(process) } -// New returns a new pointer to a socket client given the socket -// type (IP, TCP, UDP, UNIX), a network address (IP/domain:port), -// a timeout and a TLS config. It supports both IPv4 and IPv6 address -// and reuses connection where possible. +// New creates and returns a new Socket instance configured with the provided settings. +// The socket supports multiple types (TCP, UDP, UNIX), addresses (IPv4, IPv6, domain names), +// and optional TLS encryption. Connections are reused where possible. func New(cfg Config) *Socket { return &Socket{Config: cfg} } -// Socket is the implementation of a socket client. +// Socket is a concrete implementation of the Client interface, managing a network connection +// based on the specified configuration (address, type, timeout, and optional TLS settings). type Socket struct { Config conn net.Conn } -// Config holds the network ip v4 or v6 address, port, -// Socket type(ip, tcp, udp, unix), timeout and TLS configuration for a Socket +// Config encapsulates the settings required to establish a network connection. type Config struct { - Address string - Timeout time.Duration - TLSConf *tls.Config + Address string + Timeout time.Duration + TLSConf *tls.Config + MaxReadLines int64 } -// Connect connects to the Socket address on the named network. -// If the address is a domain name it will also perform the DNS resolution. -// Address like :80 will attempt to connect to the localhost. -// The config timeout and TLS config will be used. +// Connect establishes a connection to the specified address using the configuration details. func (s *Socket) Connect() error { - network, address := networkType(s.Address) - var conn net.Conn - var err error - - if s.TLSConf == nil { - conn, err = net.DialTimeout(network, address, s.timeout()) - } else { - var d net.Dialer - d.Timeout = s.timeout() - conn, err = tls.DialWithDialer(&d, network, address, s.TLSConf) - } + conn, err := s.dial() if err != nil { - return err + return fmt.Errorf("socket.Connect: %w", err) } s.conn = conn @@ -82,22 +73,19 @@ func (s *Socket) Connect() error { return nil } -// Disconnect closes the connection. -// Any in-flight commands will be cancelled and return errors. -func (s *Socket) Disconnect() (err error) { - if s.conn != nil { - err = s.conn.Close() - s.conn = nil +// Disconnect terminates the active connection if one exists. +func (s *Socket) Disconnect() error { + if s.conn == nil { + return nil } + err := s.conn.Close() + s.conn = nil return err } -// Command writes the command string to the connection and passed the -// response bytes line by line to the process function. It uses the -// timeout value from the Socket config and returns read, write and -// timeout errors if any. If a timeout occurs during the processing -// of the responses this function will stop processing and return a -// timeout error. +// Command sends a command string to the connected server and processes its response line by line +// using the provided Processor function. This method respects the timeout configuration +// for write and read operations. If a timeout or processing error occurs, it stops and returns the error. func (s *Socket) Command(command string, process Processor) error { if s.conn == nil { return errors.New("cannot send command on nil connection") @@ -112,10 +100,10 @@ func (s *Socket) Command(command string, process Processor) error { func (s *Socket) write(command string) error { if s.conn == nil { - return errors.New("attempt to write on nil connection") + return errors.New("write: nil connection") } - if err := s.conn.SetWriteDeadline(time.Now().Add(s.timeout())); err != nil { + if err := s.conn.SetWriteDeadline(s.deadline()); err != nil { return err } @@ -126,25 +114,53 @@ func (s *Socket) write(command string) error { func (s *Socket) read(process Processor) error { if process == nil { - return errors.New("process func is nil") + return errors.New("read: process func is nil") } - if s.conn == nil { - return errors.New("attempt to read on nil connection") + return errors.New("read: nil connection") } - if err := s.conn.SetReadDeadline(time.Now().Add(s.timeout())); err != nil { + if err := s.conn.SetReadDeadline(s.deadline()); err != nil { return err } sc := bufio.NewScanner(s.conn) - for sc.Scan() && process(sc.Bytes()) { + var n int64 + limit := s.MaxReadLines + + for sc.Scan() { + more, err := process(sc.Bytes()) + if err != nil { + return err + } + if n++; limit > 0 && n > limit { + return fmt.Errorf("read line limit exceeded (%d", limit) + } + if !more { + break + } } return sc.Err() } +func (s *Socket) dial() (net.Conn, error) { + network, address := parseAddress(s.Address) + + var d net.Dialer + d.Timeout = s.timeout() + + if s.TLSConf != nil { + return tls.DialWithDialer(&d, network, address, s.TLSConf) + } + return d.DialContext(context.Background(), network, address) +} + +func (s *Socket) deadline() time.Time { + return time.Now().Add(s.timeout()) +} + func (s *Socket) timeout() time.Duration { if s.Timeout == 0 { return time.Second diff --git a/src/go/plugin/go.d/pkg/socket/client_test.go b/src/go/plugin/go.d/pkg/socket/client_test.go index 53de50951ad5d0..8ef3f9b9957677 100644 --- a/src/go/plugin/go.d/pkg/socket/client_test.go +++ b/src/go/plugin/go.d/pkg/socket/client_test.go @@ -3,152 +3,86 @@ package socket import ( - "crypto/tls" "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -const ( - testServerAddress = "127.0.0.1:9999" - testUdpServerAddress = "udp://127.0.0.1:9999" - testUnixServerAddress = "/tmp/testSocketFD" - defaultTimeout = 100 * time.Millisecond -) - -var tcpConfig = Config{ - Address: testServerAddress, - Timeout: defaultTimeout, - TLSConf: nil, -} - -var udpConfig = Config{ - Address: testUdpServerAddress, - Timeout: defaultTimeout, - TLSConf: nil, -} - -var unixConfig = Config{ - Address: testUnixServerAddress, - Timeout: defaultTimeout, - TLSConf: nil, -} - -var tcpTlsConfig = Config{ - Address: testServerAddress, - Timeout: defaultTimeout, - TLSConf: &tls.Config{}, -} - -func Test_clientCommand(t *testing.T) { - srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1} - go func() { _ = srv.Run(); defer func() { _ = srv.Close() }() }() - - time.Sleep(time.Millisecond * 100) - sock := New(tcpConfig) - require.NoError(t, sock.Connect()) - err := sock.Command("ping\n", func(bytes []byte) bool { - assert.Equal(t, "pong", string(bytes)) - return true - }) - require.NoError(t, sock.Disconnect()) - require.NoError(t, err) -} - -func Test_clientTimeout(t *testing.T) { - srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1} - go func() { _ = srv.Run() }() - - time.Sleep(time.Millisecond * 100) - sock := New(tcpConfig) - require.NoError(t, sock.Connect()) - sock.Timeout = 0 - err := sock.Command("ping\n", func(bytes []byte) bool { - assert.Equal(t, "pong", string(bytes)) - return true - }) - require.NoError(t, err) -} - -func Test_clientIncompleteSSL(t *testing.T) { - srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1} - go func() { _ = srv.Run() }() - - time.Sleep(time.Millisecond * 100) - sock := New(tcpTlsConfig) - err := sock.Connect() - require.Error(t, err) -} - -func Test_clientCommandStopProcessing(t *testing.T) { - srv := &tcpServer{addr: testServerAddress, rowsNumResp: 2} - go func() { _ = srv.Run() }() - - time.Sleep(time.Millisecond * 100) - sock := New(tcpConfig) - require.NoError(t, sock.Connect()) - err := sock.Command("ping\n", func(bytes []byte) bool { - assert.Equal(t, "pong", string(bytes)) - return false - }) - require.NoError(t, sock.Disconnect()) - require.NoError(t, err) -} - -func Test_clientUDPCommand(t *testing.T) { - srv := &udpServer{addr: testServerAddress, rowsNumResp: 1} - go func() { _ = srv.Run(); defer func() { _ = srv.Close() }() }() - - time.Sleep(time.Millisecond * 100) - sock := New(udpConfig) - require.NoError(t, sock.Connect()) - err := sock.Command("ping\n", func(bytes []byte) bool { - assert.Equal(t, "pong", string(bytes)) - return false - }) - require.NoError(t, sock.Disconnect()) - require.NoError(t, err) -} - -func Test_clientTCPAddress(t *testing.T) { - srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1} - go func() { _ = srv.Run() }() - time.Sleep(time.Millisecond * 100) - - sock := New(tcpConfig) - require.NoError(t, sock.Connect()) - - tcpConfig.Address = "tcp://" + tcpConfig.Address - sock = New(tcpConfig) - require.NoError(t, sock.Connect()) -} - -func Test_clientUnixCommand(t *testing.T) { - srv := &unixServer{addr: testUnixServerAddress, rowsNumResp: 1} - // cleanup previous file descriptors - _ = srv.Close() - go func() { _ = srv.Run() }() - - time.Sleep(time.Millisecond * 200) - sock := New(unixConfig) - require.NoError(t, sock.Connect()) - err := sock.Command("ping\n", func(bytes []byte) bool { - assert.Equal(t, "pong", string(bytes)) - return false - }) - require.NoError(t, err) - require.NoError(t, sock.Disconnect()) -} - -func Test_clientEmptyProcessFunc(t *testing.T) { - srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1} - go func() { _ = srv.Run() }() - - time.Sleep(time.Millisecond * 100) - sock := New(tcpConfig) - require.NoError(t, sock.Connect()) - err := sock.Command("ping\n", nil) - require.Error(t, err, "nil process func should return an error") +func TestSocket_Command(t *testing.T) { + const ( + testServerAddress = "tcp://127.0.0.1:9999" + testUdpServerAddress = "udp://127.0.0.1:9999" + testUnixServerAddress = "unix:///tmp/testSocketFD" + defaultTimeout = 1000 * time.Millisecond + ) + + type server interface { + Run() error + Close() error + } + + tests := map[string]struct { + srv server + cfg Config + wantConnectErr bool + wantCommandErr bool + }{ + "tcp": { + srv: newTCPServer(testServerAddress), + cfg: Config{ + Address: testServerAddress, + Timeout: defaultTimeout, + }, + }, + "udp": { + srv: newUDPServer(testUdpServerAddress), + cfg: Config{ + Address: testUdpServerAddress, + Timeout: defaultTimeout, + }, + }, + "unix": { + srv: newUnixServer(testUnixServerAddress), + cfg: Config{ + Address: testUnixServerAddress, + Timeout: defaultTimeout, + }, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + go func() { + defer func() { _ = test.srv.Close() }() + require.NoError(t, test.srv.Run()) + }() + time.Sleep(time.Millisecond * 500) + + sock := New(test.cfg) + + err := sock.Connect() + + if test.wantConnectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + defer sock.Disconnect() + + var resp string + err = sock.Command("ping\n", func(bytes []byte) (bool, error) { + resp = string(bytes) + return false, nil + }) + + if test.wantCommandErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, "pong", resp) + } + }) + } } diff --git a/src/go/plugin/go.d/pkg/socket/server.go b/src/go/plugin/go.d/pkg/socket/server.go new file mode 100644 index 00000000000000..ac2fde3b0ad169 --- /dev/null +++ b/src/go/plugin/go.d/pkg/socket/server.go @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package socket + +import ( + "bufio" + "context" + "errors" + "fmt" + "net" + "os" + "sync" + "time" +) + +func newTCPServer(addr string) *tcpServer { + ctx, cancel := context.WithCancel(context.Background()) + _, addr = parseAddress(addr) + return &tcpServer{ + addr: addr, + ctx: ctx, + cancel: cancel, + } +} + +type tcpServer struct { + addr string + listener net.Listener + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc +} + +func (t *tcpServer) Run() error { + var err error + t.listener, err = net.Listen("tcp", t.addr) + if err != nil { + return fmt.Errorf("failed to start TCP server: %w", err) + } + return t.handleConnections() +} + +func (t *tcpServer) Close() (err error) { + t.cancel() + if t.listener != nil { + if err := t.listener.Close(); err != nil { + return fmt.Errorf("failed to close TCP server: %w", err) + } + } + t.wg.Wait() + return nil +} + +func (t *tcpServer) handleConnections() (err error) { + for { + select { + case <-t.ctx.Done(): + return nil + default: + conn, err := t.listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return fmt.Errorf("could not accept connection: %v", err) + } + t.wg.Add(1) + go func() { + defer t.wg.Done() + t.handleConnection(conn) + }() + } + } +} + +func (t *tcpServer) handleConnection(conn net.Conn) { + defer func() { _ = conn.Close() }() + + if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil { + return + } + + rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + if _, err := rw.ReadString('\n'); err != nil { + writeResponse(rw, fmt.Sprintf("failed to read input: %v\n", err)) + } else { + writeResponse(rw, "pong\n") + } +} + +func newUDPServer(addr string) *udpServer { + ctx, cancel := context.WithCancel(context.Background()) + _, addr = parseAddress(addr) + return &udpServer{ + addr: addr, + ctx: ctx, + cancel: cancel, + } +} + +type udpServer struct { + addr string + conn *net.UDPConn + ctx context.Context + cancel context.CancelFunc +} + +func (u *udpServer) Run() error { + addr, err := net.ResolveUDPAddr("udp", u.addr) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %w", err) + } + + u.conn, err = net.ListenUDP("udp", addr) + if err != nil { + return fmt.Errorf("failed to start UDP server: %w", err) + } + + return u.handleConnections() +} + +func (u *udpServer) Close() (err error) { + u.cancel() + if u.conn != nil { + if err := u.conn.Close(); err != nil { + return fmt.Errorf("failed to close UDP server: %w", err) + } + } + return nil +} + +func (u *udpServer) handleConnections() error { + buffer := make([]byte, 8192) + for { + select { + case <-u.ctx.Done(): + return nil + default: + if err := u.conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + continue + } + + _, addr, err := u.conn.ReadFromUDP(buffer[0:]) + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + return fmt.Errorf("failed to read UDP packet: %w", err) + } + continue + } + + if _, err := u.conn.WriteToUDP([]byte("pong\n"), addr); err != nil { + return fmt.Errorf("failed to write UDP response: %w", err) + } + } + } +} + +func newUnixServer(addr string) *unixServer { + ctx, cancel := context.WithCancel(context.Background()) + _, addr = parseAddress(addr) + return &unixServer{ + addr: addr, + ctx: ctx, + cancel: cancel, + } +} + +type unixServer struct { + addr string + listener *net.UnixListener + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc +} + +func (u *unixServer) Run() error { + if err := os.Remove(u.addr); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to clean up existing socket: %w", err) + } + + addr, err := net.ResolveUnixAddr("unix", u.addr) + if err != nil { + return fmt.Errorf("failed to resolve Unix address: %w", err) + } + + u.listener, err = net.ListenUnix("unix", addr) + if err != nil { + return fmt.Errorf("failed to start Unix server: %w", err) + } + + return u.handleConnections() +} + +func (u *unixServer) Close() error { + u.cancel() + + if u.listener != nil { + if err := u.listener.Close(); err != nil { + return fmt.Errorf("failed to close Unix server: %w", err) + } + } + + u.wg.Wait() + _ = os.Remove(u.addr) + + return nil +} + +func (u *unixServer) handleConnections() error { + for { + select { + case <-u.ctx.Done(): + return nil + default: + if err := u.listener.SetDeadline(time.Now().Add(time.Second)); err != nil { + continue + } + + conn, err := u.listener.AcceptUnix() + if err != nil { + if !errors.Is(err, os.ErrDeadlineExceeded) { + return err + } + continue + } + + u.wg.Add(1) + go func() { + defer u.wg.Done() + u.handleConnection(conn) + }() + } + } +} + +func (u *unixServer) handleConnection(conn net.Conn) { + defer func() { _ = conn.Close() }() + + if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil { + return + } + + rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + if _, err := rw.ReadString('\n'); err != nil { + writeResponse(rw, fmt.Sprintf("failed to read input: %v\n", err)) + } else { + writeResponse(rw, "pong\n") + } + +} + +func writeResponse(rw *bufio.ReadWriter, response string) { + _, _ = rw.WriteString(response) + _ = rw.Flush() +} diff --git a/src/go/plugin/go.d/pkg/socket/servers_test.go b/src/go/plugin/go.d/pkg/socket/servers_test.go deleted file mode 100644 index d6617816242d5d..00000000000000 --- a/src/go/plugin/go.d/pkg/socket/servers_test.go +++ /dev/null @@ -1,139 +0,0 @@ -// SPDX-License-Identifier: GPL-3.0-or-later - -package socket - -import ( - "bufio" - "errors" - "fmt" - "net" - "os" - "strings" - "time" -) - -type tcpServer struct { - addr string - server net.Listener - rowsNumResp int -} - -func (t *tcpServer) Run() (err error) { - t.server, err = net.Listen("tcp", t.addr) - if err != nil { - return - } - return t.handleConnections() -} - -func (t *tcpServer) Close() (err error) { - return t.server.Close() -} - -func (t *tcpServer) handleConnections() (err error) { - for { - conn, err := t.server.Accept() - if err != nil || conn == nil { - return errors.New("could not accept connection") - } - t.handleConnection(conn) - } -} - -func (t *tcpServer) handleConnection(conn net.Conn) { - defer func() { _ = conn.Close() }() - _ = conn.SetDeadline(time.Now().Add(time.Millisecond * 100)) - - rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - _, err := rw.ReadString('\n') - if err != nil { - _, _ = rw.WriteString("failed to read input") - _ = rw.Flush() - } else { - resp := strings.Repeat("pong\n", t.rowsNumResp) - _, _ = rw.WriteString(resp) - _ = rw.Flush() - } -} - -type udpServer struct { - addr string - conn *net.UDPConn - rowsNumResp int -} - -func (u *udpServer) Run() (err error) { - addr, err := net.ResolveUDPAddr("udp", u.addr) - if err != nil { - return err - } - u.conn, err = net.ListenUDP("udp", addr) - if err != nil { - return - } - u.handleConnections() - return nil -} - -func (u *udpServer) Close() (err error) { - return u.conn.Close() -} - -func (u *udpServer) handleConnections() { - for { - var buf [2048]byte - _, addr, _ := u.conn.ReadFromUDP(buf[0:]) - resp := strings.Repeat("pong\n", u.rowsNumResp) - _, _ = u.conn.WriteToUDP([]byte(resp), addr) - } -} - -type unixServer struct { - addr string - conn *net.UnixListener - rowsNumResp int -} - -func (u *unixServer) Run() (err error) { - _, _ = os.CreateTemp("/tmp", "testSocketFD") - addr, err := net.ResolveUnixAddr("unix", u.addr) - if err != nil { - return err - } - u.conn, err = net.ListenUnix("unix", addr) - if err != nil { - return - } - go u.handleConnections() - return nil -} - -func (u *unixServer) Close() (err error) { - _ = os.Remove(testUnixServerAddress) - return u.conn.Close() -} - -func (u *unixServer) handleConnections() { - var conn net.Conn - var err error - conn, err = u.conn.AcceptUnix() - if err != nil { - panic(fmt.Errorf("could not accept connection: %v", err)) - } - u.handleConnection(conn) -} - -func (u *unixServer) handleConnection(conn net.Conn) { - _ = conn.SetDeadline(time.Now().Add(time.Second)) - - rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - _, err := rw.ReadString('\n') - if err != nil { - _, _ = rw.WriteString("failed to read input") - _ = rw.Flush() - } else { - resp := strings.Repeat("pong\n", u.rowsNumResp) - _, _ = rw.WriteString(resp) - _ = rw.Flush() - } -} diff --git a/src/go/plugin/go.d/pkg/socket/utils.go b/src/go/plugin/go.d/pkg/socket/utils.go index dcc48b383f5713..09b9e64eca0012 100644 --- a/src/go/plugin/go.d/pkg/socket/utils.go +++ b/src/go/plugin/go.d/pkg/socket/utils.go @@ -12,7 +12,7 @@ func IsUdpSocket(address string) bool { return strings.HasPrefix(address, "udp://") } -func networkType(address string) (string, string) { +func parseAddress(address string) (string, string) { switch { case IsUnixSocket(address): address = strings.TrimPrefix(address, "unix://")