Skip to content

Commit

Permalink
Guard against panic caused by misusing Compiler
Browse files Browse the repository at this point in the history
Close: #149
  • Loading branch information
hillu committed May 6, 2024
1 parent 2e52045 commit 4f5cd55
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
26 changes: 20 additions & 6 deletions compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,28 @@ func (c *Compiler) setCallbackData(cb CompilerIncludeFunc) {
}
}

var (
errParse = errors.New("Compiler cannot be used after parse error")
errRules = errors.New("Compiler cannot be used after producing rule set")
)

func (c *Compiler) checkUsage() (err error) {
if c.cptr.errors != 0 {
err = errParse
} else if c.cptr.rules != nil {
err = errRules
}
return
}

// AddFile compiles rules from a file. Rules are added to the
// specified namespace.
//
// If this function returns an error, the Compiler object will become
// unusable.
func (c *Compiler) AddFile(file *os.File, namespace string) (err error) {
if c.cptr.errors != 0 {
return errors.New("Compiler cannot be used after parse error")
if err := c.checkUsage(); err != nil {
return err
}
var ns *C.char
if namespace != "" {
Expand Down Expand Up @@ -164,8 +178,8 @@ func (c *Compiler) AddFile(file *os.File, namespace string) (err error) {
// If this function returns an error, the Compiler object will become
// unusable.
func (c *Compiler) AddString(rules string, namespace string) (err error) {
if c.cptr.errors != 0 {
return errors.New("Compiler cannot be used after parse error")
if err := c.checkUsage(); err != nil {
return err
}
var ns *C.char
if namespace != "" {
Expand Down Expand Up @@ -224,8 +238,8 @@ func (c *Compiler) DefineVariable(identifier string, value interface{}) (err err

// GetRules returns the compiled ruleset.
func (c *Compiler) GetRules() (*Rules, error) {
if c.cptr.errors != 0 {
return nil, errors.New("Compiler cannot be used after parse error")
if err := c.checkUsage(); err != nil {
return nil, err
}
var yrRules *C.YR_RULES
if err := newError(C.yr_compiler_get_rules(c.cptr, &yrRules)); err != nil {
Expand Down
13 changes: 13 additions & 0 deletions compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ func TestErrors(t *testing.T) {
}
}

func TestErrorNoPanic(t *testing.T) {
c, _ := NewCompiler()
c.AddString("rule test { condition: true }", "")
if _, err := c.GetRules(); err != nil {
t.Errorf("did not expect error: %v", err)
}
if err := c.AddString("rule test { }", ""); err == nil {
t.Error("expected AddString after GetRules to fail")
} else {
t.Logf("got error as expected: %v", err)
}
}

func setupCompiler(t *testing.T) *Compiler {
c, err := NewCompiler()
if err != nil {
Expand Down

0 comments on commit 4f5cd55

Please sign in to comment.