openvpn-mgt/ldap.go

174 lines
4.4 KiB
Go

package main
import (
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"strings"
"time"
"gopkg.in/ldap.v2"
)
type ldapConfig struct {
servers []string
baseDN string
bindCn string
bindPw string
searchFilter string
primaryAttribute string
secondaryAttribute string
validGroups []string
mfaType string
certAuth string
ipMin net.IP
ipMax net.IP
upgradeFrom string
routes []string
}
func (l *ldapConfig) addIPRange(s string) error {
ips := strings.Split(s, "-")
if len(ips) != 2 {
return errors.New("invalid IPs")
}
if ip := net.ParseIP(ips[0]); ip != nil {
l.ipMin = ip
}
if ip := net.ParseIP(ips[1]); ip != nil {
l.ipMax = ip
}
return nil
}
// override the real DialTLS function
func myDialTLS(network, addr string, config *tls.Config) (*ldap.Conn, error) {
dc, err := net.DialTimeout(network, addr, 3*time.Second)
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
c := tls.Client(dc, config)
if err = c.Handshake(); err != nil {
// Handshake error, close the established connection before we return an error
dc.Close()
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
conn := ldap.NewConn(c, true)
conn.Start()
return conn, nil
}
func (conf *ldapConfig) Auth(logins []string, pass string) (e error, userOk, passOk bool, attributes []string) {
var primary, secondary []string
// special case. This configuration is a filter on the previous one
if len(conf.servers) == 0 && len(conf.validGroups) > 0 {
if inArray(logins, conf.validGroups) {
return nil, true, false, logins
}
}
// no server ldap or multiple login should not happen here
if len(logins) != 1 || len(conf.servers) == 0 {
return nil, false, false, nil
}
attributes = logins
for _, s := range conf.servers {
// we force ldaps because we can
l, err := myDialTLS("tcp", s+":636", &tls.Config{ServerName: s})
if err != nil {
log.Println(err)
continue
}
defer l.Close()
// First bind with a read only user
if err = l.Bind(conf.bindCn, conf.bindPw); err != nil {
log.Println(err)
return err, false, false, nil
}
search := []string{"dn", conf.primaryAttribute}
if conf.secondaryAttribute != "" {
search = append(search, conf.secondaryAttribute)
}
// search the user
searchRequest := ldap.NewSearchRequest(
conf.baseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
fmt.Sprintf(conf.searchFilter, logins[0]),
search,
nil,
)
sr, err := l.Search(searchRequest)
if err != nil {
log.Println(err)
return err, false, false, nil
}
if len(sr.Entries) != 1 {
return errors.New("User does not exist or too many entries returned"), false, false, nil
}
// check the attributes requested in the search
// a valid account must be part of the correct group (per instance)
for _, attribute := range sr.Entries[0].Attributes {
if (*attribute).Name == conf.primaryAttribute {
primary = attribute.Values
}
if (*attribute).Name == conf.secondaryAttribute {
secondary = attribute.Values
}
}
// user must have both primary and secondary attributes
if len(primary) == 0 {
log.Printf("User %s has no %s attribute", logins[0], conf.primaryAttribute)
return nil, false, false, nil
}
if len(secondary) == 0 {
log.Printf("User %s has no %s attribute", logins[0], conf.secondaryAttribute)
return nil, false, false, nil
}
// check if the primary attributes are in the validGroups list
if len(conf.validGroups) > 0 && !inArray(conf.validGroups, primary) {
return nil, false, false, nil
}
// if there is no validGroups check, pass the primary attributes to the
// next level
if len(conf.validGroups) == 0 {
attributes = primary
} else {
attributes = secondary
}
log.Printf("User %s has a valid account on %s", logins[0], s)
userdn := sr.Entries[0].DN
// if the password is empty, stop here
if pass == "" {
return nil, true, false, attributes
}
// if there is an error, it's because the password is invalid
if err = l.Bind(userdn, pass); err != nil {
return nil, true, false, attributes
}
// everything is fine,
log.Printf("User %s has a valid password on %s", logins[0], s)
return nil, true, true, attributes
}
// if we are here, no server is responding, rejectif auth
log.Println("can't join any ldap server")
return
}