From e8bf6ed921cac8e5ac574b9e8e9f48d34a9b05cb Mon Sep 17 00:00:00 2001 From: "Victor M. Alvarez" Date: Fri, 13 Sep 2024 18:15:41 +0200 Subject: [PATCH] refactor(go): Change the serialization API in Golang. The existing `Serialize` and `Deserialize` functions are replaced by `WriteTo` and `ReadFrom`, which write to an `io.Writer` and read from an `io.Reader` respectively. This new API is more efficient because it doesn't need to make a copy of the compiled rules in memory. This also removes an issue that existed in `Serialize` when serialized rules are larger than 4GB. It turns out that `C.GoBytes` receives a length of type `C.int` which is a 32-bits integer, effectively limiting the serialized rules to less than 4GB. --- go/compiler_test.go | 31 +++++++++++++------- go/main.go | 70 ++++++++++++++++++++++++++++++++++++++------- go/scanner_test.go | 8 ++++++ 3 files changed, 88 insertions(+), 21 deletions(-) diff --git a/go/compiler_test.go b/go/compiler_test.go index 115f032f..ecb93208 100644 --- a/go/compiler_test.go +++ b/go/compiler_test.go @@ -1,6 +1,7 @@ package yara_x import ( + "bytes" "github.com/stretchr/testify/assert" "testing" ) @@ -50,12 +51,20 @@ func TestSerialization(t *testing.T) { r, err := Compile("rule test { condition: true }") assert.NoError(t, err) - b, _ := r.Serialize() - r, _ = Deserialize(b) + var buf bytes.Buffer + // Write rules into buffer + n, err := r.WriteTo(&buf) + + assert.NoError(t, err) + assert.Len(t, buf.Bytes(), int(n)) + + // Read rules from buffer + r, _ = ReadFrom(&buf) + + // Make sure the rules work properly. s := NewScanner(r) scanResults, _ := s.Scan([]byte{}) - assert.Len(t, scanResults.MatchingRules(), 1) } @@ -163,8 +172,8 @@ func TestRulesIter(t *testing.T) { }`) assert.NoError(t, err) - rules := c.Build() - assert.Equal(t, 2, rules.Count()) + rules := c.Build() + assert.Equal(t, 2, rules.Count()) slice := rules.Slice() assert.Len(t, slice, 2) @@ -177,7 +186,7 @@ func TestRulesIter(t *testing.T) { assert.Len(t, slice[0].Metadata(), 0) assert.Len(t, slice[1].Metadata(), 1) - assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier()) + assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier()) } func TestImportsIter(t *testing.T) { @@ -193,12 +202,12 @@ func TestImportsIter(t *testing.T) { }`) assert.NoError(t, err) - rules := c.Build() - imports := rules.Imports() + rules := c.Build() + imports := rules.Imports() - assert.Len(t, imports, 2) - assert.Equal(t, "pe", imports[0]) - assert.Equal(t, "elf", imports[1]) + assert.Len(t, imports, 2) + assert.Equal(t, "pe", imports[0]) + assert.Equal(t, "elf", imports[1]) } func TestWarnings(t *testing.T) { diff --git a/go/main.go b/go/main.go index 1368c451..f9a11438 100644 --- a/go/main.go +++ b/go/main.go @@ -31,6 +31,8 @@ import "C" import ( "errors" + "io" + "reflect" "runtime" "runtime/cgo" "unsafe" @@ -49,25 +51,30 @@ func Compile(src string, opts ...CompileOption) (*Rules, error) { return c.Build(), nil } -// Deserialize deserializes rules from a byte slice. +// ReadFrom reads compiled rules from a reader. // -// The counterpart is [Rules.Serialize] -func Deserialize(data []byte) (*Rules, error) { +// The counterpart is [Rules.WriteTo]. +func ReadFrom(r io.Reader) (*Rules, error) { + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + var ptr *C.uint8_t if len(data) > 0 { ptr = (*C.uint8_t)(unsafe.Pointer(&(data[0]))) } - r := &Rules{cRules: nil} + rules := &Rules{cRules: nil} runtime.LockOSThread() defer runtime.UnlockOSThread() - if C.yrx_rules_deserialize(ptr, C.size_t(len(data)), &r.cRules) != C.SUCCESS { + if C.yrx_rules_deserialize(ptr, C.size_t(len(data)), &rules.cRules) != C.SUCCESS { return nil, errors.New(C.GoString(C.yrx_last_error())) } - return r, nil + return rules, nil } // Rules represents a set of compiled YARA rules. @@ -79,17 +86,60 @@ func (r *Rules) Scan(data []byte) (*ScanResults, error) { return scanner.Scan(data) } -// Serialize converts the compiled rules into a byte slice. -func (r *Rules) Serialize() ([]byte, error) { +// WriteTo writes the compiled rules into a writer. +// +// The counterpart is [ReadFrom]. +func (r *Rules) WriteTo(w io.Writer) (int64, error) { var buf *C.YRX_BUFFER runtime.LockOSThread() defer runtime.UnlockOSThread() if C.yrx_rules_serialize(r.cRules, &buf) != C.SUCCESS { - return nil, errors.New(C.GoString(C.yrx_last_error())) + return 0, errors.New(C.GoString(C.yrx_last_error())) } defer C.yrx_buffer_destroy(buf) runtime.KeepAlive(r) - return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length)), nil + + // We are going to write into `w` in chunks of 64MB. + const chunkSize = 1 << 26 + + // This is the slice that contains the next chunk that will be written. + var chunk []byte + + // Modify the `chunk` slice, making it point to the buffer returned + // by yrx_rules_serialize. This allows us to access the buffer from + // Go without copying the data. This is safe because the slice won't + // be used after the buffer is destroyed. + chunkHdr := (*reflect.SliceHeader)(unsafe.Pointer(&chunk)) + chunkHdr.Data = uintptr(unsafe.Pointer(buf.data)) + chunkHdr.Len = chunkSize + chunkHdr.Cap = chunkSize + + bufLen := C.ulong(buf.length) + bytesWritten := int64(0) + + for { + // If the data to be written is shorted than `chunkSize`, set the length + // of the `chunk` slice to this length. + if bufLen < chunkSize { + chunkHdr.Len = int(bufLen) + chunkHdr.Cap = int(bufLen) + } + if n, err := w.Write(chunk); err == nil { + bytesWritten += int64(n) + } else { + return 0, err + } + // If `bufLen` is still greater than `chunkSize`, there's more data to + // write, if not, we are done. + if bufLen > chunkSize { + chunkHdr.Data += chunkSize + bufLen -= chunkSize + } else { + break + } + } + + return bytesWritten, nil } // Destroy destroys the compiled YARA rules represented by [Rules]. diff --git a/go/scanner_test.go b/go/scanner_test.go index 60a76a06..a888e64c 100644 --- a/go/scanner_test.go +++ b/go/scanner_test.go @@ -19,6 +19,14 @@ func TestScanner1(t *testing.T) { assert.Equal(t, "t", matchingRules[0].Identifier()) assert.Equal(t, "default", matchingRules[0].Namespace()) assert.Len(t, matchingRules[0].Patterns(), 0) + + scanResults, _ = s.Scan(nil) + matchingRules = scanResults.MatchingRules() + + assert.Len(t, matchingRules, 1) + assert.Equal(t, "t", matchingRules[0].Identifier()) + assert.Equal(t, "default", matchingRules[0].Namespace()) + assert.Len(t, matchingRules[0].Patterns(), 0) } func TestScanner2(t *testing.T) {