Skip to content

Commit

Permalink
refactor(go): Change the serialization API in Golang.
Browse files Browse the repository at this point in the history
The existing `Serialize` and `Deserialize` functions are replaced by `WriteTo` and `ReadFrom`, which write to an `io.Writer` and read from an `io.Reader` respectively. This new API is more efficient because it doesn't need to make a copy of the compiled rules in memory.

This also removes an issue that existed in `Serialize` when serialized rules are larger than 4GB. It turns out that `C.GoBytes` receives a length of type `C.int` which is a 32-bits integer, effectively limiting the serialized rules to less than 4GB.
  • Loading branch information
plusvic committed Sep 13, 2024
1 parent d7db62b commit e8bf6ed
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 21 deletions.
31 changes: 20 additions & 11 deletions go/compiler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package yara_x

import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
)
Expand Down Expand Up @@ -50,12 +51,20 @@ func TestSerialization(t *testing.T) {
r, err := Compile("rule test { condition: true }")
assert.NoError(t, err)

b, _ := r.Serialize()
r, _ = Deserialize(b)
var buf bytes.Buffer

// Write rules into buffer
n, err := r.WriteTo(&buf)

assert.NoError(t, err)
assert.Len(t, buf.Bytes(), int(n))

// Read rules from buffer
r, _ = ReadFrom(&buf)

// Make sure the rules work properly.
s := NewScanner(r)
scanResults, _ := s.Scan([]byte{})

assert.Len(t, scanResults.MatchingRules(), 1)
}

Expand Down Expand Up @@ -163,8 +172,8 @@ func TestRulesIter(t *testing.T) {
}`)
assert.NoError(t, err)

rules := c.Build()
assert.Equal(t, 2, rules.Count())
rules := c.Build()
assert.Equal(t, 2, rules.Count())

slice := rules.Slice()
assert.Len(t, slice, 2)
Expand All @@ -177,7 +186,7 @@ func TestRulesIter(t *testing.T) {
assert.Len(t, slice[0].Metadata(), 0)
assert.Len(t, slice[1].Metadata(), 1)

assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier())
assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier())
}

func TestImportsIter(t *testing.T) {
Expand All @@ -193,12 +202,12 @@ func TestImportsIter(t *testing.T) {
}`)
assert.NoError(t, err)

rules := c.Build()
imports := rules.Imports()
rules := c.Build()
imports := rules.Imports()

assert.Len(t, imports, 2)
assert.Equal(t, "pe", imports[0])
assert.Equal(t, "elf", imports[1])
assert.Len(t, imports, 2)
assert.Equal(t, "pe", imports[0])
assert.Equal(t, "elf", imports[1])
}

func TestWarnings(t *testing.T) {
Expand Down
70 changes: 60 additions & 10 deletions go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import "C"

import (
"errors"
"io"
"reflect"
"runtime"
"runtime/cgo"
"unsafe"
Expand All @@ -49,25 +51,30 @@ func Compile(src string, opts ...CompileOption) (*Rules, error) {
return c.Build(), nil
}

// Deserialize deserializes rules from a byte slice.
// ReadFrom reads compiled rules from a reader.
//
// The counterpart is [Rules.Serialize]
func Deserialize(data []byte) (*Rules, error) {
// The counterpart is [Rules.WriteTo].
func ReadFrom(r io.Reader) (*Rules, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}

var ptr *C.uint8_t
if len(data) > 0 {
ptr = (*C.uint8_t)(unsafe.Pointer(&(data[0])))
}

r := &Rules{cRules: nil}
rules := &Rules{cRules: nil}

runtime.LockOSThread()
defer runtime.UnlockOSThread()

if C.yrx_rules_deserialize(ptr, C.size_t(len(data)), &r.cRules) != C.SUCCESS {
if C.yrx_rules_deserialize(ptr, C.size_t(len(data)), &rules.cRules) != C.SUCCESS {
return nil, errors.New(C.GoString(C.yrx_last_error()))
}

return r, nil
return rules, nil
}

// Rules represents a set of compiled YARA rules.
Expand All @@ -79,17 +86,60 @@ func (r *Rules) Scan(data []byte) (*ScanResults, error) {
return scanner.Scan(data)
}

// Serialize converts the compiled rules into a byte slice.
func (r *Rules) Serialize() ([]byte, error) {
// WriteTo writes the compiled rules into a writer.
//
// The counterpart is [ReadFrom].
func (r *Rules) WriteTo(w io.Writer) (int64, error) {
var buf *C.YRX_BUFFER
runtime.LockOSThread()
defer runtime.UnlockOSThread()
if C.yrx_rules_serialize(r.cRules, &buf) != C.SUCCESS {
return nil, errors.New(C.GoString(C.yrx_last_error()))
return 0, errors.New(C.GoString(C.yrx_last_error()))
}
defer C.yrx_buffer_destroy(buf)
runtime.KeepAlive(r)
return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length)), nil

// We are going to write into `w` in chunks of 64MB.
const chunkSize = 1 << 26

// This is the slice that contains the next chunk that will be written.
var chunk []byte

// Modify the `chunk` slice, making it point to the buffer returned
// by yrx_rules_serialize. This allows us to access the buffer from
// Go without copying the data. This is safe because the slice won't
// be used after the buffer is destroyed.
chunkHdr := (*reflect.SliceHeader)(unsafe.Pointer(&chunk))
chunkHdr.Data = uintptr(unsafe.Pointer(buf.data))
chunkHdr.Len = chunkSize
chunkHdr.Cap = chunkSize

bufLen := C.ulong(buf.length)
bytesWritten := int64(0)

for {
// If the data to be written is shorted than `chunkSize`, set the length
// of the `chunk` slice to this length.
if bufLen < chunkSize {
chunkHdr.Len = int(bufLen)
chunkHdr.Cap = int(bufLen)
}
if n, err := w.Write(chunk); err == nil {
bytesWritten += int64(n)
} else {
return 0, err
}
// If `bufLen` is still greater than `chunkSize`, there's more data to
// write, if not, we are done.
if bufLen > chunkSize {
chunkHdr.Data += chunkSize
bufLen -= chunkSize
} else {
break
}
}

return bytesWritten, nil
}

// Destroy destroys the compiled YARA rules represented by [Rules].
Expand Down
8 changes: 8 additions & 0 deletions go/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ func TestScanner1(t *testing.T) {
assert.Equal(t, "t", matchingRules[0].Identifier())
assert.Equal(t, "default", matchingRules[0].Namespace())
assert.Len(t, matchingRules[0].Patterns(), 0)

scanResults, _ = s.Scan(nil)
matchingRules = scanResults.MatchingRules()

assert.Len(t, matchingRules, 1)
assert.Equal(t, "t", matchingRules[0].Identifier())
assert.Equal(t, "default", matchingRules[0].Namespace())
assert.Len(t, matchingRules[0].Patterns(), 0)
}

func TestScanner2(t *testing.T) {
Expand Down

0 comments on commit e8bf6ed

Please sign in to comment.