Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement LDAP service discovery (RFC 2782) #362

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions add.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ type AddRequest struct {
Controls []Control
}

func (req *AddRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(AddRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *AddRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
Expand Down
93 changes: 78 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -101,6 +104,7 @@ type Conn struct {
wgClose sync.WaitGroup
outstandingRequests uint
messageMutex sync.Mutex
rootDN string
}

var _ Client = &Conn{}
Expand Down Expand Up @@ -144,35 +148,89 @@ type DialContext struct {
tc *tls.Config
}

func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
func (dc *DialContext) dial(u *url.URL) (conn net.Conn, err error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May you update the related tests and add some comments for documentation. The newly added functionality isn't really self-explanatory, especially when it comes down to the rootDN and baseDN

if u.Scheme == "ldapi" {
if u.Path == "" || u.Path == "/" {
u.Path = "/var/run/slapd/ldapi"
}
return dc.d.Dial("unix", u.Path)
}

host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
if u.Scheme != "ldap" && u.Scheme != "ldaps" {
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
}

switch u.Scheme {
case "ldap":
if port == "" {
port = DefaultLdapPort
hostports := make([]string, 0, 1)

if u.Host == "" {
// Attempt to use DNS SRV discovery for uri like ldap:///dc=example,dc=com
// For ldap:///dc=example,dc=com, it would query for _ldap._tcp.example.com with SRV type record
fragments := strings.Split(u.Path[1:], ",")
pieces := make([]string, 0, len(fragments))
for _, fragment := range fragments {
if strings.HasPrefix(fragment, "dc=") {
pieces = append(pieces, fragment[3:])
}
}
return dc.d.Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":

domain := strings.Join(pieces, ".")
_, records, err := net.LookupSRV("ldap", "tcp", domain)
if err != nil {
return nil, err
}

sort.Slice(records, func(i, j int) bool {
return records[i].Priority > records[j].Priority
})

if u.Scheme == "ldaps" {
dc.tc = &tls.Config{
ServerName: domain,
}
}

for _, record := range records {
port := strconv.Itoa(int(record.Port))
hostports = append(hostports, net.JoinHostPort(record.Target, port))
}
} else {
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
}

if port == "" {
port = DefaultLdapsPort
if u.Scheme == "ldap" {
port = DefaultLdapPort
} else if u.Scheme == "ldaps" {
port = DefaultLdapsPort
}
}

hostports = []string{net.JoinHostPort(host, port)}
}

for _, pair := range hostports {
conn, err = dc.dialConn(u.Scheme, pair)
if conn != nil {
return conn, err
}
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)
}

return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
return
}

func (dc *DialContext) dialConn(scheme, target string) (net.Conn, error) {
switch scheme {
case "ldap":
return dc.d.Dial("tcp", target)
case "ldaps":
return tls.DialWithDialer(dc.d, "tcp", target, dc.tc)
}

return nil, fmt.Errorf("Unknown scheme '%s'", scheme)
}

// Dial connects to the given address on the given network using net.Dial
Expand Down Expand Up @@ -223,7 +281,12 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
return nil, NewError(ErrorNetwork, err)
}

rootDN := ""
if u.Host == "" {
rootDN = u.Path[1:]
}
conn := NewConn(c, u.Scheme == "ldaps")
conn.rootDN = rootDN
conn.Start()
return conn, nil
}
Expand Down
7 changes: 7 additions & 0 deletions del.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ type DelRequest struct {
Controls []Control
}

func (req *DelRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(DelRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *DelRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request")
pkt.Data.Write([]byte(req.DN))
Expand Down
16 changes: 16 additions & 0 deletions dn.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,19 @@ func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool {
func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value)
}

// appendDN is for concat the baseDN and rootDN
// dn stand for user input dn in request
// rootDN stand for dn used during discovery
func appendDN(dn, rootDN string) string {
if rootDN != "" {
var baseDnBuilder strings.Builder
if dn != "" {
baseDnBuilder.WriteString(dn)
baseDnBuilder.WriteByte(',')
}
baseDnBuilder.WriteString(rootDN)
return baseDnBuilder.String()
}
return dn
}
31 changes: 31 additions & 0 deletions dn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,34 @@ func TestDNAncestor(t *testing.T) {
}
}
}

func TestAppendDN(t *testing.T) {
testcases := []struct {
baseDN string
rootDN string
expected string
}{
{
baseDN: "ou=A",
rootDN: "dc=ldap,dc=internal",
expected: "ou=A,dc=ldap,dc=internal",
},
{
baseDN: "ou=A,dc=ldap,dc=internal",
rootDN: "",
expected: "ou=A,dc=ldap,dc=internal",
},
{
baseDN: "",
rootDN: "dc=ldap,dc=internal",
expected: "dc=ldap,dc=internal",
},
}

for i, tc := range testcases {
result := appendDN(tc.baseDN, tc.rootDN)
if result != tc.expected {
t.Errorf("#%d, expected %s, getting: %s", i, tc.expected, result)
}
}
}
8 changes: 8 additions & 0 deletions ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ var attributes = []string{
"description",
}

func TestDialURLViaDiscovery(t *testing.T) {
l, err := DialURL("ldap:///dc=umich,dc=edu")
if err != nil {
t.Fatal(err)
}
defer l.Close()
}

func TestUnsecureDialURL(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
Expand Down
9 changes: 8 additions & 1 deletion moddn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi
//
// Refer NewModifyDNRequest for other parameters
func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool,
newSup string, controls []Control) *ModifyDNRequest {
newSup string, controls []Control) *ModifyDNRequest {
return &ModifyDNRequest{
DN: dn,
NewRDN: rdn,
Expand All @@ -50,6 +50,13 @@ func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool,
}
}

func (req *ModifyDNRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(ModifyDNRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
Expand Down
7 changes: 7 additions & 0 deletions modify.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals
req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}})
}

func (req *ModifyRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(ModifyRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *ModifyRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
Expand Down
10 changes: 10 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ type request interface {
appendTo(*ber.Packet) error
}

type appendDnRequest interface {
request
appendBaseDN(dn string) appendDnRequest
}

type requestFunc func(*ber.Packet) error

func (f requestFunc) appendTo(p *ber.Packet) error {
Expand All @@ -30,6 +35,11 @@ func (l *Conn) doRequest(req request) (*messageContext, error) {

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))

if areq, ok := req.(appendDnRequest); ok {
req = areq.appendBaseDN(l.rootDN)
}

if err := req.appendTo(packet); err != nil {
return nil, err
}
Expand Down
7 changes: 7 additions & 0 deletions search.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ type SearchRequest struct {
Controls []Control
}

func (req *SearchRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(SearchRequest)
*r2 = *req
r2.BaseDN = appendDN(req.BaseDN, dn)
return r2
}

func (req *SearchRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.BaseDN, "Base DN"))
Expand Down
7 changes: 7 additions & 0 deletions v3/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ type AddRequest struct {
Controls []Control
}

func (req *AddRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(AddRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *AddRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
Expand Down
Loading