pdns-auth-proxy/ldap.go

273 lines
6.4 KiB
Go

package main
import (
"crypto/tls"
"errors"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/go-ldap/ldap/v3"
)
// LdapHandler structure
// Contains a pool of connection, and all information needed for authentication
type LdapHandler struct {
servers []string
baseDN string
bindCn string
bindPw string
searchFilter string
attribute string
pgpAttribute string
validValues []string
currentServer int
ssl bool
authCache *AuthCache
clients chan LdapClient
m sync.RWMutex
}
// LdapClient interface definition
type LdapClient interface {
ldap.Client
}
// NewLdap initializes a new LDAP structure
// We start with an empty pool, no need to prepare connections
func NewLdap(servers []string, bindCn, bindPw, baseDN, filter, attr, pgpAttr string, valid []string, nbConn int, ssl bool) *LdapHandler {
ldap.DefaultTimeout = 3 * time.Second
l := LdapHandler{
servers: servers,
baseDN: baseDN,
bindCn: bindCn,
bindPw: bindPw,
searchFilter: filter,
attribute: attr,
pgpAttribute: pgpAttr,
validValues: valid,
clients: make(chan LdapClient, nbConn),
authCache: NewAuthCache(),
ssl: ssl,
}
return &l
}
// PgpKeys get all knonw GPG keys in the directory that belong to entry
// matching the the Auth Profile
func (l *LdapHandler) PgpKeys() ([]string, error) {
// try the cache first
ret := l.authCache.PgpGet()
if len(ret) > 0 {
return ret, nil
}
// get a conn from the pool
conn, err := l.GetConn()
if err != nil {
return nil, err
}
// and put it back
defer l.BackToPool(conn)
// the search parameters
searchRequest := ldap.NewSearchRequest(
l.baseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
"(objectClass=pgpUserKeyInfo)",
[]string{l.attribute, l.pgpAttribute},
nil,
)
sr, err := conn.Search(searchRequest)
if err != nil {
return nil, err
}
for _, entry := range sr.Entries {
key := ""
valid := false
for _, attribute := range entry.Attributes {
switch (*attribute).Name {
case l.pgpAttribute:
key = attribute.Values[0]
case l.attribute:
if inArray(l.validValues, attribute.Values) {
valid = true
}
}
}
if valid && key != "" {
ret = append(ret, l.PgpNormalizeKey(key))
}
}
if len(ret) > 0 {
l.authCache.PgpSet(ret)
}
return ret, nil
}
// Auth method fot LDAP
// We check that the subject exists in the ldap
// If it's the case, we search for the attribute defined in the LdapHandler
// This attribute's value must then be one of the registered value in the LdapHandler
func (l *LdapHandler) Auth(subject string) (bool, error) {
// use the cache if possible
if l.authCache.Get(subject) {
return true, nil
}
// get a conn from the pool
conn, err := l.GetConn()
if err != nil {
return false, err
}
// and put it back
defer l.BackToPool(conn)
// the search parameters
searchRequest := ldap.NewSearchRequest(
l.baseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
fmt.Sprintf(l.searchFilter, subject),
[]string{l.attribute},
nil,
)
sr, err := conn.Search(searchRequest)
// if the search failed or returns more than one entry, it's not valid
if err != nil {
return false, err
}
if len(sr.Entries) != 1 {
return false, errors.New("User does not exist or too many entries returned")
}
// validate the values returned
for _, attribute := range sr.Entries[0].Attributes {
if (*attribute).Name == l.attribute {
if inArray(l.validValues, attribute.Values) {
l.authCache.Set(subject)
return true, nil
}
}
}
// we found nothing, auth failed
return false, nil
}
// NewConn creates a ldap client object
func (l *LdapHandler) NewConn() (LdapClient, error) {
var err error
var conn *ldap.Conn
// the server list cannot be empty
if len(l.servers) == 0 {
return nil, errors.New("Empty server list")
}
// circle the server list
for i := 0; i < len(l.servers); i++ {
l.currentServer++
if l.currentServer >= len(l.servers) {
l.currentServer = 0
}
if l.ssl {
conn, err = ldap.DialTLS("tcp",
fmt.Sprintf("%s:%d", l.servers[l.currentServer], 636),
&tls.Config{ServerName: l.servers[l.currentServer]})
} else {
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d",
l.servers[l.currentServer], 389))
}
// if the connection fails, maybe there is another server available
if err != nil {
log.Println(err)
continue
}
// First bind with a read only user
// if it work, we are done
if err = conn.Bind(l.bindCn, l.bindPw); err == nil {
return conn, nil
}
}
// no working server were found
return nil, errors.New("No valid ldap server found")
}
// GetConn retrieves a new connection from the pool
func (l *LdapHandler) GetConn() (LdapClient, error) {
// check if the pool is valid
if l.clients == nil {
return nil, errors.New("Pool is closed")
}
select {
case conn := <-l.clients:
if conn == nil {
return nil, errors.New("Pool is closed")
}
if ldapIsAlive(conn) {
return conn, nil
}
// dead connection, restart it
conn.Close()
return l.NewConn()
default:
// No more conn in Pool, create a new one and return it
return l.NewConn()
}
}
// BackToPool returns a connection to the pool
func (l *LdapHandler) BackToPool(p LdapClient) error {
// if it's nil, stop here
if p == nil {
return errors.New("Connexion is closed")
}
// check if it's alive. If not, no need to put it back
if !ldapIsAlive(p) {
p.Close()
return errors.New("returned connection was closed")
}
// same if the pool is not active
if l.clients == nil {
p.Close()
return errors.New("Pool is closed")
}
select {
case l.clients <- p:
return nil
default:
// Pool is full
p.Close()
return nil
}
}
// make a dummy request to validate that the server is alive
func ldapIsAlive(client LdapClient) bool {
_, err := client.Search(&ldap.SearchRequest{BaseDN: "", Scope: ldap.ScopeBaseObject, Filter: "(&)", Attributes: []string{"1.1"}})
return err == nil
}
// Len returns information about the pool
func (l *LdapHandler) Len() int {
return len(l.clients)
}
// Cap returns cap information about the pool
func (l *LdapHandler) Cap() int {
return cap(l.clients)
}
// PgpNormalizeKey replace space with new line in ldap stored GPG keys
func (l *LdapHandler) PgpNormalizeKey(key string) string {
var rkey, sep string
for _, sub := range strings.Split(key, " ") {
rkey = rkey + sep + sub
if strings.HasPrefix(sub, "---") {
sep = " "
}
if strings.HasSuffix(sub, "---") {
sep = "\n"
}
}
return rkey
}