Skip to content

Commit

Permalink
test using gorm serialiser instead of custom hooks
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Sep 30, 2024
1 parent b7bbdaa commit e691a12
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 213 deletions.
15 changes: 13 additions & 2 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)

func init() {
schema.RegisterSerializer("text", TextSerialiser{})
}

var errDatabaseNotSupported = errors.New("database type not supported")

// KV is a key-value store in a psql table. For future use...
Expand All @@ -31,7 +36,8 @@ type KV struct {
}

type HSDatabase struct {
DB *gorm.DB
DB *gorm.DB
cfg *types.DatabaseConfig

baseDomain string
}
Expand Down Expand Up @@ -421,7 +427,8 @@ func NewHeadscaleDatabase(
}

db := HSDatabase{
DB: dbConn,
DB: dbConn,
cfg: &cfg,

baseDomain: baseDomain,
}
Expand Down Expand Up @@ -621,6 +628,10 @@ func (hsdb *HSDatabase) Close() error {
return err
}

if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog {
db.Exec("VACUUM")
}

return db.Close()
}

Expand Down
23 changes: 23 additions & 0 deletions hscontrol/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,29 @@ func TestMigrations(t *testing.T) {
}
},
},
{
dbPath: "testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite3",
wantFunc: func(t *testing.T, h *HSDatabase) {
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx)
})
assert.NoError(t, err)

for _, node := range nodes {
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
assert.Contains(t, node.MachineKey.String(), "mkey:")
assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey")
assert.Contains(t, node.NodeKey.String(), "nodekey:")
assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey")
assert.Contains(t, node.DiscoKey.String(), "discokey:")
assert.NotNil(t, node.IPv4)
assert.NotNil(t, node.IPv4)
assert.Len(t, node.Endpoints, 1)
assert.NotNil(t, node.Hostinfo)
assert.NotNil(t, node.MachineKey)
}
},
},
}

for _, tt := range tests {
Expand Down
37 changes: 0 additions & 37 deletions hscontrol/db/ip_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package db

import (
"database/sql"
"fmt"
"net/netip"
"strings"
Expand Down Expand Up @@ -291,15 +290,7 @@ func TestBackfillIPAddresses(t *testing.T) {
v4 := fmt.Sprintf("100.64.0.%d", i)
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
return &types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: v4,
},
IPv4: nap(v4),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: v6,
},
IPv6: nap(v6),
}
}
Expand Down Expand Up @@ -331,15 +322,7 @@ func TestBackfillIPAddresses(t *testing.T) {

want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
Expand All @@ -364,15 +347,7 @@ func TestBackfillIPAddresses(t *testing.T) {

want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
Expand All @@ -397,10 +372,6 @@ func TestBackfillIPAddresses(t *testing.T) {

want: types.Nodes{
&types.Node{
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv4: nap("100.64.0.1"),
},
},
Expand All @@ -425,10 +396,6 @@ func TestBackfillIPAddresses(t *testing.T) {

want: types.Nodes{
&types.Node{
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
IPv6: nap("fd7a:115c:a1e0::1"),
},
},
Expand Down Expand Up @@ -474,13 +441,9 @@ func TestBackfillIPAddresses(t *testing.T) {

comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
"ID",
"MachineKeyDatabaseField",
"NodeKeyDatabaseField",
"DiscoKeyDatabaseField",
"User",
"UserID",
"Endpoints",
"HostinfoDatabaseField",
"Hostinfo",
"Routes",
"CreatedAt",
Expand Down
10 changes: 7 additions & 3 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()

v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1))
node := types.Node{
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
Expand Down Expand Up @@ -239,6 +239,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {

adminNode, err := db.GetNodeByID(1)
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
c.Assert(adminNode.IPv4, check.NotNil)
c.Assert(adminNode.IPv6, check.IsNil)
c.Assert(err, check.IsNil)

testNode, err := db.GetNodeByID(2)
Expand All @@ -247,9 +249,11 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {

adminPeers, err := db.ListPeers(adminNode.ID)
c.Assert(err, check.IsNil)
c.Assert(len(adminPeers), check.Equals, 9)

testPeers, err := db.ListPeers(testNode.ID)
c.Assert(err, check.IsNil)
c.Assert(len(testPeers), check.Equals, 9)

adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
c.Assert(err, check.IsNil)
Expand All @@ -259,14 +263,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {

peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)

c.Log(peersOfAdminNode)
c.Log(peersOfTestNode)

c.Assert(len(peersOfTestNode), check.Equals, 9)
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")

c.Log(peersOfAdminNode)
c.Assert(len(peersOfAdminNode), check.Equals, 9)
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
Expand Down
99 changes: 99 additions & 0 deletions hscontrol/db/text_serialiser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package db

import (
"context"
"encoding"
"fmt"
"reflect"

"gorm.io/gorm/schema"
)

// Got from https://github.com/xdg-go/strum/blob/main/types.go
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()

func isTextUnmarshaler(rv reflect.Value) bool {
return rv.Type().Implements(textUnmarshalerType)
}

func maybeInstantiatePtr(rv reflect.Value) {
if rv.Kind() == reflect.Ptr && rv.IsNil() {
np := reflect.New(rv.Type().Elem())
rv.Set(np)
}
}

func decodingError(name string, err error) error {
return fmt.Errorf("error decoding to %s: %w", name, err)
}

// TextSerialiser implements the Serialiser interface for fields that
// have a type that implements encoding.TextUnmarshaler.
type TextSerialiser struct{}

func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)

// If the field is a pointer, we need to dereference it to get the actual type
// so we do not end with a second pointer.
if fieldValue.Elem().Kind() == reflect.Ptr {
fieldValue = fieldValue.Elem()
}

if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
}

if isTextUnmarshaler(fieldValue) {
maybeInstantiatePtr(fieldValue)
f := fieldValue.MethodByName("UnmarshalText")
args := []reflect.Value{reflect.ValueOf(bytes)}
ret := f.Call(args)
if !ret[0].IsNil() {
return decodingError(field.Name, ret[0].Interface().(error))
}

// If the underlying field is to a pointer type, we need to
// assign the value as a pointer to it.
// If it is not a pointer, we need to assign the value to the
// field.
dstField := field.ReflectValueOf(ctx, dst)
if dstField.Kind() == reflect.Ptr {
dstField.Set(fieldValue)
} else {
dstField.Set(fieldValue.Elem())
}
return nil
} else {
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
}
}

return
}

func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
switch v := fieldValue.(type) {
case encoding.TextMarshaler:
// If the value is nil, we return nil, however, go nil values are not
// always comparable, particularly when reflection is involved:
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
return nil, nil
}
b, err := v.MarshalText()
if err != nil {
return nil, err
}
return string(b), nil
default:
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
}
}
38 changes: 29 additions & 9 deletions hscontrol/policy/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"
"time"

"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -593,6 +594,11 @@ func (pol *ACLPolicy) ExpandAlias(
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
// that are correctly tagged since they should not be listed as being in the user
// we assume in this function that we only have nodes from 1 user.
//
// TODO(kradalby): It is quite hard to understand what this function is doing,
// it seems like it trying to ensure that we dont include nodes that are tagged
// when we look up the nodes owned by a user.
// This should be refactored to be more clear as part of the Tags work in #1369
func excludeCorrectlyTaggedNodes(
aclPolicy *ACLPolicy,
nodes types.Nodes,
Expand All @@ -611,17 +617,16 @@ func excludeCorrectlyTaggedNodes(
for _, node := range nodes {
found := false

if node.Hostinfo == nil {
continue
}

for _, t := range node.Hostinfo.RequestTags {
if util.StringOrPrefixListContains(tags, t) {
found = true
if node.Hostinfo != nil {
for _, t := range node.Hostinfo.RequestTags {
if util.StringOrPrefixListContains(tags, t) {
found = true

break
break
}
}
}

if len(node.ForcedTags) > 0 {
found = true
}
Expand Down Expand Up @@ -966,6 +971,14 @@ func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
return out
}

func FilterRulesToMatchers(filter []tailcfg.FilterRule) []matcher.Match {
matchers := make([]matcher.Match, len(filter))
for i, rule := range filter {
matchers[i] = matcher.MatchFromFilterRule(rule)
}
return matchers
}

// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
func FilterNodesByACL(
node *types.Node,
Expand All @@ -974,12 +987,19 @@ func FilterNodesByACL(
) types.Nodes {
var result types.Nodes

// TODO(kradalby): Regenerate this everytime the filter change, instead of
// every time we use it.
matchers := FilterRulesToMatchers(filter)

for index, peer := range nodes {
if peer.ID == node.ID {
continue
}

if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
log.Printf("Checking if %s can access %s", node.Hostname, peer.Hostname)

if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) {
log.Printf("CAN ACCESS %s can access %s", node.Hostname, peer.Hostname)
result = append(result, peer)
}
}
Expand Down
Loading

0 comments on commit e691a12

Please sign in to comment.