273 lines
6.4 KiB
Go
273 lines
6.4 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-ldap/ldap/v3"
|
|
)
|
|
|
|
// LdapHandler structure
|
|
// Contains a pool of connection, and all information needed for authentication
|
|
type LdapHandler struct {
|
|
servers []string
|
|
baseDN string
|
|
bindCn string
|
|
bindPw string
|
|
searchFilter string
|
|
attribute string
|
|
pgpAttribute string
|
|
validValues []string
|
|
currentServer int
|
|
ssl bool
|
|
authCache *AuthCache
|
|
clients chan LdapClient
|
|
m sync.RWMutex
|
|
}
|
|
|
|
// LdapClient interface definition
|
|
type LdapClient interface {
|
|
ldap.Client
|
|
}
|
|
|
|
// NewLdap initializes a new LDAP structure
|
|
// We start with an empty pool, no need to prepare connections
|
|
func NewLdap(servers []string, bindCn, bindPw, baseDN, filter, attr, pgpAttr string, valid []string, nbConn int, ssl bool) *LdapHandler {
|
|
ldap.DefaultTimeout = 3 * time.Second
|
|
l := LdapHandler{
|
|
servers: servers,
|
|
baseDN: baseDN,
|
|
bindCn: bindCn,
|
|
bindPw: bindPw,
|
|
searchFilter: filter,
|
|
attribute: attr,
|
|
pgpAttribute: pgpAttr,
|
|
validValues: valid,
|
|
clients: make(chan LdapClient, nbConn),
|
|
authCache: NewAuthCache(),
|
|
ssl: ssl,
|
|
}
|
|
return &l
|
|
}
|
|
|
|
// PgpKeys get all knonw GPG keys in the directory that belong to entry
|
|
// matching the the Auth Profile
|
|
func (l *LdapHandler) PgpKeys() ([]string, error) {
|
|
// try the cache first
|
|
ret := l.authCache.PgpGet()
|
|
if len(ret) > 0 {
|
|
return ret, nil
|
|
}
|
|
// get a conn from the pool
|
|
conn, err := l.GetConn()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// and put it back
|
|
defer l.BackToPool(conn)
|
|
|
|
// the search parameters
|
|
searchRequest := ldap.NewSearchRequest(
|
|
l.baseDN,
|
|
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
|
"(objectClass=pgpUserKeyInfo)",
|
|
[]string{l.attribute, l.pgpAttribute},
|
|
nil,
|
|
)
|
|
sr, err := conn.Search(searchRequest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, entry := range sr.Entries {
|
|
key := ""
|
|
valid := false
|
|
for _, attribute := range entry.Attributes {
|
|
switch (*attribute).Name {
|
|
case l.pgpAttribute:
|
|
key = attribute.Values[0]
|
|
case l.attribute:
|
|
if inArray(l.validValues, attribute.Values) {
|
|
valid = true
|
|
}
|
|
}
|
|
}
|
|
if valid && key != "" {
|
|
ret = append(ret, l.PgpNormalizeKey(key))
|
|
}
|
|
}
|
|
if len(ret) > 0 {
|
|
l.authCache.PgpSet(ret)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// Auth method fot LDAP
|
|
// We check that the subject exists in the ldap
|
|
// If it's the case, we search for the attribute defined in the LdapHandler
|
|
// This attribute's value must then be one of the registered value in the LdapHandler
|
|
func (l *LdapHandler) Auth(subject string) (bool, error) {
|
|
// use the cache if possible
|
|
if l.authCache.Get(subject) {
|
|
return true, nil
|
|
}
|
|
// get a conn from the pool
|
|
conn, err := l.GetConn()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
// and put it back
|
|
defer l.BackToPool(conn)
|
|
|
|
// the search parameters
|
|
searchRequest := ldap.NewSearchRequest(
|
|
l.baseDN,
|
|
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
|
fmt.Sprintf(l.searchFilter, subject),
|
|
[]string{l.attribute},
|
|
nil,
|
|
)
|
|
sr, err := conn.Search(searchRequest)
|
|
|
|
// if the search failed or returns more than one entry, it's not valid
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if len(sr.Entries) != 1 {
|
|
return false, errors.New("User does not exist or too many entries returned")
|
|
}
|
|
// validate the values returned
|
|
for _, attribute := range sr.Entries[0].Attributes {
|
|
if (*attribute).Name == l.attribute {
|
|
if inArray(l.validValues, attribute.Values) {
|
|
l.authCache.Set(subject)
|
|
return true, nil
|
|
}
|
|
}
|
|
}
|
|
// we found nothing, auth failed
|
|
return false, nil
|
|
}
|
|
|
|
// NewConn creates a ldap client object
|
|
func (l *LdapHandler) NewConn() (LdapClient, error) {
|
|
var err error
|
|
var conn *ldap.Conn
|
|
// the server list cannot be empty
|
|
if len(l.servers) == 0 {
|
|
return nil, errors.New("Empty server list")
|
|
}
|
|
// circle the server list
|
|
for i := 0; i < len(l.servers); i++ {
|
|
l.currentServer++
|
|
if l.currentServer >= len(l.servers) {
|
|
l.currentServer = 0
|
|
}
|
|
if l.ssl {
|
|
conn, err = ldap.DialTLS("tcp",
|
|
fmt.Sprintf("%s:%d", l.servers[l.currentServer], 636),
|
|
&tls.Config{ServerName: l.servers[l.currentServer]})
|
|
} else {
|
|
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d",
|
|
l.servers[l.currentServer], 389))
|
|
}
|
|
// if the connection fails, maybe there is another server available
|
|
if err != nil {
|
|
log.Println(err)
|
|
continue
|
|
}
|
|
// First bind with a read only user
|
|
// if it work, we are done
|
|
if err = conn.Bind(l.bindCn, l.bindPw); err == nil {
|
|
return conn, nil
|
|
}
|
|
}
|
|
// no working server were found
|
|
return nil, errors.New("No valid ldap server found")
|
|
}
|
|
|
|
// GetConn retrieves a new connection from the pool
|
|
func (l *LdapHandler) GetConn() (LdapClient, error) {
|
|
// check if the pool is valid
|
|
if l.clients == nil {
|
|
return nil, errors.New("Pool is closed")
|
|
}
|
|
select {
|
|
case conn := <-l.clients:
|
|
if conn == nil {
|
|
return nil, errors.New("Pool is closed")
|
|
}
|
|
if ldapIsAlive(conn) {
|
|
return conn, nil
|
|
}
|
|
// dead connection, restart it
|
|
conn.Close()
|
|
return l.NewConn()
|
|
default:
|
|
// No more conn in Pool, create a new one and return it
|
|
return l.NewConn()
|
|
}
|
|
}
|
|
|
|
// BackToPool returns a connection to the pool
|
|
func (l *LdapHandler) BackToPool(p LdapClient) error {
|
|
// if it's nil, stop here
|
|
if p == nil {
|
|
return errors.New("Connexion is closed")
|
|
}
|
|
// check if it's alive. If not, no need to put it back
|
|
if !ldapIsAlive(p) {
|
|
p.Close()
|
|
return errors.New("returned connection was closed")
|
|
}
|
|
// same if the pool is not active
|
|
if l.clients == nil {
|
|
p.Close()
|
|
return errors.New("Pool is closed")
|
|
}
|
|
select {
|
|
case l.clients <- p:
|
|
return nil
|
|
default:
|
|
// Pool is full
|
|
p.Close()
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// make a dummy request to validate that the server is alive
|
|
func ldapIsAlive(client LdapClient) bool {
|
|
_, err := client.Search(&ldap.SearchRequest{BaseDN: "", Scope: ldap.ScopeBaseObject, Filter: "(&)", Attributes: []string{"1.1"}})
|
|
return err == nil
|
|
}
|
|
|
|
// Len returns information about the pool
|
|
func (l *LdapHandler) Len() int {
|
|
return len(l.clients)
|
|
}
|
|
|
|
// Cap returns cap information about the pool
|
|
func (l *LdapHandler) Cap() int {
|
|
return cap(l.clients)
|
|
}
|
|
|
|
// PgpNormalizeKey replace space with new line in ldap stored GPG keys
|
|
func (l *LdapHandler) PgpNormalizeKey(key string) string {
|
|
var rkey, sep string
|
|
for _, sub := range strings.Split(key, " ") {
|
|
rkey = rkey + sep + sub
|
|
|
|
if strings.HasPrefix(sub, "---") {
|
|
sep = " "
|
|
}
|
|
if strings.HasSuffix(sub, "---") {
|
|
sep = "\n"
|
|
}
|
|
}
|
|
return rkey
|
|
}
|