Initial euclide.org release
This commit is contained in:
commit
97379c8e8a
|
@ -0,0 +1,7 @@
|
|||
images:
|
||||
- name: pdns-auth-proxy
|
||||
build_args: .
|
||||
repositories:
|
||||
quayio:
|
||||
owners:
|
||||
- tribeinfrastructure
|
|
@ -0,0 +1,41 @@
|
|||
pdns-auth-proxy
|
||||
pdns-proxy.conf
|
||||
.DS_Store
|
||||
localtest.conf
|
||||
.*.swp
|
||||
coverage.out
|
||||
fixtures/replay/commands.asc
|
||||
fixtures/test/badclient-cert.pem
|
||||
fixtures/test/badclient-csr.pem
|
||||
fixtures/test/ca.crt
|
||||
fixtures/test/ca.srl
|
||||
fixtures/test/client-cert.pem
|
||||
fixtures/test/client-csr.pem
|
||||
fixtures/test/client-key.pem
|
||||
fixtures/test/myCA.key
|
||||
fixtures/test/server-cert.pem
|
||||
fixtures/test/server-csr.pem
|
||||
fixtures/test/server-key.pem
|
||||
fixtures/test/ldap-cert.pem
|
||||
fixtures/test/ldap-csr.pem
|
||||
fixtures/test/ldap-key.pem
|
||||
fixtures/test/webclient-cert.pem
|
||||
fixtures/test/webclient-csr.pem
|
||||
fixtures/test/webclient-key.pem
|
||||
fixtures/test/public-key.txt
|
||||
fixtures/test/crlnumber
|
||||
fixtures/test/crlnumber.old
|
||||
fixtures/test/index.txt
|
||||
fixtures/test/index.txt.attr
|
||||
fixtures/test/index.txt.old
|
||||
fixtures/test/root.crl.pem
|
||||
debug
|
||||
client/client
|
||||
debian/pdns-auth-proxy.debhelper.log
|
||||
debian/pdns-auth-proxy.postinst.debhelper
|
||||
debian/pdns-auth-proxy.postrm.debhelper
|
||||
debian/pdns-auth-proxy.prerm.debhelper
|
||||
debian/pdns-auth-proxy.substvars
|
||||
debian/files
|
||||
build-stamp
|
||||
fixtures/pdns/pdns.dump
|
|
@ -0,0 +1,14 @@
|
|||
FROM golang:1.12
|
||||
|
||||
ENV GOPATH $HOME/workspace/gopath
|
||||
ENV GOROOT /usr/local/go
|
||||
ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH
|
||||
RUN mkdir -p $GOPATH
|
||||
RUN mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH"
|
||||
|
||||
RUN go install golang.org/x/lint/golint@latest
|
||||
|
||||
COPY . /src
|
||||
WORKDIR /src
|
||||
|
||||
RUN make vendor
|
|
@ -0,0 +1,10 @@
|
|||
# build stage
|
||||
FROM golang:1.12 AS build-env
|
||||
ADD . /pdns-auth-proxy
|
||||
RUN cd /pdns-auth-proxy && make build
|
||||
|
||||
# final stage
|
||||
FROM alpine
|
||||
WORKDIR /app
|
||||
COPY --from=build-env /pdns-auth-proxy/pdns-auth-proxy /pdns-auth-proxy/
|
||||
ENTRYPOINT /pdns-auth-proxy/pdns-auth-proxy /pdns-auth-proxy/proxy.ini
|
|
@ -0,0 +1,66 @@
|
|||
#!/usr/bin/env groovy
|
||||
@Library('jarvis') _
|
||||
|
||||
pipeline {
|
||||
agent {
|
||||
label 'westeros-agent'
|
||||
}
|
||||
options {
|
||||
buildDiscarder(logRotator(numToKeepStr: '5'))
|
||||
durabilityHint('PERFORMANCE_OPTIMIZED')
|
||||
timestamps()
|
||||
}
|
||||
environment {
|
||||
def alfred = null
|
||||
def repo_name = sh returnStdout: true, script: 'basename -s .git $(git config --get remote.origin.url)'
|
||||
def version = sh returnStdout: true, script: 'git describe --tags'
|
||||
VERSION = "${version}"
|
||||
SLACK_CHANNEL = 'admin-install'
|
||||
}
|
||||
stages {
|
||||
stage("Alfred") {
|
||||
steps {
|
||||
script {
|
||||
alfred = getAlfredConfig()
|
||||
sh 'make -f Makefile.ci init'
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Build") {
|
||||
steps {
|
||||
sh 'make -f Makefile.ci build'
|
||||
}
|
||||
}
|
||||
stage("Quality") {
|
||||
steps {
|
||||
parallel(
|
||||
format: {
|
||||
sh 'make -f Makefile.ci fmt'
|
||||
},
|
||||
lint: {
|
||||
sh 'make -f Makefile.ci lint'
|
||||
},
|
||||
test: {
|
||||
sh 'make -f Makefile.ci test'
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
stage('TO DO: build a debian package & publish it in packages.dm.gg ') {
|
||||
when {
|
||||
branch "master"
|
||||
}
|
||||
steps {
|
||||
echo "Expecting something here"
|
||||
}
|
||||
}
|
||||
}
|
||||
post {
|
||||
success {
|
||||
sendPipelineStatusToSlack('SUCCESS', "Gitlab - Docker - ${repo_name}")
|
||||
}
|
||||
failure {
|
||||
sendPipelineStatusToSlack('FAILURE', "Gitlab - Docker - ${repo_name}")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
# -*- mode: makefile -*-
|
||||
SHELL := /bin/bash
|
||||
PREFIX?=$(shell pwd)
|
||||
|
||||
# Setup name variables for the package/tool
|
||||
PROJECTNAME := pdns-auth-proxy
|
||||
|
||||
# Set our default go compiler
|
||||
GO := go
|
||||
|
||||
VERSION := $(shell git describe --tags)
|
||||
GITCOMMIT := $(shell git rev-parse --short HEAD)
|
||||
PKG := git.euclide.org/euclide/$(PROJECTNAME)
|
||||
|
||||
CTIMEVAR=-X $(PKG)/version.GITCOMMIT=$(GITCOMMIT) -X $(PKG)/version.VERSION=$(VERSION)
|
||||
GO_LDFLAGS=-ldflags "-w $(CTIMEVAR)"
|
||||
GO_LDFLAGS_STATIC=-ldflags "-w $(CTIMEVAR) -extldflags -static"
|
||||
GOCACHE=off
|
||||
|
||||
BUILDDIR := ${PREFIX}/build
|
||||
DOCKER_IMAGE := $(docker images --format "{{.Repository}}" --filter=reference='$(PROJECTNAME):$(GITCOMMIT)')
|
||||
GOOSARCHES := linux/amd64
|
||||
|
||||
all: build fmt lint test
|
||||
|
||||
.PHONY: fmt
|
||||
fmt: ## Verifies all files have been `gofmt`ed.
|
||||
@echo "+ $@"
|
||||
$(eval var = $(shell gofmt -s -l . | grep -v vendor | tee /dev/stderr))
|
||||
@if [ ! -z "${var}" ]; then exit 1; fi
|
||||
$(MAKE) -C client fmt
|
||||
|
||||
.PHONY: lint
|
||||
lint: ## Verifies `golint` passes.
|
||||
@echo "+ $@"
|
||||
$(eval var = $(shell golint ./... | grep -v vendor | tee /dev/stderr))
|
||||
@if [ ! -z "${var}" ]; then exit 1; fi
|
||||
$(MAKE) -C client lint
|
||||
|
||||
.PHONY: tag
|
||||
tag: ## Create a new git tag to prepare to build a release
|
||||
git tag -sa $(VERSION) -m "$(VERSION)"
|
||||
@echo "Run git push origin $(VERSION) to push your new tag to GitHub and trigger a Jenkins build."
|
||||
|
||||
.PHONY: test
|
||||
test: ## Generates test certificates & run unit tests.
|
||||
$(MAKE) -C fixtures/test
|
||||
go vet -mod=vendor
|
||||
go test -v -mod=vendor -coverprofile=coverage.out
|
||||
$(MAKE) -C client test
|
||||
|
||||
.PHONY: resetreplay
|
||||
resetreplay:
|
||||
rm -f fixtures/replay/record
|
||||
$(MAKE) clean
|
||||
$(MAKE) test
|
||||
|
||||
.PHONY: deb
|
||||
deb:
|
||||
debuild -e GOROOT -e PATH -i -us -uc -b
|
||||
|
||||
.PHONY: build
|
||||
build: ## Builds a static executable.
|
||||
@echo "+ $@"
|
||||
CGO_ENABLED=0 $(GO) build -mod=vendor \
|
||||
-o $(PROJECTNAME) \
|
||||
-tags "static_build netgo" \
|
||||
-trimpath \
|
||||
-installsuffix netgo ${GO_LDFLAGS_STATIC} .;
|
||||
strip $(PROJECTNAME) 2>/dev/null || echo
|
||||
$(MAKE) -C client build
|
||||
|
||||
.PHONY: vendor
|
||||
vendor: ## Updates the vendoring directory.
|
||||
# @$(RM) go.sum
|
||||
@$(RM) -r vendor
|
||||
GOPRIVATE=git.euclide.org $(GO) mod init $(PKG) || true
|
||||
GOPRIVATE=git.euclide.org $(GO) mod tidy
|
||||
GOPRIVATE=git.euclide.org $(GO) mod vendor
|
||||
@$(RM) Gopkg.toml Gopkg.lock
|
||||
|
||||
.PHONY: clean
|
||||
clean: ## Cleanup any build binaries or packages.
|
||||
@echo "+ $@"
|
||||
$(RM) $(PROJECTNAME) debug
|
||||
$(RM) -r $(BUILDDIR)
|
||||
$(MAKE) -C fixtures/test clean
|
||||
$(MAKE) -C client clean
|
|
@ -0,0 +1,46 @@
|
|||
# -*- mode: makefile -*-
|
||||
SHELL := /bin/sh
|
||||
PROJECTNAME := pdns-auth-proxy
|
||||
VERSION := $(shell git rev-parse --short HEAD)
|
||||
TO := _
|
||||
|
||||
ifdef BUILD_NUMBER
|
||||
NUMBER = $(BUILD_NUMBER)
|
||||
else
|
||||
NUMBER = 1
|
||||
endif
|
||||
|
||||
ifdef JOB_BASE_NAME
|
||||
PROJECT_ENCODED_SLASH = $(subst %2F,$(TO),$(JOB_BASE_NAME))
|
||||
PROJECT = $(subst /,$(TO),$(PROJECT_ENCODED_SLASH))
|
||||
# Run on CI
|
||||
COMPOSE = docker-compose -f docker-compose.yml -f docker-compose.ci.yml -p $(PROJECTNAME)_$(PROJECT)_$(NUMBER)
|
||||
else
|
||||
# Run Locally
|
||||
COMPOSE = docker-compose -p $(PROJECTNAME)
|
||||
endif
|
||||
|
||||
.PHONY: init
|
||||
init:
|
||||
# This following command is used to provision the network
|
||||
$(COMPOSE) up --no-start --no-build app | true
|
||||
|
||||
.PHONY: build
|
||||
build:
|
||||
$(COMPOSE) run build
|
||||
|
||||
.PHONY: fmt
|
||||
format:
|
||||
$(COMPOSE) run fmt
|
||||
|
||||
.PHONY: lint
|
||||
lint:
|
||||
$(COMPOSE) run lint
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
$(COMPOSE) run test
|
||||
|
||||
.PHONY: down
|
||||
down:
|
||||
$(COMPOSE) down --volumes
|
|
@ -0,0 +1,242 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pyke369/golang-support/rcache"
|
||||
"golang.org/x/crypto/openpgp"
|
||||
)
|
||||
|
||||
// AuthProfile interface
|
||||
type AuthProfile interface {
|
||||
Match(string) (bool, error)
|
||||
PgpKeys() openpgp.EntityList
|
||||
}
|
||||
|
||||
// AuthProfileRegexp implements auth based on regexp only
|
||||
type AuthProfileRegexp struct {
|
||||
subjectRegexp *regexp.Regexp
|
||||
pgpKeys openpgp.EntityList
|
||||
}
|
||||
|
||||
// AuthProfileLdap implements auth based on regexp then ldap check
|
||||
type AuthProfileLdap struct {
|
||||
subjectRegexp *regexp.Regexp
|
||||
ldap *LdapHandler
|
||||
}
|
||||
|
||||
// JSONRPCACL represents the acls for the json RPC API
|
||||
type JSONRPCACL struct {
|
||||
actions map[string][]*regexp.Regexp
|
||||
pgpProfiles []string
|
||||
sslProfiles []string
|
||||
}
|
||||
|
||||
// NewJSONRPCACL instanciates json RPC API ACLs.
|
||||
func NewJSONRPCACL(regexps map[string][]string, pgpProfiles, sslProfiles []string) *JSONRPCACL {
|
||||
realPerms := map[string][]*regexp.Regexp{}
|
||||
for action, reg := range regexps {
|
||||
realPerms[action] = []*regexp.Regexp{}
|
||||
for _, r := range reg {
|
||||
realPerms[action] = append(realPerms[action], rcache.Get(fmt.Sprintf("^%s$", r)))
|
||||
}
|
||||
}
|
||||
a := JSONRPCACL{
|
||||
actions: realPerms,
|
||||
pgpProfiles: pgpProfiles,
|
||||
sslProfiles: sslProfiles,
|
||||
}
|
||||
return &a
|
||||
}
|
||||
|
||||
// GetListFilters returns the restrictions on the list action (or * if // defined)
|
||||
func (a JSONRPCACL) GetListFilters(pgpProfiles, sslProfiles []string, method string) []*regexp.Regexp {
|
||||
var ret []*regexp.Regexp
|
||||
// ignore acl for which the profile doesn't match
|
||||
if !inArray(sslProfiles, a.sslProfiles) && !inArray(pgpProfiles, a.pgpProfiles) {
|
||||
return nil
|
||||
}
|
||||
for aclAction, targets := range a.actions {
|
||||
if aclAction == "*" || aclAction == method {
|
||||
ret = append(ret, targets...)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Match method for a json RPC ACL
|
||||
func (a JSONRPCACL) Match(action, target string, pgpProfiles, sslProfiles []string) bool {
|
||||
// ignore acl for which the profile doesn't match
|
||||
if !inArray(sslProfiles, a.sslProfiles) && !inArray(pgpProfiles, a.pgpProfiles) {
|
||||
return false
|
||||
}
|
||||
// the profile match somehow, let's test the action
|
||||
for aclAction, targets := range a.actions {
|
||||
// action must match
|
||||
if aclAction != action && aclAction != "*" {
|
||||
continue
|
||||
}
|
||||
// and target must match too
|
||||
for _, reg := range targets {
|
||||
if reg.MatchString(target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PdnsACL represents the acls for direct access to the PDNS API
|
||||
type PdnsACL struct {
|
||||
regexp *regexp.Regexp
|
||||
read bool // Allow GET method
|
||||
write bool // Allow every other method
|
||||
profiles []string
|
||||
}
|
||||
|
||||
// NewPdnsACL instanciates ACLs.
|
||||
func NewPdnsACL(regexp string, perms []string, profiles []string) *PdnsACL {
|
||||
a := PdnsACL{
|
||||
regexp: rcache.Get(fmt.Sprintf("^%s$", regexp)),
|
||||
profiles: profiles,
|
||||
read: false,
|
||||
write: false,
|
||||
}
|
||||
for _, v := range perms {
|
||||
switch v {
|
||||
case "r":
|
||||
a.read = true
|
||||
case "w":
|
||||
a.write = true
|
||||
}
|
||||
}
|
||||
return &a
|
||||
}
|
||||
|
||||
// Match method for an ACL
|
||||
func (a PdnsACL) Match(path, method string, profiles []string) bool {
|
||||
// ignore acl for which the profile doesn't match
|
||||
if !inArray(profiles, a.profiles) {
|
||||
return false
|
||||
}
|
||||
// check the method is valid
|
||||
if (method == "GET" && !a.read) || (method != "GET" && !a.write) {
|
||||
return false
|
||||
}
|
||||
return a.regexp.MatchString(path)
|
||||
}
|
||||
|
||||
// NewAuthProfileRegexp instanciates a regexp based profile
|
||||
func NewAuthProfileRegexp(subjectRegexp, pgpKeys string) AuthProfileRegexp {
|
||||
p := AuthProfileRegexp{
|
||||
subjectRegexp: rcache.Get(fmt.Sprintf("^%s$", subjectRegexp)),
|
||||
}
|
||||
keyring, err := openpgp.ReadArmoredKeyRing(strings.NewReader(pgpKeys))
|
||||
if err == nil {
|
||||
p.pgpKeys = keyring
|
||||
} else {
|
||||
p.pgpKeys = openpgp.EntityList{}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Match method for Regexp based profile
|
||||
func (p AuthProfileRegexp) Match(subject string) (bool, error) {
|
||||
return p.subjectRegexp.MatchString(subject), nil
|
||||
}
|
||||
|
||||
// PgpKeys Method don't apply for a Regexp based profile
|
||||
func (p AuthProfileRegexp) PgpKeys() openpgp.EntityList {
|
||||
return p.pgpKeys
|
||||
}
|
||||
|
||||
// NewAuthProfileLdap instanciates ldap based profile
|
||||
func NewAuthProfileLdap(subjectRegexp string, servers []string,
|
||||
bindCn, bindPw, baseDN, filter, attr, pgpAttr string, valid []string, ssl bool) (AuthProfileLdap, error) {
|
||||
p := AuthProfileLdap{
|
||||
subjectRegexp: rcache.Get(fmt.Sprintf("^%s$", subjectRegexp)),
|
||||
ldap: NewLdap(servers, bindCn, bindPw, baseDN, filter, attr, pgpAttr, valid, 10, ssl),
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Match Method for ldap based profile
|
||||
func (p AuthProfileLdap) Match(subject string) (bool, error) {
|
||||
if !p.subjectRegexp.MatchString(subject) {
|
||||
return false, nil
|
||||
}
|
||||
return p.ldap.Auth(subject)
|
||||
}
|
||||
|
||||
// PgpKeys Method for ldap based profile
|
||||
func (p AuthProfileLdap) PgpKeys() openpgp.EntityList {
|
||||
var kr openpgp.EntityList
|
||||
ret, err := p.ldap.PgpKeys()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil
|
||||
}
|
||||
if ret == nil {
|
||||
return kr
|
||||
}
|
||||
for _, key := range ret {
|
||||
k, err := openpgp.ReadArmoredKeyRing(strings.NewReader(key))
|
||||
if err != nil {
|
||||
log.Println("Error reading Armored Key: ", err)
|
||||
continue
|
||||
}
|
||||
kr = append(kr, k...)
|
||||
}
|
||||
return kr
|
||||
}
|
||||
|
||||
// NewSalt generate a nonce
|
||||
func NewSalt(size int) string {
|
||||
var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
b := make([]rune, size)
|
||||
for i := range b {
|
||||
b[i] = letterRunes[rand.Intn(len(letterRunes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ComputeHmac256 return a HMAC, base64 encoded
|
||||
func ComputeHmac256(message string, secret string) string {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(message))
|
||||
return strings.TrimRight(base64.StdEncoding.EncodeToString(h.Sum(nil)), "=")
|
||||
}
|
||||
|
||||
// PgpMessageVerify checks a message signature
|
||||
func PgpMessageVerify(msg []byte, sig []byte, keyring openpgp.EntityList) string {
|
||||
// no keyring, no need to try to decode
|
||||
if len(keyring) == 0 {
|
||||
return ""
|
||||
}
|
||||
signature := bytes.NewReader(sig)
|
||||
message := bytes.NewReader(msg)
|
||||
|
||||
entity, err := openpgp.CheckDetachedSignature(keyring, message, signature)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return ""
|
||||
}
|
||||
for _, a := range entity.Identities {
|
||||
e, err := mail.ParseAddress(a.Name)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return a.Name
|
||||
}
|
||||
return strings.Split(e.Address, "@")[0]
|
||||
}
|
||||
return ""
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuthCache structure, to avoid hitting to hard on the ldap servers
|
||||
type AuthCache struct {
|
||||
m sync.RWMutex
|
||||
users map[string]time.Time
|
||||
gpgKeys []string
|
||||
gpgTime time.Time
|
||||
}
|
||||
|
||||
// lock the structure
|
||||
func (a *AuthCache) lock() {
|
||||
a.m.Lock()
|
||||
}
|
||||
|
||||
// or unlock it
|
||||
func (a *AuthCache) unlock() {
|
||||
a.m.Unlock()
|
||||
}
|
||||
|
||||
// NewAuthCache initializes the cache
|
||||
func NewAuthCache() *AuthCache {
|
||||
return &AuthCache{
|
||||
users: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Get checks if a user is cached
|
||||
// and reinitialize counter to add another minute
|
||||
func (a *AuthCache) Get(user string) bool {
|
||||
a.lock()
|
||||
defer a.unlock()
|
||||
now := time.Now()
|
||||
if added, ok := a.users[user]; ok {
|
||||
if added.Add(5 * time.Minute).After(now) {
|
||||
a.users[user] = now
|
||||
return true
|
||||
}
|
||||
delete(a.users, user)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Set marks the user valid for a minute
|
||||
func (a *AuthCache) Set(user string) {
|
||||
a.lock()
|
||||
defer a.unlock()
|
||||
a.users[user] = time.Now()
|
||||
}
|
||||
|
||||
// PgpGet the pgp Keys if cached
|
||||
// and reinitialize the counter to add another minute
|
||||
func (a *AuthCache) PgpGet() []string {
|
||||
now := time.Now()
|
||||
if len(a.gpgKeys) > 0 && a.gpgTime.Add(5*time.Minute).After(now) {
|
||||
a.gpgTime = now
|
||||
return a.gpgKeys
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// PgpSet store the pgp Keys for a minute
|
||||
func (a *AuthCache) PgpSet(keys []string) {
|
||||
a.lock()
|
||||
defer a.unlock()
|
||||
a.gpgKeys = keys
|
||||
a.gpgTime = time.Now()
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/pyke369/golang-support/uconfig"
|
||||
)
|
||||
|
||||
var fakeserver *HTTPServer
|
||||
|
||||
func init() {
|
||||
config, err := uconfig.New("pdns-proxy-test.conf")
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fakeserver = NewHTTPServer(
|
||||
"127.0.0.01:8080",
|
||||
"fixtures/test/server-key.pem",
|
||||
"fixtures/test/server-cert.pem",
|
||||
"",
|
||||
"fixtures/test/ca.crt",
|
||||
"",
|
||||
"",
|
||||
3,
|
||||
3600,
|
||||
)
|
||||
populateHTTPServerAcls(config, fakeserver)
|
||||
populateHTTPServerProfiles(config, fakeserver)
|
||||
|
||||
go fakeserver.Run()
|
||||
}
|
||||
|
||||
// Testing Functions
|
||||
|
||||
func TestAuthRegexpProfiles(t *testing.T) {
|
||||
v := fakeserver.getProfiles("validserver")
|
||||
if len(v) != 1 || v[0] != "testS2S" {
|
||||
t.Errorf("cannot valid test profile")
|
||||
}
|
||||
v = fakeserver.getProfiles("bidule.example.com")
|
||||
if len(v) != 0 {
|
||||
t.Errorf("too many validations")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
for k, testCase := range []struct {
|
||||
path, user, method string
|
||||
expected bool
|
||||
message string
|
||||
}{
|
||||
{"zones/specificdomain.example", "validserver", "POST", true, "valid user and path"},
|
||||
{"zones/dev.example.com", "validserver", "GET", true, "valid user and path"},
|
||||
{"zones/developper.example.com", "validserver", "GET", false, "valid user and invalid path"},
|
||||
{"zones/otherdomain.example", "validserver", "GET", false, "valid user and invalid path"},
|
||||
{"zones/dev.example.com", "validserver", "POST", false, "valid user, valid path, invalid method"},
|
||||
{"zones/specificdomain.example", "invalidserver", "GET", false, "invalid user"},
|
||||
} {
|
||||
if fakeserver.nativeValidAuth(testCase.path, testCase.user, testCase.method) != testCase.expected {
|
||||
result := "worked"
|
||||
if testCase.expected {
|
||||
result = "failed"
|
||||
}
|
||||
t.Errorf("%s but it %s on URL %d", testCase.message, result, k+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInArray(t *testing.T) {
|
||||
for k, testCase := range []struct {
|
||||
search []string
|
||||
list []string
|
||||
expected bool
|
||||
message string
|
||||
}{
|
||||
{[]string{"apple", "orange"}, []string{"kiwi", "apple"}, true, "apple should have matched"},
|
||||
{[]string{"apple", "orange"}, []string{"kiwi", "apple"}, true, "apple should have matched"},
|
||||
{[]string{"citrus", "orange"}, []string{"kiwi", "apple"}, false, "nothing should have matched"},
|
||||
} {
|
||||
if inArray(testCase.search, testCase.list) != testCase.expected {
|
||||
result := "worked"
|
||||
if testCase.expected {
|
||||
result = "failed"
|
||||
}
|
||||
t.Errorf("%s but it %s on test %d", testCase.message, result, k+1)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
# -*- mode: makefile -*-
|
||||
|
||||
PROJECTNAME := client
|
||||
GO_LDFLAGS_STATIC=-ldflags "-w $(CTIMEVAR) -extldflags -static"
|
||||
GO := go
|
||||
|
||||
all: build fmt lint test
|
||||
|
||||
.PHONY: fmt
|
||||
fmt: ## Verifies all files have been `gofmt`ed.
|
||||
@echo "+ $@"
|
||||
$(eval var = $(shell gofmt -s -l . | grep -v vendor | tee /dev/stderr))
|
||||
@if [ ! -z "${var}" ]; then exit 1; fi
|
||||
|
||||
|
||||
.PHONY: build
|
||||
build: ## Builds a static executable.
|
||||
@echo "+ $@"
|
||||
CGO_ENABLED=0 $(GO) build -mod=vendor \
|
||||
-o $(PROJECTNAME) \
|
||||
-tags "static_build netgo" \
|
||||
-installsuffix netgo ${GO_LDFLAGS_STATIC} .;
|
||||
|
||||
|
||||
.PHONY: test
|
||||
test: ## Generates test certificates & run unit tests.
|
||||
go vet -mod=vendor
|
||||
go test -v -mod=vendor
|
||||
|
||||
|
||||
.PHONY: lint
|
||||
lint: ## Verifies `golint` passes.
|
||||
@echo "+ $@"
|
||||
$(eval var = $(shell golint ./... | grep -v vendor | tee /dev/stderr))
|
||||
@if [ ! -z "${var}" ]; then exit 1; fi
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -f $(PROJECTNAME)
|
|
@ -0,0 +1,719 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// the default API url
|
||||
defaultAPI = "https://localhost:8443"
|
||||
)
|
||||
|
||||
type (
|
||||
// json definitions for the web API
|
||||
jsonArray []*jsonCommand
|
||||
jsonCommand struct {
|
||||
Method string `json:"method"`
|
||||
ID int `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Params jsonCommandParams `json:"params"`
|
||||
args []string
|
||||
}
|
||||
jsonCommandParams struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
TTL int `json:"ttl"`
|
||||
ForceReverse bool `json:"reverse"`
|
||||
Append bool `json:"append"`
|
||||
Comment string `json:"comment"`
|
||||
DryRun bool `json:"dry-run"`
|
||||
IgnoreError bool `json:"ignore-error"`
|
||||
Nonce string `json:"nonce"`
|
||||
Debug bool `json:"-"`
|
||||
}
|
||||
apiResponse struct {
|
||||
ID int `json:"id"`
|
||||
Result []struct {
|
||||
Changes string `json:"changes"`
|
||||
Comment string `json:"comment"`
|
||||
Result string `json:"result"`
|
||||
} `json:"result"`
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
// structure for command line options
|
||||
command struct {
|
||||
args []string
|
||||
descr string
|
||||
options []string
|
||||
check func(*jsonCommand) bool
|
||||
}
|
||||
)
|
||||
|
||||
// global vars
|
||||
var (
|
||||
dnsAPI string
|
||||
httpClient http.Client
|
||||
commands map[string]command
|
||||
generalFlags []string
|
||||
)
|
||||
|
||||
// Populate the generalFlags, commands and dnsAPI variables
|
||||
func init() {
|
||||
generalFlags = []string{"dry-run", "comment"}
|
||||
commands = map[string]command{
|
||||
"a": {
|
||||
[]string{"<record>", "<ipv4>"},
|
||||
"Create a A record, points to an IPv4. If the IP has no reverse, create it with the record value (won't apply for wildcard)",
|
||||
[]string{"append", "reverse", "ttl"}, cmdCheckIPv4},
|
||||
"aaaa": {
|
||||
[]string{"<record>", "<ipv6>"},
|
||||
"Create a AAAA record, points to an IPv6. If the IP has no reverse, create it with the record value (won't apply for wildcard)",
|
||||
[]string{"append", "reverse", "ttl"}, cmdCheckIPv6},
|
||||
"cname": {
|
||||
[]string{"<record>", "<destination>"},
|
||||
"Create a CNAME record, points to another name. If that name is managed by the DNS server, it must exist",
|
||||
[]string{"ttl"}, nil},
|
||||
"dname": {
|
||||
[]string{"<record>", "<destination>"},
|
||||
"Create a DNAME record, points to another name. If that name is managed by the DNS server, it must exist",
|
||||
[]string{"ttl"}, nil},
|
||||
"caa": {
|
||||
[]string{"<domain>", "<flag>", "<tag>", "<value>"},
|
||||
"Create a CAA record. If value contains spaces, it must be quoted",
|
||||
[]string{"append", "ttl"}, cmdCheckCAA},
|
||||
"srv": {
|
||||
[]string{"<_service._proto.name>", "<priority>", "<weight>", "<port>", "<target>"},
|
||||
"Create a SRV record. https://en.wikipedia.org/wiki/SRV_record",
|
||||
[]string{"append", "ttl"}, cmdCheckSRV},
|
||||
"txt": {
|
||||
[]string{"<record>", "<\"text\">"},
|
||||
"Create a TXT record. No need to escape quotes, it will be done automatically on the server",
|
||||
[]string{"append", "ttl"}, nil},
|
||||
"mx": {
|
||||
[]string{"<record>", "<priority>", "<mail-server>"},
|
||||
"Create a MX record. mail-server must exist",
|
||||
[]string{"append", "ttl"}, cmdCheckMX},
|
||||
"ns": {
|
||||
[]string{"<record>", "<dns-server>"},
|
||||
"Create a NS record. dns-server must exist and must be an external server",
|
||||
[]string{"append", "ttl"}, nil},
|
||||
"ptr": {
|
||||
[]string{"<ip | something.arpa>", "<name>"},
|
||||
"Add a PTR record. If the first argument is an IP, will convert it to an .arpa name",
|
||||
[]string{"append", "ttl"}, cmdCheckPTR},
|
||||
"delete": {
|
||||
[]string{"<record>", "[value]"},
|
||||
"Remove the record. If value is not provided, remove all records of all types sharing the name. If value is specified, only remove that particuliar value. If a deleted record corresponds to a matching reverse, will remove the reverse too",
|
||||
[]string{}, nil},
|
||||
"ttl": {
|
||||
[]string{"<record>", "[value]", "<ttl>"},
|
||||
"Change the TTL of a record. If value is not provided, change the TTL of all types sharing the name. If value is specified, only change that particuliar TTL",
|
||||
[]string{}, cmdCheckTTL},
|
||||
"newzone": {
|
||||
[]string{"<zone>"},
|
||||
"Add a new zone. The zone will be private for a private IPs reverse zone. NS and SOA options are automatically changed",
|
||||
[]string{}, nil},
|
||||
"search": {
|
||||
[]string{"<query>"},
|
||||
"Search the pdns database",
|
||||
[]string{}, nil},
|
||||
"dump": {
|
||||
[]string{"<zone>"},
|
||||
"Display the zone",
|
||||
[]string{}, nil},
|
||||
"list": {
|
||||
[]string{"[regexp]"},
|
||||
"List zones, filter on regexp if provided",
|
||||
[]string{}, nil},
|
||||
"batch": {
|
||||
[]string{"<file>"},
|
||||
"Batch mode: file is a jsonRPC 2.0 complient json file with all commands you want to execute. Example : \n [\n { \"jsonrpc\": \"2.0\", \"id\": 0, \"method\": \"list\" },\n { \"jsonrpc\": \"2.0\", \"id\": 1, \"method\": \"newzone\", \"params\": { \"name\": \"example.com\", \"ignore-error\": true } },\n { \"jsonrpc\": \"2.0\", \"id\": 2, \"method\": \"a\", \"params\": {\n \"comment\": \"it's the fault\", \"name\": \"toto.example.com\",\n \"value\": \"192.0.2.1\" } }\n ]\n By default, an error will stop the batch. Use the boolean ignore-error to change the behaviour for a particular line. The comment and ttl switches are applied only if no explicit value is provided in a line",
|
||||
[]string{"ttl"}, nil},
|
||||
}
|
||||
// we can override the default API url (for dev/debug). In this case,
|
||||
// ignore the SSL
|
||||
dnsAPI = defaultAPI
|
||||
if os.Getenv("DNS_API") != "" {
|
||||
dnsAPI = os.Getenv("DNS_API")
|
||||
}
|
||||
// ignore security for localhost
|
||||
if strings.HasPrefix(dnsAPI, "https://localhost") {
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
// set the http timeout
|
||||
httpClient.Timeout = time.Duration(60 * time.Second)
|
||||
|
||||
flag.Usage = Usage
|
||||
}
|
||||
|
||||
// cmdCheckIPv4 is the validation function for the "a" action.
|
||||
// Checks the 2nd argument is an IPv4 address
|
||||
func cmdCheckIPv4(j *jsonCommand) bool {
|
||||
ip := net.ParseIP(j.args[2])
|
||||
return ip != nil && ip.To4() != nil
|
||||
}
|
||||
|
||||
// cmdCheckIPv6 is the validation function for the "aaaa" action.
|
||||
// Checks the 2nd argument is an IPv6 address
|
||||
func cmdCheckIPv6(j *jsonCommand) bool {
|
||||
ip := net.ParseIP(j.args[2])
|
||||
return ip != nil && ip.To4() == nil
|
||||
}
|
||||
|
||||
// cmdCheckPTR is the validation function for the "reverse" action.
|
||||
// Checks the 1st argument is an IP or a valid .arpa name,
|
||||
// converts the first argument to the valid .arpa name if necessary,
|
||||
// checks that the second argument points back to the first
|
||||
func cmdCheckPTR(j *jsonCommand) bool {
|
||||
var ip net.IP
|
||||
|
||||
j.args[1] = strings.Trim(j.args[1], ".")
|
||||
j.args[2] = strings.Trim(j.args[2], ".")
|
||||
if strings.HasSuffix(j.args[1], ".arpa") {
|
||||
ip = net.ParseIP(ptrToIP(j.args[1]))
|
||||
} else {
|
||||
ip = net.ParseIP(j.args[1])
|
||||
}
|
||||
// if ip is not valid, stop here
|
||||
if ip == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s is not a valid IP or reverse\n\n", j.args[1])
|
||||
return false
|
||||
}
|
||||
j.Params.Name = iPtoReverse(ip)
|
||||
strIP := ip.String()
|
||||
|
||||
names, err := net.LookupHost(j.args[2])
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: %s must exist\n\n", j.args[2])
|
||||
return true
|
||||
}
|
||||
cname, err := net.LookupCNAME(j.args[2])
|
||||
cname = strings.Trim(cname, ".")
|
||||
if err == nil && cname != j.args[2] {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s cannot be a CNAME\n\n", j.args[2])
|
||||
return false
|
||||
}
|
||||
for _, n := range names {
|
||||
if n == strIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Warning: %s must point to %s\n\n", j.args[2], strIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// cmdCheckCAA is the validation function for the "caa" action.
|
||||
func cmdCheckCAA(j *jsonCommand) bool {
|
||||
const q = "\""
|
||||
var err error
|
||||
var tag int
|
||||
|
||||
if tag, err = strconv.Atoi(j.args[2]); err != nil || tag < 0 || tag > 255 {
|
||||
return false
|
||||
}
|
||||
j.args[4] = strings.Trim(j.args[4], q)
|
||||
j.args[4] = strings.Replace(j.args[4], q, "\\\"", -1)
|
||||
j.args[4] = q + j.args[4] + q
|
||||
|
||||
j.Params.Value = fmt.Sprintf("%d %s %s", tag, j.args[3], j.args[4])
|
||||
return true
|
||||
}
|
||||
|
||||
// cmdCheckSRV is the validation function for the "srv" action.
|
||||
// Checks that the arguments are valid, and put them all in arg[2]
|
||||
func cmdCheckSRV(j *jsonCommand) bool {
|
||||
var err error
|
||||
var prio, weight, port int
|
||||
validSRV := regexp.MustCompile("^_[a-z0-9]+\\._(tcp|udp|tls)[^ ]+$")
|
||||
if !validSRV.MatchString(j.args[1]) {
|
||||
return false
|
||||
}
|
||||
if prio, err = strconv.Atoi(j.args[2]); err != nil || prio < 0 {
|
||||
return false
|
||||
}
|
||||
if weight, err = strconv.Atoi(j.args[3]); err != nil || weight < 0 {
|
||||
return false
|
||||
}
|
||||
if port, err = strconv.Atoi(j.args[4]); err != nil || port < 1 {
|
||||
return false
|
||||
}
|
||||
j.Params.Value = fmt.Sprintf("%d %d %d %s", prio, weight, port, j.args[5])
|
||||
return true
|
||||
}
|
||||
|
||||
// cmdCheckMX is the validation function for the "mx" action.
|
||||
// Checks that the arguments are a weight and a mail servers
|
||||
func cmdCheckMX(j *jsonCommand) bool {
|
||||
var err error
|
||||
var prio int
|
||||
|
||||
if prio, err = strconv.Atoi(j.args[2]); err != nil || prio < 0 {
|
||||
return false
|
||||
}
|
||||
j.Params.Value = fmt.Sprintf("%d %s", prio, j.args[3])
|
||||
return true
|
||||
}
|
||||
|
||||
// cmdCheckTTL is the validation function for the "ttl" action.
|
||||
func cmdCheckTTL(j *jsonCommand) bool {
|
||||
var err error
|
||||
|
||||
// The args for this command are "<record>", "[value]", "<ttl>" (the
|
||||
// middle argument is optional).
|
||||
n := len(j.args)
|
||||
if j.Params.TTL, err = strconv.Atoi(j.args[n-1]); err != nil || j.Params.TTL < 0 {
|
||||
return false
|
||||
}
|
||||
// Default value if the optional argument is not given
|
||||
if n < 4 {
|
||||
j.Params.Value = ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Output the usage of this specific command
|
||||
func (c *command) Usage(name string) string {
|
||||
ret := fmt.Sprintf(" %s [options] %s %s : \n %s\n Options:\n",
|
||||
os.Args[0], name, c.formatArgs(), c.formatDescr())
|
||||
|
||||
findSame := map[string]string{}
|
||||
alt := map[string]string{}
|
||||
|
||||
for _, opt := range append(generalFlags, c.options...) {
|
||||
fl := flag.Lookup(opt)
|
||||
findSame[fl.Usage] = opt
|
||||
}
|
||||
flag.VisitAll(func(f *flag.Flag) {
|
||||
if opt, ok := findSame[f.Usage]; ok && opt != f.Name {
|
||||
alt[opt] = f.Name
|
||||
}
|
||||
})
|
||||
for _, opt := range append(generalFlags, c.options...) {
|
||||
var short, arg, defValue string
|
||||
f := flag.Lookup(opt)
|
||||
|
||||
if s, ok := alt[opt]; ok {
|
||||
short = fmt.Sprintf("-%s,", s)
|
||||
}
|
||||
if f.DefValue != "false" && f.DefValue != "true" {
|
||||
arg = "<value>"
|
||||
defValue = fmt.Sprintf(", default is %s", f.DefValue)
|
||||
}
|
||||
if f.DefValue == "" {
|
||||
defValue += "\"\""
|
||||
}
|
||||
ret += fmt.Sprintf(" %-3v --%-7v %-7v : %s%s\n",
|
||||
short, f.Name, arg, f.Usage, defValue)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// CheckArgs populate the json structure with the given arguments and check
|
||||
// they are valid
|
||||
func (c *command) CheckArgs(j *jsonCommand) bool {
|
||||
args := flag.Args()
|
||||
min := 1
|
||||
for _, a := range c.args {
|
||||
if a[0] == '<' {
|
||||
min++
|
||||
}
|
||||
}
|
||||
if len(args) > len(c.args)+1 {
|
||||
return false
|
||||
}
|
||||
if len(args) < min {
|
||||
return false
|
||||
}
|
||||
// per construction, there is always at least 2 arguments and the commands
|
||||
// in args
|
||||
if len(args) > 1 {
|
||||
j.Params.Name = args[1]
|
||||
}
|
||||
if len(args) > 2 {
|
||||
j.Params.Value = args[2]
|
||||
}
|
||||
// use the adapted check function. It can modify the structure
|
||||
if c.check != nil {
|
||||
return c.check(j)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SetDryRun set the dry-run flag on every entry
|
||||
func (ja jsonArray) SetDryRun() {
|
||||
for i := range ja {
|
||||
ja[i].Params.DryRun = true
|
||||
}
|
||||
}
|
||||
|
||||
// SetNonce gets a nonce from the API, and stores it in the structure.
|
||||
// The nonce is valid for 10 min, to avoid replay attacks
|
||||
func (ja jsonArray) SetNonce() error {
|
||||
var valid = regexp.MustCompile(`^[A-Za-z0-9+/]*$`)
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/nonce", dnsAPI), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
|
||||
if valid.MatchString(string(body)) {
|
||||
ja[0].Params.Nonce = string(body)
|
||||
return nil
|
||||
}
|
||||
return errors.New("Cannot get Nonce : Invalid response")
|
||||
}
|
||||
|
||||
// SetArgs copy the argument from the command line into the structure
|
||||
func (ja *jsonArray) SetArgs(j *jsonCommand, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("Not enough args")
|
||||
}
|
||||
cmd, ok := commands[args[0]]
|
||||
if !ok {
|
||||
return fmt.Errorf("Unknown command %s", args[0])
|
||||
}
|
||||
j.Method = args[0]
|
||||
j.args = args
|
||||
j.JSONRPC = "2.0"
|
||||
j.ID = 1
|
||||
|
||||
// check the arguments
|
||||
if !cmd.CheckArgs(j) {
|
||||
Usage()
|
||||
}
|
||||
// evacuate the simple case first
|
||||
if args[0] != "batch" {
|
||||
*ja = append(*ja, j)
|
||||
return nil
|
||||
}
|
||||
batch, err := ioutil.ReadFile(args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := json.Unmarshal(batch, ja); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range *ja {
|
||||
(*ja)[i].JSONRPC = j.JSONRPC
|
||||
(*ja)[i].ID = i
|
||||
if (*ja)[i].Params.TTL == 0 {
|
||||
(*ja)[i].Params.TTL = j.Params.TTL
|
||||
}
|
||||
if (*ja)[i].Params.Comment == "" {
|
||||
(*ja)[i].Params.Comment = j.Params.Comment
|
||||
}
|
||||
if (*ja)[i].Params.DryRun == false {
|
||||
(*ja)[i].Params.DryRun = j.Params.DryRun
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Usage print the usage for all actions, or just the one used
|
||||
func Usage() {
|
||||
fmt.Fprintln(flag.CommandLine.Output(), "Usage: ")
|
||||
if flag.NArg() >= 1 {
|
||||
name := flag.Arg(0)
|
||||
if cmd, ok := commands[name]; ok {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "%s\n\n", cmd.Usage(name))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
s := []string{}
|
||||
for name := range commands {
|
||||
s = append(s, name)
|
||||
}
|
||||
sort.Strings(s)
|
||||
|
||||
for _, name := range s {
|
||||
cmd := commands[name]
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "%s\n\n", cmd.Usage(name))
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// GetYubikey try to find a PGP Smartcard and return its ID
|
||||
func GetYubikey() (string, error) {
|
||||
out, err := exec.Command("gpg", "--card-status", "--with-colons").CombinedOutput()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if strings.HasPrefix(line, "fpr:") {
|
||||
return strings.Split(line, ":")[1], nil
|
||||
}
|
||||
}
|
||||
return "", errors.New("Yubikey issue (is it plugged in?)")
|
||||
}
|
||||
|
||||
// Sign runs the "gpg --clear-sign" command on the input.
|
||||
// If the key ID is empty, it will return the input address
|
||||
func Sign(payload []byte, gpgKey string) ([]byte, error) {
|
||||
if gpgKey == "" {
|
||||
return payload, nil
|
||||
}
|
||||
cmd := exec.Command("gpg", "--clearsign", "-u", gpgKey)
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return payload, err
|
||||
}
|
||||
stdin.Write(payload)
|
||||
stdin.Close()
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return payload, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// sendQuery sends the payload to the API server
|
||||
func sendQuery(payload []byte) ([]*apiResponse, error) {
|
||||
apiRespSimple := &apiResponse{}
|
||||
apiRespArray := []*apiResponse{}
|
||||
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/jsonrpc", dnsAPI), bytes.NewBuffer(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
s, _ := ioutil.ReadAll(resp.Body)
|
||||
if json.Unmarshal(s, &apiRespArray); len(apiRespArray) > 0 {
|
||||
return apiRespArray, nil
|
||||
}
|
||||
if err := json.Unmarshal(s, apiRespSimple); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(apiRespArray, apiRespSimple), nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
var j jsonCommand
|
||||
var jsonStruct jsonArray
|
||||
|
||||
// parse the command line
|
||||
flag.BoolVar(&j.Params.ForceReverse, "reverse", false, "Force the creation of a reverse")
|
||||
flag.BoolVar(&j.Params.ForceReverse, "r", false, "Force the creation of a reverse")
|
||||
flag.BoolVar(&j.Params.DryRun, "dry-run", false, "Explain what the command would do, but do nothing")
|
||||
flag.BoolVar(&j.Params.DryRun, "n", false, "Explain what the command would do, but do nothing")
|
||||
flag.StringVar(&j.Params.Comment, "comment", "", "Add a comment to the operation")
|
||||
flag.StringVar(&j.Params.Comment, "c", "", "Add a comment to the operation")
|
||||
flag.IntVar(&j.Params.TTL, "ttl", 172800, "Specify the TTL")
|
||||
flag.IntVar(&j.Params.TTL, "t", 172800, "Specify the TTL")
|
||||
flag.BoolVar(&j.Params.Append, "append", false, "Append the value, don't replace the whole record")
|
||||
flag.BoolVar(&j.Params.Append, "a", false, "Append the value, don't replace the whole record")
|
||||
flag.BoolVar(&j.Params.Debug, "debug", false, "Add debug info")
|
||||
flag.BoolVar(&j.Params.Debug, "d", false, "Add debug info")
|
||||
flag.Parse()
|
||||
|
||||
// copy and check the arguments
|
||||
if err := jsonStruct.SetArgs(&j, flag.Args()); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
Usage()
|
||||
}
|
||||
|
||||
// now we have the right number of arguments, and a valid command
|
||||
// Add Nonce
|
||||
if err := jsonStruct.SetNonce(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(5)
|
||||
}
|
||||
|
||||
// get the GPG Public Key from DM_GPG_KEY env var (usefull in case of several Yubikey)
|
||||
var gpgEnvKey = os.Getenv("DM_GPG_KEY")
|
||||
|
||||
// select Key from env if define
|
||||
var gpgKey string
|
||||
if gpgEnvKey != "" {
|
||||
gpgKey = gpgEnvKey
|
||||
// or the Yubikey
|
||||
} else {
|
||||
// get the Yubikey Public Key
|
||||
gpgYubikey, err := GetYubikey()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Warning, no Yubikey; forcing dry-run mode")
|
||||
jsonStruct.SetDryRun()
|
||||
}
|
||||
gpgKey = gpgYubikey
|
||||
}
|
||||
|
||||
// Transform into json
|
||||
jsonStr, err := json.MarshalIndent(jsonStruct, " ", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "panic")
|
||||
os.Exit(6)
|
||||
}
|
||||
// debug mode
|
||||
if j.Params.Debug {
|
||||
fmt.Println(string(jsonStr))
|
||||
}
|
||||
// Sign
|
||||
signed, err := Sign(jsonStr, gpgKey)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(7)
|
||||
}
|
||||
// Send to the server
|
||||
ret, err := sendQuery(signed)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(9)
|
||||
}
|
||||
for _, response := range ret {
|
||||
if response.Error.Message != "" {
|
||||
fmt.Fprintln(os.Stderr, "; "+
|
||||
strings.Replace(response.Error.Message, "\n", "\n; ", -1)+"\n")
|
||||
continue
|
||||
}
|
||||
for _, result := range response.Result {
|
||||
if strings.Trim(result.Comment, " \n") != "" {
|
||||
fmt.Fprintln(os.Stderr, "; "+
|
||||
strings.Replace(result.Comment, "\n", "\n; ", -1)+"\n")
|
||||
if result.Result != "" {
|
||||
fmt.Fprintln(os.Stderr, "; "+result.Result)
|
||||
}
|
||||
}
|
||||
if result.Changes != "" {
|
||||
fmt.Println(result.Changes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// formatDescr outputs the description while adding CR when the line is too long
|
||||
func (c *command) formatDescr() string {
|
||||
max := 70
|
||||
ret := ""
|
||||
current := 0
|
||||
for _, word := range strings.Split(c.descr, " ") {
|
||||
if strings.Contains(word, "\n") {
|
||||
current = 0
|
||||
ret += word + " "
|
||||
continue
|
||||
}
|
||||
if current+len(word) > max {
|
||||
ret += "\n "
|
||||
current = 0
|
||||
}
|
||||
current += len(word + " ")
|
||||
ret += word + " "
|
||||
}
|
||||
ret = strings.Trim(ret, " ")
|
||||
return ret
|
||||
}
|
||||
|
||||
// formatArgs concatenate the arguments for the Usage() function
|
||||
func (c *command) formatArgs() string {
|
||||
return strings.Join(c.args, " ")
|
||||
}
|
||||
|
||||
// iPtoReverse calculates the reverse name associated with an IPv4 or IPv6
|
||||
func iPtoReverse(ip net.IP) (arpa string) {
|
||||
const hexDigit = "0123456789abcdef"
|
||||
// code copied and adapted from the net library
|
||||
// ip can be 4 or 16 bytes long
|
||||
if ip.To4() != nil {
|
||||
if len(ip) == 16 {
|
||||
return uitoa(uint(ip[15])) + "." + uitoa(uint(ip[14])) + "." + uitoa(uint(ip[13])) + "." + uitoa(uint(ip[12])) + ".in-addr.arpa."
|
||||
}
|
||||
return uitoa(uint(ip[3])) + "." + uitoa(uint(ip[2])) + "." + uitoa(uint(ip[1])) + "." + uitoa(uint(ip[0])) + ".in-addr.arpa."
|
||||
}
|
||||
// Must be IPv6
|
||||
buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
|
||||
|
||||
// Add it, in reverse, to the buffer
|
||||
for i := len(ip) - 1; i >= 0; i-- {
|
||||
v := ip[i]
|
||||
buf = append(buf, hexDigit[v&0xF])
|
||||
buf = append(buf, '.')
|
||||
buf = append(buf, hexDigit[v>>4])
|
||||
buf = append(buf, '.')
|
||||
}
|
||||
// Append "ip6.arpa." and return (buf already has the final .)
|
||||
buf = append(buf, "ip6.arpa."...)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// Convert unsigned integer to decimal string.
|
||||
// code copied from the net library
|
||||
func uitoa(val uint) string {
|
||||
if val == 0 { // avoid string allocation
|
||||
return "0"
|
||||
}
|
||||
var buf [20]byte // big enough for 64bit value base 10
|
||||
i := len(buf) - 1
|
||||
for val >= 10 {
|
||||
q := val / 10
|
||||
buf[i] = byte('0' + val - q*10)
|
||||
i--
|
||||
val = q
|
||||
}
|
||||
// val < 10
|
||||
buf[i] = byte('0' + val)
|
||||
return string(buf[i:])
|
||||
}
|
||||
|
||||
// ptrToIP converts a .arpa name to the corresponding IP
|
||||
func ptrToIP(s string) string {
|
||||
s = reverseParts(s) // reverse parts between dots (".")
|
||||
count := 0
|
||||
ip := ""
|
||||
version := 4
|
||||
for _, elt := range strings.Split(s, ".") {
|
||||
switch elt {
|
||||
case "":
|
||||
case "ip6":
|
||||
version = 6
|
||||
case "in-addr":
|
||||
case "arpa":
|
||||
default:
|
||||
count++
|
||||
ip += elt
|
||||
if version == 4 && count != 4 {
|
||||
ip += "."
|
||||
}
|
||||
if version == 6 && count%4 == 0 && count != 32 {
|
||||
ip += ":"
|
||||
}
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// reverse the part order on every member of the array
|
||||
func reverseParts(s string) string {
|
||||
parts := strings.Split(s, ".")
|
||||
for i, j := 0, len(parts)-1; i < j; i, j = i+1, j-1 {
|
||||
parts[i], parts[j] = parts[j], parts[i]
|
||||
}
|
||||
return strings.Join(parts, ".")
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/pyke369/golang-support/uconfig"
|
||||
)
|
||||
|
||||
func loadConfig(configFile string) (*HTTPServer, error) {
|
||||
// Parse the configuration file
|
||||
config, err := uconfig.New(configFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h := NewHTTPServer(
|
||||
config.GetString("config.http.port", "127.0.0.01:8080"),
|
||||
config.GetString("config.http.key", ""),
|
||||
config.GetString("config.http.cert", ""),
|
||||
config.GetString("config.http.crl", ""),
|
||||
config.GetString("config.http.ca", ""),
|
||||
config.GetString("config.pdns.api-url", "http://127.0.0.1:8081/api/v1/servers/localhost"),
|
||||
config.GetString("config.pdns.api-key", ""),
|
||||
int(config.GetInteger("config.pdns.timeout", 3)),
|
||||
int(config.GetInteger("config.pdns.defaultTTL", 3600)),
|
||||
)
|
||||
populateHTTPServerZoneProfiles(config, h)
|
||||
populateHTTPServerAcls(config, h)
|
||||
populateHTTPServerProfiles(config, h)
|
||||
populateJSONRPCAcls(config, h)
|
||||
return h, err
|
||||
}
|
||||
|
||||
// populateHTTPServerAcls parses the native ACL, and put them into the wanted structures
|
||||
func populateHTTPServerAcls(config *uconfig.UConfig, h *HTTPServer) {
|
||||
for _, acl := range config.GetPaths("config.pdnsAcls") {
|
||||
h.AddPdnsACL(NewPdnsACL(
|
||||
config.GetString(acl+".regexp", ""),
|
||||
parseConfigArray(config, acl+".perms"),
|
||||
parseConfigArray(config, acl+".profiles"),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// populateJSONRPCAcls parses the json RPC API ACL, and put them into the wanted structures
|
||||
func populateJSONRPCAcls(config *uconfig.UConfig, h *HTTPServer) {
|
||||
for _, acl := range config.GetPaths("config.jrpcAcls") {
|
||||
perms := map[string][]string{}
|
||||
for _, action := range config.GetPaths(acl + ".perms") {
|
||||
perms[getSuffix(action)] = parseConfigArray(config, action)
|
||||
}
|
||||
h.AddjsonRPCACL(NewJSONRPCACL(
|
||||
perms,
|
||||
parseConfigArray(config, acl+".pgpProfiles"),
|
||||
parseConfigArray(config, acl+".sslProfiles"),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// populateHTTPServerProfiles parses profiles, and put them into the wanted structures
|
||||
func populateHTTPServerProfiles(config *uconfig.UConfig, h *HTTPServer) {
|
||||
// convert the configuration into authProfiles
|
||||
for _, profile := range config.GetPaths("config.profiles") {
|
||||
if p, err := parseConfigProfile(config, profile); err == nil {
|
||||
h.AddAuthProfile(getSuffix(profile), p)
|
||||
} else {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// populateHTTPServerZoneProfiles parses the ZoneProfiles and put them into the wanted structures
|
||||
func populateHTTPServerZoneProfiles(config *uconfig.UConfig, h *HTTPServer) {
|
||||
for _, profile := range config.GetPaths("config.zoneProfile") {
|
||||
if err := h.NewZoneProfile(
|
||||
getSuffix(profile),
|
||||
config.GetString(profile+".soa", ""),
|
||||
config.GetBoolean(profile+".default", false),
|
||||
config.GetBoolean(profile+".autoIncrement", true),
|
||||
parseConfigArray(config, profile+".nameservers"),
|
||||
parseConfigArray(config, profile+".whenRegexp"),
|
||||
); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
for _, entry := range config.GetPaths(profile + ".populate") {
|
||||
h.zoneProfiles[getSuffix(profile)].addDefaultEntry(
|
||||
config.GetString(entry+".name", ""),
|
||||
config.GetString(entry+".type", "txt"),
|
||||
config.GetString(entry+".value", ""),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseConfigProfile allows to use the adequate type of profile
|
||||
func parseConfigProfile(config *uconfig.UConfig, profile string) (AuthProfile, error) {
|
||||
t := config.GetString(profile+".type", "")
|
||||
switch t {
|
||||
case "regexp":
|
||||
return NewAuthProfileRegexp(
|
||||
config.GetString(profile+".subjectRegexp", ""),
|
||||
config.GetString(profile+".pgpKeys", "")), nil
|
||||
case "ldap":
|
||||
return NewAuthProfileLdap(config.GetString(profile+".subjectRegexp", ""),
|
||||
parseConfigArray(config, profile+".servers"),
|
||||
config.GetString(profile+".bindCn", ""),
|
||||
config.GetString(profile+".bindPw", ""),
|
||||
config.GetString(profile+".baseDN", ""),
|
||||
config.GetString(profile+".searchFilter", ""),
|
||||
config.GetString(profile+".attribute", ""),
|
||||
config.GetString(profile+".pgpAttribute", "pgpKey"),
|
||||
parseConfigArray(config, profile+".validValues"),
|
||||
config.GetBoolean(profile+".ssl", true),
|
||||
)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown profile type %s", t)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// parseConfigArray parses a uconf array
|
||||
func parseConfigArray(config *uconfig.UConfig, configpath string) []string {
|
||||
result := []string{}
|
||||
for _, i := range config.GetPaths(configpath) {
|
||||
if s := config.GetString(i, ""); s == "" {
|
||||
continue
|
||||
} else {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getSuffix gets the last part of a config path
|
||||
func getSuffix(s string) string {
|
||||
sl := strings.Split(s, ".")
|
||||
return sl[len(sl)-1]
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
pdns-auth-proxy (0.9.7-1+ubuntu2) unstable; urgency=medium
|
||||
|
||||
[ Xavier Henner ]
|
||||
* Initial euclide.org release
|
||||
|
||||
-- root <xavier@euclide.org> Fri, 17 Nov 2023 04:52:00 +0000
|
|
@ -0,0 +1 @@
|
|||
5
|
|
@ -0,0 +1,15 @@
|
|||
Source: pdns-auth-proxy
|
||||
Maintainer: Xavier Henner <xavier@euclide.org>
|
||||
Section: net
|
||||
Priority: optional
|
||||
Build-Depends:
|
||||
debhelper,
|
||||
golang-1.21
|
||||
Standards-Version: 4.4.1
|
||||
Vcs-Browser: https://git.euclide.org/euclide/pdns-auth-proxy
|
||||
Vcs-Git: https://git.euclide.org/euclide/pdns-auth-proxy.git
|
||||
|
||||
Package: pdns-auth-proxy
|
||||
Architecture: any
|
||||
Description: Proxy before pdns API
|
||||
Depends: ${misc:Depends}, ${shlibs:Depends}, pdns-server
|
|
@ -0,0 +1,2 @@
|
|||
[DEFAULT]
|
||||
pristine-tar = False
|
|
@ -0,0 +1,5 @@
|
|||
#pdns-auth-proxy base config
|
||||
|
||||
ENABLED=1
|
||||
CONFIGURATION="/etc/pdns-auth-proxy.config"
|
||||
OPTS="-syslog"
|
|
@ -0,0 +1,17 @@
|
|||
[Unit]
|
||||
Description=pdns-auth-proxy
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
User=pdns
|
||||
Group=pdns
|
||||
EnvironmentFile=/etc/default/pdns-auth-proxy
|
||||
StandardOutput=null
|
||||
StandardError=journal
|
||||
LimitNOFILE=65536
|
||||
ExecStart=/usr/sbin/pdns-auth-proxy -config ${CONFIGURATION} ${OPTS}
|
||||
Restart=always
|
||||
RestartSec=3
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/make -f
|
||||
|
||||
build: build-stamp
|
||||
|
||||
build-stamp:
|
||||
dh_testdir
|
||||
$(MAKE) build
|
||||
touch build-stamp
|
||||
|
||||
clean:
|
||||
dh_testdir
|
||||
dh_testroot
|
||||
rm -f build-stamp
|
||||
$(MAKE) clean
|
||||
dh_clean
|
||||
|
||||
install: build
|
||||
dh_testdir
|
||||
dh_prep
|
||||
dh_install pdns-auth-proxy /usr/sbin
|
||||
dh_install client/main.go /usr/share/doc/pdns-auth-proxy/client.go
|
||||
|
||||
binary-indep: build install
|
||||
|
||||
binary-arch: build install
|
||||
dh_testdir
|
||||
dh_testroot
|
||||
dh_installinit
|
||||
dh_fixperms
|
||||
dh_installdeb
|
||||
dh_gencontrol
|
||||
dh_md5sums
|
||||
dh_builddeb
|
||||
|
||||
binary: binary-indep binary-arch
|
||||
.PHONY: build clean binary-indep binary-arch binary install
|
|
@ -0,0 +1 @@
|
|||
3.0 (quilt)
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug pdns-auth-proxy",
|
||||
"mode": "debug",
|
||||
"cwd": "${workspaceRoot}",
|
||||
"args": ["-config", "pdns-proxy-test.conf"]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
version: '3'
|
||||
|
||||
networks:
|
||||
ns:
|
||||
driver: calico
|
||||
ipam:
|
||||
driver: calico-ipam
|
|
@ -0,0 +1,5 @@
|
|||
version: '3'
|
||||
|
||||
networks:
|
||||
ns:
|
||||
driver: bridge
|
|
@ -0,0 +1,21 @@
|
|||
version: '3'
|
||||
services:
|
||||
build: &build
|
||||
build: .
|
||||
command: make build
|
||||
volumes:
|
||||
- .:/src
|
||||
networks:
|
||||
- ns
|
||||
|
||||
fmt:
|
||||
<<: *build
|
||||
command: make fmt
|
||||
|
||||
lint:
|
||||
<<: *build
|
||||
command: make lint
|
||||
|
||||
test:
|
||||
<<: *build
|
||||
command: make test
|
|
@ -0,0 +1,72 @@
|
|||
version: '2'
|
||||
|
||||
services:
|
||||
|
||||
openldap:
|
||||
image: osixia/openldap:1.2.5
|
||||
container_name: openldap
|
||||
environment:
|
||||
LDAP_LOG_LEVEL: "256"
|
||||
# LDAP_ORGANISATION: "example.org."
|
||||
LDAP_DOMAIN: "example.org"
|
||||
LDAP_BASE_DN: "dc=example,dc=org"
|
||||
LDAP_ADMIN_PASSWORD: "admin"
|
||||
LDAP_READONLY_USER: "false"
|
||||
# LDAP_READONLY_USER_USERNAME: "ldap-ro"
|
||||
# LDAP_READONLY_USER_PASSWORD: "prout"
|
||||
LDAP_TLS: "true"
|
||||
LDAP_TLS_CRT_FILENAME: "ldap-cert.pem"
|
||||
LDAP_TLS_KEY_FILENAME: "ldap-key.pem"
|
||||
LDAP_TLS_CA_CRT_FILENAME: "ca.crt"
|
||||
LDAP_TLS_VERIFY_CLIENT: "never"
|
||||
tty: true
|
||||
stdin_open: true
|
||||
command: [ "--copy-service","--loglevel","debug" ]
|
||||
ports:
|
||||
- "389:389"
|
||||
- "636:636"
|
||||
volumes:
|
||||
- ./openldap/test.ldif:/container/service/slapd/assets/config/bootstrap/ldif/50-bootstrap.ldif
|
||||
- ./test:/container/service/slapd/certs
|
||||
hostname: ldap.example.org
|
||||
|
||||
pdns:
|
||||
image: synyx/pdns:latest
|
||||
container_name: pdns
|
||||
environment:
|
||||
PDNS_DEBUG_ENV: "true"
|
||||
PDNS_LAUNCH: "gmysql"
|
||||
PDNS_GMYSQL_HOST: "pdns-mysql"
|
||||
PDNS_GMYSQL_DBNAME: "pdns"
|
||||
PDNS_GMYSQL_USER: "pdns"
|
||||
PDNS_GMYSQL_PASSWORD: "pdns"
|
||||
PDNS_LOG_DNS_QUERIES: "yes"
|
||||
PDNS_LOGLEVEL: 5
|
||||
PDNS_API: "yes"
|
||||
PDNS_API_KEY: "123password"
|
||||
PDNS_API_LOFGILE: "/var/log/pdns-api.log"
|
||||
PDNS_WEBSERVER: "yes"
|
||||
PDNS_WEBSERVER_ADDRESS: "0.0.0.0"
|
||||
PDNS_WEBSERVER_ALLOW_FROM: "0.0.0.0/0;::/0"
|
||||
ports:
|
||||
- "53:53"
|
||||
- "53:53/udp"
|
||||
- "8081:8081"
|
||||
- "9120:9120"
|
||||
hostname: pdns
|
||||
links:
|
||||
- pdns-mysql
|
||||
|
||||
pdns-mysql:
|
||||
image: mariadb:10.3.10
|
||||
container_name: pdns-mysql
|
||||
environment:
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
MYSQL_DATABASE: "pdns"
|
||||
MYSQL_USER: "pdns"
|
||||
MYSQL_PASSWORD: "pdns"
|
||||
ports:
|
||||
- "3306:3306"
|
||||
volumes:
|
||||
- ./pdns/pdns.dump:/docker-entrypoint-initdb.d/pdns.sql
|
||||
hostname: pdns-mysql
|
|
@ -0,0 +1,58 @@
|
|||
# This is the default image startup configuration file
|
||||
# this file define environment variables used during the container **first start** in **startup files**.
|
||||
|
||||
# General container configuration
|
||||
# see table 5.1 in http://www.openldap.org/doc/admin24/slapdconf2.html for the available log levels.
|
||||
LDAP_LOG_LEVEL: 256
|
||||
|
||||
# Ulimit
|
||||
LDAP_NOFILE: 1024
|
||||
|
||||
# Do not perform any chown to fix file ownership
|
||||
DISABLE_CHOWN: false
|
||||
|
||||
# Required and used for new ldap server only
|
||||
LDAP_ORGANISATION: Example Inc.
|
||||
LDAP_DOMAIN: example.org
|
||||
LDAP_BASE_DN: #if empty automatically set from LDAP_DOMAIN
|
||||
|
||||
LDAP_ADMIN_PASSWORD: admin
|
||||
LDAP_CONFIG_PASSWORD: config
|
||||
|
||||
LDAP_READONLY_USER: false
|
||||
LDAP_READONLY_USER_USERNAME: readonly
|
||||
LDAP_READONLY_USER_PASSWORD: readonly
|
||||
|
||||
LDAP_RFC2307BIS_SCHEMA: false
|
||||
|
||||
# Backend
|
||||
LDAP_BACKEND: mdb
|
||||
|
||||
# Tls
|
||||
LDAP_TLS: true
|
||||
LDAP_TLS_CRT_FILENAME: ldap.crt
|
||||
LDAP_TLS_KEY_FILENAME: ldap.key
|
||||
LDAP_TLS_DH_PARAM_FILENAME: dhparam.pem
|
||||
LDAP_TLS_CA_CRT_FILENAME: ca.crt
|
||||
|
||||
LDAP_TLS_ENFORCE: false
|
||||
LDAP_TLS_CIPHER_SUITE: SECURE256:+SECURE128:-VERS-TLS-ALL:+VERS-TLS1.2:-RSA:-DHE-DSS:-CAMELLIA-128-CBC:-CAMELLIA-256-CBC
|
||||
LDAP_TLS_VERIFY_CLIENT: demand
|
||||
|
||||
# Replication
|
||||
LDAP_REPLICATION: false
|
||||
|
||||
|
||||
# Do not change the ldap config
|
||||
# - If set to true with an existing database, config will remain unchanged. Image tls and replication config will not be run.
|
||||
# The container can be started with LDAP_ADMIN_PASSWORD and LDAP_CONFIG_PASSWORD empty or filled with fake data.
|
||||
# - If set to true when bootstrapping a new database, bootstap ldif and schema will not be added and tls and replication config will not be run.
|
||||
KEEP_EXISTING_CONFIG: false
|
||||
|
||||
# Remove config after setup
|
||||
LDAP_REMOVE_CONFIG_AFTER_SETUP: true
|
||||
|
||||
# ssl-helper environment variables prefix
|
||||
LDAP_SSL_HELPER_PREFIX: ldap # ssl-helper first search config from LDAP_SSL_HELPER_* variables, before SSL_HELPER_* variables.
|
||||
|
||||
SSL_HELPER_AUTO_RENEW_SERVICES_IMPACTED: slapd
|
|
@ -0,0 +1,109 @@
|
|||
# LDIF Export for ou=users,dc=example,dc=org
|
||||
|
||||
version: 1
|
||||
|
||||
dn: dc=example,dc=org
|
||||
o: example
|
||||
objectClass: dcObject
|
||||
objectClass: organization
|
||||
description: Example Directory
|
||||
|
||||
# Entry 1: ou=users,dc=example,dc=org
|
||||
dn: ou=users,dc=example,dc=org
|
||||
changetype: add
|
||||
description: Utilisateurs
|
||||
objectclass: organizationalUnit
|
||||
ou: users
|
||||
|
||||
# Entry 2: uid=jdoe,ou=users,dc=example,dc=org
|
||||
dn: uid=jdoe,ou=users,dc=example,dc=org
|
||||
changetype: add
|
||||
cn: Jane Doe
|
||||
description: jdoe
|
||||
description: infra
|
||||
gidnumber: 0
|
||||
homedirectory: /root
|
||||
sshPublicKey: ""
|
||||
loginshell: /bin/bash
|
||||
mail: karin.aitsiamer@example.org
|
||||
objectclass: inetOrgPerson
|
||||
objectclass: posixAccount
|
||||
objectclass: ldapPublicKey
|
||||
objectclass: person
|
||||
objectclass: top
|
||||
sn: Jane Doe
|
||||
uid: jdoe
|
||||
uidnumber: 0
|
||||
|
||||
# Entry 3: uid=jdo,ou=users,dc=example,dc=org
|
||||
dn: uid=jdo,ou=users,dc=example,dc=org
|
||||
changetype: add
|
||||
cn: John DO
|
||||
sshPublicKey: ""
|
||||
description: jdo
|
||||
description: dev
|
||||
description: api-01.dev
|
||||
description: api-02.dev
|
||||
description: bbxlogs
|
||||
description: graphdb
|
||||
description: inspect-02
|
||||
description: ejobs
|
||||
description: eseblook
|
||||
description: esebworker
|
||||
description: esebed
|
||||
description: esebed.dev
|
||||
description: ebworker.dev
|
||||
description: eseblook.dev
|
||||
description: searchprov
|
||||
description: syslog-01
|
||||
description: cronjobs
|
||||
description: staging
|
||||
description: pebprov
|
||||
description: bodyguard
|
||||
description: webed
|
||||
description: pebed
|
||||
description: mebed
|
||||
description: ebworker
|
||||
description: prov
|
||||
description: web-01.dev
|
||||
description: web-02.dev
|
||||
description: releaseslave
|
||||
description: release
|
||||
description: inspectslave
|
||||
description: statyle
|
||||
description: inspect
|
||||
description: orscale-03
|
||||
description: dock-001
|
||||
description: npmrepo
|
||||
gidnumber: 0
|
||||
homedirectory: /home/jdo
|
||||
loginshell: /bin/bash
|
||||
mail: klemen.sever@example.org
|
||||
objectclass: inetOrgPerson
|
||||
objectclass: posixAccount
|
||||
objectclass: ldapPublicKey
|
||||
objectclass: person
|
||||
objectclass: top
|
||||
sn: John Do
|
||||
uid: jdo
|
||||
uidnumber: 0
|
||||
|
||||
# Entry 4: uid=xavier,ou=users,dc=example,dc=org
|
||||
dn: uid=xavier,ou=users,dc=example,dc=org
|
||||
changetype: add
|
||||
cn: Xavier Henner
|
||||
sshPublicKey: ""
|
||||
description: infra
|
||||
employeenumber: 1005
|
||||
gidnumber: 0
|
||||
homedirectory: /root
|
||||
loginshell: /bin/bash
|
||||
mail: xavier.henner@example.org
|
||||
objectclass: inetOrgPerson
|
||||
objectclass: posixAccount
|
||||
objectclass: ldapPublicKey
|
||||
objectclass: person
|
||||
objectclass: top
|
||||
sn: Xavier Henner
|
||||
uid: xavier
|
||||
uidnumber: 0
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"comments": [
|
||||
{
|
||||
"account": "kr1",
|
||||
"content": "My first API-created zone",
|
||||
"name": "uhuh",
|
||||
"type": "dunno"
|
||||
}
|
||||
],
|
||||
"kind": "Native",
|
||||
"masters": [],
|
||||
"name": "example2.net.",
|
||||
"nameservers": [
|
||||
"ns1.example.net.",
|
||||
"ns2.example.net."
|
||||
],
|
||||
"records": [
|
||||
{
|
||||
"content": "ns.example.net. hostmaster.example.com. 1 1800 900 604800 86400",
|
||||
"disabled": false,
|
||||
"name": "example.net",
|
||||
"ttl": 86400,
|
||||
"type": "SOA"
|
||||
},
|
||||
{
|
||||
"content": "192.168.1.42",
|
||||
"disabled": false,
|
||||
"name": "www.example.net",
|
||||
"ttl": 3600,
|
||||
"type": "A"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
log-dns-queries=yes
|
||||
loglevel=5
|
||||
allow-recursion=127.0.0.1
|
||||
default-ttl=60
|
||||
disable-axfr=no
|
||||
api=yes
|
||||
api-key=123password
|
||||
api-logfile=/var/log/pdns-api.log
|
||||
webserver=yes
|
||||
webserver-address=0.0.0.0
|
||||
webserver-allow-from=0.0.0.0/0;::/0
|
||||
guardian=yes
|
||||
launch=
|
||||
local-address=0.0.0.0
|
||||
local-ipv6=::
|
||||
master=yes
|
||||
soa-expire-default=1209600
|
||||
soa-minimum-ttl=60
|
||||
soa-refresh-default=60
|
||||
soa-retry-default=60
|
|
@ -0,0 +1,105 @@
|
|||
DROP TABLE IF EXISTS records;
|
||||
CREATE TABLE records (
|
||||
id BIGINT AUTO_INCREMENT,
|
||||
domain_id INT DEFAULT NULL,
|
||||
name VARCHAR(255) DEFAULT NULL,
|
||||
type VARCHAR(10) DEFAULT NULL,
|
||||
content VARCHAR(64000) DEFAULT NULL,
|
||||
ttl INT DEFAULT NULL,
|
||||
prio INT DEFAULT NULL,
|
||||
change_date INT DEFAULT NULL,
|
||||
disabled TINYINT(1) DEFAULT 0,
|
||||
ordername VARCHAR(255) BINARY DEFAULT NULL,
|
||||
auth TINYINT(1) DEFAULT 1,
|
||||
PRIMARY KEY (id)
|
||||
) Engine=InnoDB CHARACTER SET 'latin1';
|
||||
|
||||
CREATE INDEX nametype_index ON records(name,type);
|
||||
CREATE INDEX domain_id ON records(domain_id);
|
||||
CREATE INDEX ordername ON records (ordername);
|
||||
|
||||
|
||||
DROP TABLE IF EXISTS supermasters;
|
||||
CREATE TABLE supermasters (
|
||||
ip VARCHAR(64) NOT NULL,
|
||||
nameserver VARCHAR(255) NOT NULL,
|
||||
account VARCHAR(40) CHARACTER SET 'utf8' NOT NULL,
|
||||
PRIMARY KEY (ip, nameserver)
|
||||
) Engine=InnoDB CHARACTER SET 'latin1';
|
||||
|
||||
|
||||
DROP TABLE IF EXISTS comments;
|
||||
CREATE TABLE comments (
|
||||
id INT AUTO_INCREMENT,
|
||||
domain_id INT NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
type VARCHAR(10) NOT NULL,
|
||||
modified_at INT NOT NULL,
|
||||
account VARCHAR(40) CHARACTER SET 'utf8' DEFAULT NULL,
|
||||
comment TEXT CHARACTER SET 'utf8' NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
) Engine=InnoDB CHARACTER SET 'latin1';
|
||||
|
||||
CREATE INDEX comments_name_type_idx ON comments (name, type);
|
||||
CREATE INDEX comments_order_idx ON comments (domain_id, modified_at);
|
||||
|
||||
|
||||
DROP TABLE IF EXISTS domainmetadata;
|
||||
CREATE TABLE domainmetadata (
|
||||
id INT AUTO_INCREMENT,
|
||||
domain_id INT NOT NULL,
|
||||
kind VARCHAR(32),
|
||||
content TEXT,
|
||||
PRIMARY KEY (id)
|
||||
) Engine=InnoDB CHARACTER SET 'latin1';
|
||||
|
||||
CREATE INDEX domainmetadata_idx ON domainmetadata (domain_id, kind);
|
||||
|
||||
|
||||
DROP TABLE IF EXISTS cryptokeys;
|
||||
CREATE TABLE cryptokeys (
|
||||
id INT AUTO_INCREMENT,
|
||||
domain_id INT NOT NULL,
|
||||
flags INT NOT NULL,
|
||||
active BOOL,
|
||||
content TEXT,
|
||||
PRIMARY KEY(id)
|
||||
) Engine=InnoDB CHARACTER SET 'latin1';
|
||||
|
||||
CREATE INDEX domainidindex ON cryptokeys(domain_id);
|
||||
|
||||
|
||||
DROP TABLE IF EXISTS tsigkeys;
|
||||
CREATE TABLE tsigkeys (
|
||||
id INT AUTO_INCREMENT,
|
||||
name VARCHAR(255),
|
||||
algorithm VARCHAR(50),
|
||||
secret VARCHAR(255),
|
||||
PRIMARY KEY (id)
|
||||
) Engine=InnoDB CHARACTER SET 'latin1';
|
||||
|
||||
DROP TABLE IF EXISTS domains;
|
||||
CREATE TABLE domains (
|
||||
id INT AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
master VARCHAR(128) DEFAULT NULL,
|
||||
last_check INT DEFAULT NULL,
|
||||
type VARCHAR(6) NOT NULL,
|
||||
notified_serial INT DEFAULT NULL,
|
||||
account VARCHAR(40) CHARACTER SET 'utf8' DEFAULT NULL,
|
||||
PRIMARY KEY (id)
|
||||
) Engine=InnoDB CHARACTER SET 'latin1';
|
||||
|
||||
CREATE UNIQUE INDEX name_index ON domains(name);
|
||||
|
||||
|
||||
|
||||
|
||||
CREATE UNIQUE INDEX namealgoindex ON tsigkeys(name, algorithm);
|
||||
|
||||
ALTER TABLE records ADD CONSTRAINT `records_domain_id_ibfk` FOREIGN KEY (`domain_id`) REFERENCES `domains` (`id`) ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE comments ADD CONSTRAINT `comments_domain_id_ibfk` FOREIGN KEY (`domain_id`) REFERENCES `domains` (`id`) ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE domainmetadata ADD CONSTRAINT `domainmetadata_domain_id_ibfk` FOREIGN KEY (`domain_id`) REFERENCES `domains` (`id`) ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE cryptokeys ADD CONSTRAINT `cryptokeys_domain_id_ibfk` FOREIGN KEY (`domain_id`) REFERENCES `domains` (`id`) ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
[
|
||||
{ "id": 0, "method": "list", "params": { "ignore-error": true }},
|
||||
{ "id": 1, "method": "newzone", "params": { "name": "example.com" }},
|
||||
{ "id": 2, "method": "newzone", "params": { "name": "2.0.192.in-addr.arpa" }},
|
||||
{ "id": 3, "method": "a", "params": { "name": "mail.example.com", "value": "192.0.2.12" }},
|
||||
{ "id": 4, "method": "a", "params": { "name": "localhost.example.com", "value": "127.0.0.1" }},
|
||||
{ "id": 5, "method": "a", "params": { "name": "www.example.com", "value": "192.0.2.10" }},
|
||||
{ "id": 6, "method": "a", "params": { "name": "a.example.com", "value": "192.0.2.53" }},
|
||||
{ "id": 7, "method": "a", "params": { "name": "b.example.com", "value": "192.0.2.54" }},
|
||||
{ "id": 8, "method": "list" },
|
||||
{ "id": 9, "method": "a", "params": { "name": "toto.example.com", "value": "1.1.1.1" }},
|
||||
{ "id": 10, "method": "list" },
|
||||
{ "id": 11, "method": "dump", "params": { "name": "example.com" }},
|
||||
{ "id": 12, "method": "a", "params": { "name": "toto.example.com", "value": "1.1.1.1" }},
|
||||
{ "id": 13, "method": "a", "params": { "append": true, "name": "toto.example.com", "value": "192.0.2.30" }},
|
||||
{ "id": 14, "method": "dump", "params": { "name": "2.0.192.in-addr.arpa" }},
|
||||
{ "id": 15, "method": "a", "params": { "name": "toto.example.com", "value": "1.1.1.1" }},
|
||||
{ "id": 16, "method": "dump", "params": { "name": "2.0.192.in-addr.arpa" }},
|
||||
{ "id": 17, "method": "dump", "params": { "name": "example.com" }},
|
||||
{ "id": 18, "method": "caa", "params": { "name": "example.com", "value": "0 issue Digicert.com" }},
|
||||
{ "id": 19, "method": "txt", "params": { "name": "toto.example.com", "value": "text and spaces" }},
|
||||
{ "id": 20, "method": "a", "params": { "name": "toto2.example.com", "value": "1.1.1.1" }},
|
||||
{ "id": 21, "method": "ptr", "params": { "ignore-error": true, "name": "10.2.0.192.in-addr.arpa", "value": "www.example.com" }},
|
||||
{ "id": 22, "method": "ptr", "params": { "ignore-error": true, "name": "11.2.0.192.in-addr.arpa", "value": "mail.example.com" }},
|
||||
{ "id": 23, "method": "ptr", "params": { "name": "12.2.0.192.in-addr.arpa", "value": "mail.example.com" }},
|
||||
{ "id": 24, "method": "ttl", "params": { "comment": "test", "name": "toto2.example.com", "ttl": 300 }},
|
||||
{ "id": 25, "method": "ttl", "params": { "ignore-error": true, "name": "toto2.example.com", "ttl": 300 }},
|
||||
{ "id": 26, "method": "ttl", "params": { "ignore-error": true, "name": "toto3.example.com", "ttl": 300 }},
|
||||
{ "id": 27, "method": "dump", "params": { "name": "example.com" }},
|
||||
{ "id": 28, "method": "cname", "params": { "name": "titi.example.com", "value": "toto.example.com" }},
|
||||
{ "id": 29, "method": "mx", "params": { "ignore-error": true, "name": "example.com", "value": "titi.example.com" }},
|
||||
{ "id": 30, "method": "mx", "params": { "ignore-error": true, "name": "example.com", "value": "20 titi.example.com" }},
|
||||
{ "id": 31, "method": "mx", "params": { "append": true, "name": "example.com", "value": "20 www.example.com" }},
|
||||
{ "id": 32, "method": "delete", "params": { "ignore-error": true, "name": "toto.example.com" }},
|
||||
{ "id": 33, "method": "delete", "params": { "name": "toto.example.com", "value": "1.1.1.1" }},
|
||||
{ "id": 34, "method": "delete", "params": { "name": "toto.example.com", "value": "text and spaces" }},
|
||||
{ "id": 35, "method": "delete", "params": { "name": "titi.example.com" }},
|
||||
{ "id": 36, "method": "delete", "params": { "name": "toto2.example.com" }},
|
||||
{ "id": 37, "method": "a", "params": { "name": "device.toto.example.com", "value":"1.1.1.1" }},
|
||||
{ "id": 38, "method": "newzone", "params": { "name": "toto.example.com" }},
|
||||
{ "id": 39, "method": "cname", "params": { "name": "bidule.titi.example.com", "value": "www.example.com" }},
|
||||
{ "id": 40, "method": "ns", "params": { "ignore-error": true, "name": "titi.example.com", "value": "truc.example.com" }},
|
||||
{ "id": 41, "method": "ns", "params": { "ignore-error": true, "name": "titi.example.com", "value": "bidule.example.com" }},
|
||||
{ "id": 42, "method": "ns", "params": { "ignore-error": true, "name": "titi.example.com", "value": "truc.titi.example.com" }},
|
||||
{ "id": 43, "method": "ns", "params": { "ignore-error": true, "name": "titi.example.com", "value": "truc.example.com" }},
|
||||
{ "id": 44, "method": "aaaa", "params": { "name": "truc.titi.example.com", "value": "2001:7a8::1" }},
|
||||
{ "id": 45, "method": "delete", "params": { "name": "bidule.titi.example.com", "value": "www.exa2mple.com" }},
|
||||
{ "id": 46, "method": "delete", "params": { "name": "bidule.titi.example.com", "value": "www.example.com" }},
|
||||
{ "id": 47, "method": "ns", "params": { "name": "titi.example.com", "value": "truc.titi.example.com" }},
|
||||
{ "id": 48, "method": "ns", "params": { "append": true, "name": "titi.example.com", "value": "a.exemple.com" }},
|
||||
{ "id": 49, "method": "search", "params": { "name": "*" }},
|
||||
{ "id": 50, "method": "domain", "params": { "name": "www.example.com" }},
|
||||
{ "id": 51, "method": "cname", "params": { "name": "www5.example.com", "value": "indirect.titi.example.com" }},
|
||||
{ "id": 52, "method": "domain", "params": { "ignore-error": true, "name": "www.exr*ple.com" }},
|
||||
{ "id": 53, "method": "dump", "params": { "ignore-error": true, "name": "badexample.com" }}
|
||||
]
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,78 @@
|
|||
#!/bin/sh
|
||||
|
||||
all: client-cert.pem server-cert.pem badclient-cert.pem public-key.txt root.crl.pem ldap-cert.pem webclient-cert.pem
|
||||
|
||||
clean:
|
||||
rm -f *.key *.pem *.csr *.crt ca.srl public-key.txt index.txt index.txt.attr index.txt.old crlnumber crlnumber.old ../replay/commands.asc
|
||||
|
||||
myCA.key:
|
||||
openssl genrsa -out myCA.key 2048
|
||||
|
||||
public-key.txt:
|
||||
rm -f ~/.gnupg/testkeys.gpg
|
||||
gpg --no-default-keyring --keyring testkeys.gpg --batch --generate-key gpg-key-conf
|
||||
gpg --no-default-keyring --keyring testkeys.gpg --armor --export joe@foo.bar | sed -e ':a' -e 'N' -e '$$!ba' -e 's/\n/\\n/g' | tr -d "\n" > public-key.txt
|
||||
rm -f ../replay/commands.asc
|
||||
test -f ../replay/commands && gpg --no-default-keyring --keyring testkeys.gpg --clear-sign -u joe@foo.bar ../replay/commands || echo
|
||||
rm -f ~/.gnupg/testkeys.gpg
|
||||
|
||||
ca.crt: myCA.key
|
||||
OPENSSL_CONF=ca.cnf openssl req -x509 -new -nodes -key myCA.key -sha256 -days 1825 -out ca.crt
|
||||
|
||||
server-key.pem:
|
||||
openssl genrsa -out server-key.pem 2048
|
||||
|
||||
server-csr.pem: server-key.pem
|
||||
OPENSSL_CONF=server.cnf openssl req -new -key server-key.pem -out server-csr.pem
|
||||
|
||||
client-key.pem:
|
||||
openssl genrsa -out client-key.pem 2048
|
||||
|
||||
badclient-csr.pem: client-key.pem
|
||||
OPENSSL_CONF=badclient.cnf openssl req -new -key client-key.pem -out badclient-csr.pem
|
||||
|
||||
client-csr.pem: client-key.pem
|
||||
OPENSSL_CONF=client.cnf openssl req -new -key client-key.pem -out client-csr.pem
|
||||
|
||||
server-cert.pem: server-csr.pem ca.crt
|
||||
openssl x509 -req -in server-csr.pem -CA ca.crt -CAkey myCA.key \
|
||||
-CAcreateserial -out server-cert.pem -days 1825 -sha256
|
||||
|
||||
client-cert.pem: client-csr.pem ca.crt
|
||||
openssl x509 -req -in client-csr.pem -CA ca.crt -CAkey myCA.key \
|
||||
-CAcreateserial -out client-cert.pem -days 1825 -sha256
|
||||
|
||||
badclient-cert.pem: badclient-csr.pem ca.crt
|
||||
openssl x509 -req -in badclient-csr.pem -CA ca.crt -CAkey myCA.key \
|
||||
-CAcreateserial -out badclient-cert.pem -days 1825 -sha256
|
||||
|
||||
webclient-key.pem:
|
||||
openssl genrsa -out webclient-key.pem 2048
|
||||
|
||||
webclient-csr.pem: webclient-key.pem
|
||||
OPENSSL_CONF=webclient.cnf openssl req -new -key webclient-key.pem -out webclient-csr.pem
|
||||
|
||||
webclient-cert.pem: webclient-csr.pem
|
||||
openssl x509 -req -in webclient-csr.pem -CA ca.crt -CAkey myCA.key \
|
||||
-CAcreateserial -out webclient-cert.pem -days 1825 -sha256
|
||||
|
||||
ldap-key.pem:
|
||||
openssl genrsa -out ldap-key.pem 2048
|
||||
|
||||
ldap-csr.pem: ldap-key.pem
|
||||
OPENSSL_CONF=ldap.cnf openssl req -new -key ldap-key.pem -out ldap-csr.pem
|
||||
|
||||
ldap-cert.pem: ldap-csr.pem
|
||||
openssl x509 -req -in ldap-csr.pem -CA ca.crt -CAkey myCA.key \
|
||||
-CAcreateserial -out ldap-cert.pem -days 1825 -sha256
|
||||
|
||||
index.txt: badclient-cert.pem
|
||||
touch index.txt
|
||||
echo 01 > crlnumber
|
||||
openssl ca -cert ca.crt -keyfile myCA.key -config ca.cnf -revoke badclient-cert.pem
|
||||
|
||||
root.crl.pem: index.txt
|
||||
openssl ca -config ca.cnf -gencrl -keyfile myCA.key -cert ca.crt -out root.crl.pem
|
||||
cat ca.crt >> root.crl.pem
|
||||
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# http://spin.atomicobject.com/2014/05/12/openssl-commands/
|
||||
[ req ]
|
||||
prompt = no
|
||||
default_bits = 2048
|
||||
encrypt_key = no
|
||||
distinguished_name = req_distinguished_name
|
||||
|
||||
string_mask = utf8only
|
||||
|
||||
#req_extensions = v3_req
|
||||
|
||||
[ req_distinguished_name ]
|
||||
C=FR
|
||||
ST=Ile de France
|
||||
L=Paris
|
||||
O=Dailymoton Fake PKI
|
||||
CN=invalidserver
|
||||
|
||||
#[ v3_req ]
|
||||
|
||||
#basicConstraints = CA:FALSE
|
||||
#keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
|
@ -0,0 +1,40 @@
|
|||
# http://spin.atomicobject.com/2014/05/12/openssl-commands/
|
||||
[ req ]
|
||||
prompt = no
|
||||
default_bits = 2048
|
||||
encrypt_key = no
|
||||
distinguished_name = req_distinguished_name
|
||||
|
||||
[ca]
|
||||
default_ca = CA_default
|
||||
|
||||
[ CA_default ]
|
||||
# Directory and file locations.
|
||||
dir = "."
|
||||
certs = $dir
|
||||
crl_dir = $dir
|
||||
new_certs_dir = $dir
|
||||
database = $dir/index.txt
|
||||
serial = $dir/serial
|
||||
RANDFILE = $dir/.rand
|
||||
string_mask = utf8only
|
||||
default_md = sha256
|
||||
|
||||
# For certificate revocation lists.
|
||||
crlnumber = $dir/crlnumber
|
||||
crl = $dir/ca.crl.pem
|
||||
default_crl_days = 30
|
||||
|
||||
#req_extensions = v3_req
|
||||
|
||||
[ req_distinguished_name ]
|
||||
C=FR
|
||||
ST=Ile de France
|
||||
L=Paris
|
||||
O=Dailymotion Fake PKI
|
||||
CN=Dailymotion Fake CA
|
||||
|
||||
#[ v3_req ]
|
||||
|
||||
#basicConstraints = CA:FALSE
|
||||
#keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
|
@ -0,0 +1,22 @@
|
|||
# http://spin.atomicobject.com/2014/05/12/openssl-commands/
|
||||
[ req ]
|
||||
prompt = no
|
||||
default_bits = 2048
|
||||
encrypt_key = no
|
||||
distinguished_name = req_distinguished_name
|
||||
|
||||
string_mask = utf8only
|
||||
|
||||
#req_extensions = v3_req
|
||||
|
||||
[ req_distinguished_name ]
|
||||
C=FR
|
||||
ST=Ile de France
|
||||
L=Paris
|
||||
O=Dailymoton Fake PKI
|
||||
CN=validserver
|
||||
|
||||
#[ v3_req ]
|
||||
|
||||
#basicConstraints = CA:FALSE
|
||||
#keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
|
@ -0,0 +1,13 @@
|
|||
%echo Generating a basic OpenPGP key
|
||||
%no-protection
|
||||
Key-Type: DSA
|
||||
Key-Length: 1024
|
||||
Subkey-Type: ELG-E
|
||||
Subkey-Length: 1024
|
||||
Name-Real: Joe Tester
|
||||
Name-Comment: with stupid passphrase
|
||||
Name-Email: joe@foo.bar
|
||||
Expire-Date: 0
|
||||
# Do a commit here, so that we can later print "done" :-)
|
||||
%commit
|
||||
%echo done
|
|
@ -0,0 +1,28 @@
|
|||
# http://spin.atomicobject.com/2014/05/12/openssl-commands/
|
||||
[ req ]
|
||||
prompt = no
|
||||
default_bits = 2048
|
||||
encrypt_key = no
|
||||
distinguished_name = req_distinguished_name
|
||||
|
||||
string_mask = utf8only
|
||||
|
||||
#req_extensions = v3_req
|
||||
|
||||
[ req_distinguished_name ]
|
||||
C=FR
|
||||
ST=Ile de France
|
||||
L=Paris
|
||||
O=Example Fake PKI
|
||||
CN=ldap.example.org
|
||||
|
||||
[ req_ext ]
|
||||
subjectAltName = @alt_names
|
||||
|
||||
[ alt_names ]
|
||||
DNS.1 = localhost
|
||||
DNS.2 = ldap
|
||||
#[ v3_req ]
|
||||
|
||||
#basicConstraints = CA:FALSE
|
||||
#keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
|
@ -0,0 +1,28 @@
|
|||
# http://spin.atomicobject.com/2014/05/12/openssl-commands/
|
||||
[ req ]
|
||||
prompt = no
|
||||
default_bits = 2048
|
||||
encrypt_key = no
|
||||
distinguished_name = req_distinguished_name
|
||||
|
||||
string_mask = utf8only
|
||||
|
||||
req_extensions = v3_req
|
||||
subjectAltName = @alt_names
|
||||
|
||||
[ req_distinguished_name ]
|
||||
C=FR
|
||||
ST=Ile de France
|
||||
L=Paris
|
||||
O=Dailymoton Fake PKI
|
||||
CN=localhost
|
||||
|
||||
[ v3_req ]
|
||||
basicConstraints = CA:FALSE
|
||||
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
||||
|
||||
[ alt_names ]
|
||||
DNS.1 = ldap
|
||||
DNS.2 = localhost
|
||||
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# http://spin.atomicobject.com/2014/05/12/openssl-commands/
|
||||
[ req ]
|
||||
prompt = no
|
||||
default_bits = 2048
|
||||
encrypt_key = no
|
||||
distinguished_name = req_distinguished_name
|
||||
|
||||
string_mask = utf8only
|
||||
|
||||
#req_extensions = v3_req
|
||||
|
||||
[ req_distinguished_name ]
|
||||
C=FR
|
||||
ST=Ile de France
|
||||
L=Paris
|
||||
O=Example Fake PKI
|
||||
CN=xavier@example.org
|
||||
OU=Engineering
|
||||
emailAddress=security@example.org
|
||||
|
||||
#[ v3_req ]
|
||||
|
||||
#basicConstraints = CA:FALSE
|
||||
#keyUsage = nonRepudiation, digitalSignature, keyEncipherment
|
||||
#Subject: CN = xaver@example.org, C = FR, ST = Ile de France, L = Paris, O = Example, OU = Engineering, emailAddress = security@example.org
|
|
@ -0,0 +1,17 @@
|
|||
module git.euclide.org/euclide/pdns-auth-proxy
|
||||
|
||||
go 1.21.1
|
||||
|
||||
require (
|
||||
github.com/go-ldap/ldap/v3 v3.4.6
|
||||
github.com/jarcoal/httpmock v1.3.1
|
||||
github.com/pyke369/golang-support v0.0.0-20231112163947-ff7c18596096
|
||||
golang.org/x/crypto v0.15.0
|
||||
gopkg.in/djherbis/times.v1 v1.3.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.5 // indirect
|
||||
github.com/google/uuid v1.3.1 // indirect
|
||||
)
|
|
@ -0,0 +1,72 @@
|
|||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
|
||||
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 h1:Kk6a4nehpJ3UuJRqlA3JxYxBZEqCeOmATOvrbT4p9RA=
|
||||
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.5 h1:MNHlNMBDgEKD4TcKr36vQN68BA00aDfjIt3/bD50WnA=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.5/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||
github.com/go-ldap/ldap/v3 v3.4.6 h1:ert95MdbiG7aWo/oPYp9btL3KJlMPKnP58r09rI8T+A=
|
||||
github.com/go-ldap/ldap/v3 v3.4.6/go.mod h1:IGMQANNtxpsOzj7uUAMjpGBaOVTC4DYyIy8VsTdxmtc=
|
||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww=
|
||||
github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg=
|
||||
github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g=
|
||||
github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pyke369/golang-support v0.0.0-20231112163947-ff7c18596096 h1:Nyi3FoyXqAre6ciqcG5to+hbu145HdYEQljA0lznr/U=
|
||||
github.com/pyke369/golang-support v0.0.0-20231112163947-ff7c18596096/go.mod h1:851u6g/3itVw+DBjbypRl7zzpXig+I4g2ndBiwsUUm8=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA=
|
||||
golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/djherbis/times.v1 v1.3.0 h1:uxMS4iMtH6Pwsxog094W0FYldiNnfY/xba00vq6C2+o=
|
||||
gopkg.in/djherbis/times.v1 v1.3.0/go.mod h1:AQlg6unIsrsCEdQYhTzERy542dz6SFdQFZFv6mUY0P8=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
@ -0,0 +1,567 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"embed"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/openpgp/clearsign"
|
||||
"gopkg.in/djherbis/times.v1"
|
||||
)
|
||||
|
||||
type (
|
||||
// HTTPServer is the webservice main object with its configuration parameters
|
||||
HTTPServer struct {
|
||||
Port string
|
||||
key string
|
||||
cert string
|
||||
crl string
|
||||
decodedCrl *pkix.CertificateList
|
||||
decodedCA []*x509.Certificate
|
||||
decodedCrlTime time.Time
|
||||
authProfiles map[string]AuthProfile
|
||||
pdnsAcls []*PdnsACL
|
||||
jrpcAPIACls []*JSONRPCACL
|
||||
certPool *x509.CertPool
|
||||
debug bool
|
||||
m sync.RWMutex
|
||||
dns *PowerDNS
|
||||
nonceGen string
|
||||
certCache map[string]time.Time
|
||||
zoneProfiles map[string]*zoneProfile
|
||||
}
|
||||
zoneProfile struct {
|
||||
Default bool
|
||||
NameServers []string
|
||||
SOA string
|
||||
DefaultEntries []*JSONInput
|
||||
Regexp []*regexp.Regexp
|
||||
AutoInc bool
|
||||
}
|
||||
)
|
||||
|
||||
// initalize local files
|
||||
//
|
||||
//go:embed web/*
|
||||
var embeddedFS embed.FS
|
||||
|
||||
// JSONRPCNewError return a valid json rpc 2.0 error
|
||||
func JSONRPCNewError(code, id int, message string) JSONRPCError {
|
||||
e := JSONRPCError{ID: id, JSONRPC: "2.0"}
|
||||
e.Error.Code = code
|
||||
e.Error.Message = message
|
||||
return e
|
||||
}
|
||||
|
||||
func (e JSONRPCError) String() string {
|
||||
return string(printJSON(e))
|
||||
}
|
||||
|
||||
// NewHTTPServer initializes HTTPServer
|
||||
func NewHTTPServer(port, key, cert, crl, ca, pdnsServer, pdnsKey string, timeout, ttl int) *HTTPServer {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
h := HTTPServer{
|
||||
Port: port,
|
||||
key: key,
|
||||
cert: cert,
|
||||
crl: crl,
|
||||
nonceGen: NewSalt(25),
|
||||
zoneProfiles: map[string]*zoneProfile{},
|
||||
certCache: map[string]time.Time{},
|
||||
authProfiles: map[string]AuthProfile{},
|
||||
pdnsAcls: []*PdnsACL{},
|
||||
certPool: x509.NewCertPool(),
|
||||
decodedCA: []*x509.Certificate{},
|
||||
}
|
||||
|
||||
rawCA, err := ioutil.ReadFile(ca)
|
||||
if err != nil {
|
||||
log.Fatal("ca:", err)
|
||||
}
|
||||
for len(rawCA) > 0 {
|
||||
var block *pem.Block
|
||||
block, rawCA = pem.Decode(rawCA)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
|
||||
continue
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
h.certPool.AddCert(cert)
|
||||
h.decodedCA = append(h.decodedCA, cert)
|
||||
}
|
||||
if h.dns, err = NewClient(pdnsServer, pdnsKey, timeout, ttl); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if _, err := h.RefreshCRL(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return &h
|
||||
}
|
||||
|
||||
// NewZoneProfile add zone profiles to the structure
|
||||
func (h *HTTPServer) NewZoneProfile(zoneType, soa string, isDefault, autoInc bool, nameServers, rules []string) error {
|
||||
for t, profile := range h.zoneProfiles {
|
||||
if t == zoneType {
|
||||
return fmt.Errorf("zone type %s already defined", t)
|
||||
}
|
||||
if isDefault && profile.Default {
|
||||
return fmt.Errorf("zone type %s is already the default one", t)
|
||||
}
|
||||
}
|
||||
if len(nameServers) == 0 {
|
||||
return errors.New("no nameservers in the configuration")
|
||||
}
|
||||
for i := range nameServers {
|
||||
nameServers[i] = addPoint(nameServers[i])
|
||||
}
|
||||
z := &zoneProfile{
|
||||
Default: isDefault,
|
||||
NameServers: nameServers,
|
||||
Regexp: []*regexp.Regexp{},
|
||||
SOA: soa,
|
||||
AutoInc: autoInc,
|
||||
}
|
||||
for _, r := range rules {
|
||||
re, err := regexp.Compile(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
z.Regexp = append(z.Regexp, re)
|
||||
}
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
h.zoneProfiles[zoneType] = z
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *zoneProfile) addDefaultEntry(name, action, value string) {
|
||||
params := JSONInputParams{
|
||||
Name: name,
|
||||
Value: value,
|
||||
TTL: 172800,
|
||||
}
|
||||
input := &JSONInput{Method: action, ignoreBadDomain: true, Params: params}
|
||||
z.DefaultEntries = append(z.DefaultEntries, input)
|
||||
}
|
||||
|
||||
// Lock the structure
|
||||
func (h *HTTPServer) lock() {
|
||||
h.m.Lock()
|
||||
}
|
||||
|
||||
// ...or unlock it
|
||||
func (h *HTTPServer) unlock() {
|
||||
h.m.Unlock()
|
||||
}
|
||||
|
||||
// Debug facilities
|
||||
func (h *HTTPServer) Debug() {
|
||||
h.debug = true
|
||||
h.dns.Debug()
|
||||
}
|
||||
|
||||
// verifyNonce check that the once is valid and less than 10s old
|
||||
func (h *HTTPServer) verifyNonce(nonce string) bool {
|
||||
// cannot appear in production, but useful for tests
|
||||
if len(h.nonceGen) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(nonce) < 47 {
|
||||
return false
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
salt := nonce[:4]
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
if salt+ComputeHmac256(fmt.Sprintf("%s%d", salt, now-int64(i)), h.nonceGen) == nonce {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sendNonce return a valid nonce to the user
|
||||
func (h *HTTPServer) sendNonce(w http.ResponseWriter, r *http.Request) {
|
||||
now := time.Now().Unix()
|
||||
salt := NewSalt(4)
|
||||
fmt.Fprintf(w, salt+ComputeHmac256(fmt.Sprintf("%s%d", salt, now), h.nonceGen))
|
||||
}
|
||||
|
||||
// AddAuthProfile adds a profile in the server config
|
||||
func (h *HTTPServer) AddAuthProfile(name string, p AuthProfile) {
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
h.authProfiles[name] = p
|
||||
}
|
||||
|
||||
// AddPdnsACL adds an acl in the server config
|
||||
func (h *HTTPServer) AddPdnsACL(a *PdnsACL) {
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
h.pdnsAcls = append(h.pdnsAcls, a)
|
||||
}
|
||||
|
||||
// AddjsonRPCACL adds a JRPC acl in the server config
|
||||
func (h *HTTPServer) AddjsonRPCACL(a *JSONRPCACL) {
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
h.jrpcAPIACls = append(h.jrpcAPIACls, a)
|
||||
}
|
||||
|
||||
// getPgpProfiles returns the list of profiles validated by the signed payload,
|
||||
// and the name of the cert
|
||||
func (h *HTTPServer) getPgpProfiles(message, signature []byte) (string, []string) {
|
||||
valid := []string{}
|
||||
signer := ""
|
||||
for name, p := range h.authProfiles {
|
||||
if keyUser := PgpMessageVerify(message, signature, p.PgpKeys()); keyUser != "" {
|
||||
valid = append(valid, name)
|
||||
signer = keyUser
|
||||
}
|
||||
}
|
||||
return signer, valid
|
||||
}
|
||||
|
||||
// getProfiles returns the list of profiles validated for a certificate subject
|
||||
func (h *HTTPServer) getProfiles(subject string) []string {
|
||||
valid := []string{}
|
||||
for name, p := range h.authProfiles {
|
||||
ok, err := p.Match(subject)
|
||||
if err == nil && ok {
|
||||
valid = append(valid, name)
|
||||
}
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
// nativeValidAuth validates that a user can access a particular path with a specified method
|
||||
func (h *HTTPServer) nativeValidAuth(path, user, method string) bool {
|
||||
profiles := h.getProfiles(user)
|
||||
|
||||
// no profile validated, no need to continue
|
||||
if len(profiles) == 0 {
|
||||
return false
|
||||
}
|
||||
// check every acl
|
||||
for _, acl := range h.pdnsAcls {
|
||||
if acl.Match(path, method, profiles) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
log.Println("Could not find profile/acl for ", user, method, path)
|
||||
return false
|
||||
}
|
||||
|
||||
// nativeValidAuth validates that a user can access a particular path with a specified method
|
||||
func (h *HTTPServer) jsonrpcValidAuth(j JSONArray, message, signature []byte, certUser string) (bool, string) {
|
||||
// if the DryRun flag is set on all commands, let's authorize it
|
||||
// in debug mode, since we will do nothing
|
||||
if isDryRun(j) && h.debug && len(signature) == 0 {
|
||||
return true, certUser
|
||||
}
|
||||
|
||||
var pgpProfiles, sslProfiles []string
|
||||
|
||||
if len(j) == 0 {
|
||||
// nothing to check
|
||||
return false, ""
|
||||
}
|
||||
// try to get some profile
|
||||
switch {
|
||||
case h.verifyNonce(j[0].Params.Nonce):
|
||||
certUser, pgpProfiles = h.getPgpProfiles(message, signature)
|
||||
case certUser != "":
|
||||
sslProfiles = h.getProfiles(certUser)
|
||||
}
|
||||
// we need at least one profile
|
||||
if len(sslProfiles)+len(pgpProfiles) == 0 {
|
||||
log.Printf("[jsonRPC API] User %s was not authorized to execute anything", certUser)
|
||||
return false, certUser
|
||||
}
|
||||
// get the available acl for the user if there is a search or list
|
||||
listFilters := map[string][]*regexp.Regexp{}
|
||||
for _, acl := range h.jrpcAPIACls {
|
||||
for _, method := range []string{"list", "search"} {
|
||||
listFilters[method] = append(listFilters[method], acl.GetListFilters(pgpProfiles, sslProfiles, method)...)
|
||||
}
|
||||
}
|
||||
// check every acl on every action
|
||||
for i, action := range j {
|
||||
// for list and search, we do the checking after the fact
|
||||
if action.Method == "list" || action.Method == "search" {
|
||||
j[i].listFilters = listFilters[action.Method]
|
||||
continue
|
||||
}
|
||||
// the domain method is always permitted
|
||||
if action.Method == "domain" {
|
||||
continue
|
||||
}
|
||||
ok := false
|
||||
for _, acl := range h.jrpcAPIACls {
|
||||
if acl.Match(action.Method, action.Params.Name, pgpProfiles, sslProfiles) {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
log.Printf("[jsonRPC API] User %s was not authorized to execute \"%s\"", certUser, action.String())
|
||||
return false, certUser
|
||||
}
|
||||
}
|
||||
return true, certUser
|
||||
}
|
||||
|
||||
// Get Certificate user Name from the http.Request
|
||||
func (h *HTTPServer) getCertificate(r *http.Request) (bool, string, error) {
|
||||
// Check if the TLS user certificate is valid
|
||||
var sslUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageAny}
|
||||
opts := x509.VerifyOptions{Roots: h.certPool, KeyUsages: sslUsage}
|
||||
|
||||
if len(r.TLS.PeerCertificates) == 0 {
|
||||
return false, "", nil
|
||||
}
|
||||
if _, err := r.TLS.PeerCertificates[0].Verify(opts); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if err := h.checkCRL(r.TLS.PeerCertificates); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
return true, strings.Replace(r.TLS.PeerCertificates[0].Subject.CommonName, " ", "", -1), nil
|
||||
}
|
||||
|
||||
// decode the http request and return the payload, username and signature
|
||||
func (h *HTTPServer) jrpcDecodeQuery(body []byte) (JSONArray, []byte, []byte, bool, error) {
|
||||
// the payload can be a PGP signed payload
|
||||
// if it's not, message will contain the whole payload
|
||||
b, message := clearsign.Decode(body)
|
||||
// this is not a valid PGP signed payload, meaning message is all we got
|
||||
if b == nil {
|
||||
jsonRPC, wasArray, err := ParsejsonRPCRequest(message, h.dns)
|
||||
return jsonRPC, message, nil, wasArray, err
|
||||
}
|
||||
// this is a valid PGP signed payload, we can extract the real payload and
|
||||
// use the signature to authenticate
|
||||
message = b.Plaintext
|
||||
// there is a newline appended to the payload somehow
|
||||
if message[len(message)-1] == 10 {
|
||||
message = message[:len(message)-1]
|
||||
}
|
||||
// we need the signature to be a []byte and not an Reader, since we
|
||||
// may have to use it several times
|
||||
signature, err := ioutil.ReadAll(b.ArmoredSignature.Body)
|
||||
if err != nil {
|
||||
return JSONArray{}, message, signature, false, err
|
||||
}
|
||||
jsonRPC, wasArray, err := ParsejsonRPCRequest(message, h.dns)
|
||||
return jsonRPC, message, signature, wasArray, err
|
||||
}
|
||||
|
||||
// PowerDNS json RPC API support. Support Cert Auth or PGP signed messages
|
||||
func (h *HTTPServer) jsonRPCServe(w http.ResponseWriter, r *http.Request) {
|
||||
// This API is POST only
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, JSONRPCNewError(-32603, 0, "Internal error").String(), 405)
|
||||
return
|
||||
}
|
||||
w.Header().Set("content-type", "application/json")
|
||||
|
||||
username := ""
|
||||
_, username, certError := h.getCertificate(r)
|
||||
if certError != nil {
|
||||
http.Error(w, JSONRPCNewError(-32008, 0, certError.Error()).String(), 403)
|
||||
return
|
||||
}
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
http.Error(w, JSONRPCNewError(-32603, 0, "Internal error").String(), 500)
|
||||
return
|
||||
}
|
||||
// decode the body
|
||||
jsonRPC, message, signature, wasArray, err := h.jrpcDecodeQuery(body)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
http.Error(w, JSONRPCNewError(-32603, 0, "Internal error").String(), 500)
|
||||
return
|
||||
}
|
||||
// try to validate the query
|
||||
valid, username := h.jsonrpcValidAuth(jsonRPC, message, signature, username)
|
||||
if !valid {
|
||||
http.Error(w, JSONRPCNewError(-32004, 0, "You are not authorized").String(), 403)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, jsonRPC.Run(h, username, wasArray, r.Header.Get("PDNS-Output") == "plaintext"))
|
||||
}
|
||||
|
||||
// PowerDNS native API support. Add certificate support for auth
|
||||
func (h *HTTPServer) nativeAPIServe(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if the user/server is allowed to perform the required action
|
||||
ok, commonName, certError := h.getCertificate(r)
|
||||
if !ok {
|
||||
http.Error(w, certError.Error(), 403)
|
||||
return
|
||||
}
|
||||
log.Println("[Native API] User", commonName, "requested", r.Method, "on", r.RequestURI)
|
||||
if !h.nativeValidAuth(strings.TrimPrefix(r.RequestURI, h.dns.apiURL+"/"), commonName, r.Method) {
|
||||
http.Error(w, "The user "+commonName+" is not authorized to perform this action", 403)
|
||||
log.Println("[Native API] User ", commonName, " was not authorized to perform the action")
|
||||
return
|
||||
}
|
||||
h.dns.Proxy(w, r)
|
||||
}
|
||||
|
||||
// GetZoneConfig returns a configuration for a new zone
|
||||
func (h *HTTPServer) GetZoneConfig(s string) (string, string, []string, JSONArray, bool, error) {
|
||||
valid := ""
|
||||
def := JSONArray{}
|
||||
for zoneType, profile := range h.zoneProfiles {
|
||||
for _, re := range profile.Regexp {
|
||||
if !re.MatchString(s) {
|
||||
continue
|
||||
}
|
||||
if valid != "" {
|
||||
return "", "", nil, nil, false, errors.New("Multiple profiles matched, check the config")
|
||||
}
|
||||
valid = zoneType
|
||||
}
|
||||
}
|
||||
if valid != "" {
|
||||
profile := h.zoneProfiles[valid]
|
||||
for _, e := range profile.DefaultEntries {
|
||||
def = append(def, e)
|
||||
}
|
||||
return valid, profile.SOA, profile.NameServers, def, profile.AutoInc, nil
|
||||
}
|
||||
// check for the default
|
||||
for zoneType, profile := range h.zoneProfiles {
|
||||
if profile.Default {
|
||||
for _, e := range profile.DefaultEntries {
|
||||
def = append(def, e)
|
||||
}
|
||||
return zoneType, profile.SOA, profile.NameServers, def, profile.AutoInc, nil
|
||||
}
|
||||
}
|
||||
return "", "", nil, nil, false, errors.New("no valid configuration found")
|
||||
}
|
||||
|
||||
// RefreshCRL decodes the crl file if it's configured, and store it for later use
|
||||
func (h *HTTPServer) RefreshCRL() (*pkix.CertificateList, error) {
|
||||
// if there is no crl, no need to do anything
|
||||
if h.crl == "" {
|
||||
return nil, nil
|
||||
}
|
||||
t, err := times.Stat(h.crl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mtime := t.ModTime()
|
||||
|
||||
if h.decodedCrlTime.Equal(mtime) {
|
||||
return h.decodedCrl, nil
|
||||
}
|
||||
rawCrl, err := ioutil.ReadFile(h.crl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decoded, err := x509.ParseCRL(rawCrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ok := false
|
||||
for _, ca := range h.decodedCA {
|
||||
if err := ca.CheckCRLSignature(decoded); err == nil {
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("CRL issued with the wrong CA")
|
||||
}
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
h.decodedCrlTime = mtime
|
||||
h.decodedCrl = decoded
|
||||
return h.decodedCrl, nil
|
||||
}
|
||||
|
||||
func (h *HTTPServer) checkCRL(allCerts []*x509.Certificate) error {
|
||||
crl, err := h.RefreshCRL()
|
||||
if err != nil || crl == nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
expire := now.Add(24 * time.Hour)
|
||||
for _, cert := range allCerts {
|
||||
sign := base64.StdEncoding.EncodeToString(cert.Signature)
|
||||
if t, ok := h.certCache[sign]; ok && t.Before(expire) {
|
||||
continue
|
||||
}
|
||||
for _, revoked := range h.decodedCrl.TBSCertList.RevokedCertificates {
|
||||
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 {
|
||||
return fmt.Errorf("Certificate for %s is revoked", cert.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
h.certCache[sign] = now
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run the server
|
||||
func (h *HTTPServer) Run() {
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{
|
||||
Addr: h.Port,
|
||||
Handler: mux,
|
||||
TLSConfig: &tls.Config{
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
ClientCAs: h.certPool,
|
||||
},
|
||||
}
|
||||
|
||||
serverRoot, err := fs.Sub(embeddedFS, "static")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// give acces part of the native API
|
||||
mux.HandleFunc("/api/v1/servers/localhost/", h.nativeAPIServe)
|
||||
// but not all
|
||||
mux.Handle("/api/v1/", http.NotFoundHandler())
|
||||
// new json RPC api
|
||||
mux.HandleFunc("/jsonrpc", h.jsonRPCServe)
|
||||
// needed for security
|
||||
mux.HandleFunc("/nonce", h.sendNonce)
|
||||
// status page
|
||||
mux.HandleFunc("/stats/", h.nativeAPIServe)
|
||||
mux.HandleFunc("/style.css", h.nativeAPIServe)
|
||||
if _, err := os.Stat("./web"); err == nil && h.debug {
|
||||
h.dns.LogDebug("serving local files")
|
||||
mux.Handle("/", http.FileServer(http.Dir("web")))
|
||||
} else {
|
||||
mux.Handle("/", http.FileServer(http.FS(serverRoot)))
|
||||
}
|
||||
log.Println("Ready to serve")
|
||||
log.Fatal(server.ListenAndServeTLS(h.cert, h.key))
|
||||
}
|
|
@ -0,0 +1,699 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type (
|
||||
// JSONArray is an array of JSONInput
|
||||
JSONArray []*JSONInput
|
||||
|
||||
// JSONRPCResponse is a jsonRPC response structure
|
||||
JSONRPCResponse struct {
|
||||
ID int `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Result []*JSONRPCResult `json:"result"`
|
||||
}
|
||||
// JSONRPCError is a jsonRPC error structure
|
||||
JSONRPCError struct {
|
||||
ID int `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// JSONInput is a json rpc 2.0 compatible structure for the PDNS API
|
||||
JSONInput struct {
|
||||
Method string `json:"method"`
|
||||
Params JSONInputParams `json:"params"`
|
||||
ID int `json:"id"`
|
||||
ignoreBadDomain bool
|
||||
user string
|
||||
listFilters []*regexp.Regexp
|
||||
}
|
||||
// JSONInputParams encode the parameters of the method
|
||||
JSONInputParams struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
TTL int `json:"ttl"`
|
||||
ForceReverse bool `json:"reverse"`
|
||||
Append bool `json:"append"`
|
||||
Priority int `json:"priority"`
|
||||
Comment string `json:"comment"`
|
||||
DryRun bool `json:"dry-run"`
|
||||
IgnoreError bool `json:"ignore-error"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
// JSONRPCResult is the type of the response of the API
|
||||
JSONRPCResult struct {
|
||||
Changes string `json:"changes"`
|
||||
Comment string `json:"comment"`
|
||||
Result string `json:"result"`
|
||||
Raw []interface{} `json:"raw"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
)
|
||||
|
||||
// JSONRPCResult creates new result
|
||||
func (j JSONInput) JSONRPCResult(content, comment string, err error) *JSONRPCResult {
|
||||
ret := &JSONRPCResult{
|
||||
Changes: content,
|
||||
Comment: comment,
|
||||
Raw: []interface{}{},
|
||||
}
|
||||
if err != nil {
|
||||
ret.Comment = "There was an error"
|
||||
ret.Error = err.Error()
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// IsError returns if this response is an error or not
|
||||
func (r *JSONRPCResult) IsError() bool {
|
||||
return len(r.Error) > 0
|
||||
}
|
||||
|
||||
// JSONInput try to convert back a JSONInput to the original command line
|
||||
func (j JSONInput) String() string {
|
||||
ret := ""
|
||||
if j.Params.TTL != 172800 || j.Method == "ttl" {
|
||||
ret += fmt.Sprintf("-t %d ", j.Params.TTL)
|
||||
}
|
||||
if j.Params.ForceReverse {
|
||||
ret += "-f "
|
||||
}
|
||||
if j.Params.Append {
|
||||
ret += "-a "
|
||||
}
|
||||
if j.Method == "mx" {
|
||||
ret += fmt.Sprintf("-p %d ", j.Params.Priority)
|
||||
}
|
||||
if len(j.Params.Comment) > 0 {
|
||||
ret += fmt.Sprintf("-c \"%s\" ", j.Params.Comment)
|
||||
}
|
||||
if j.Params.DryRun {
|
||||
ret += "-n "
|
||||
}
|
||||
return fmt.Sprintf("%s %s %s %s", ret, j.Method, j.Params.Name, j.Params.Value)
|
||||
}
|
||||
|
||||
// SetDefaults modify the input by setting the right default parameters
|
||||
func (ja JSONArray) SetDefaults(p *PowerDNS) JSONArray {
|
||||
for i, j := range ja {
|
||||
if j.Params.TTL == 0 {
|
||||
ja[i].Params.TTL = p.DefaultTTL
|
||||
}
|
||||
}
|
||||
return ja
|
||||
}
|
||||
|
||||
// Run the actions defined in the json structure
|
||||
func (ja JSONArray) Run(h *HTTPServer, user string, wasArray, textOnly bool) string {
|
||||
listResult := []interface{}{}
|
||||
plain := ""
|
||||
for _, j := range ja {
|
||||
result := j.Run(h, user)
|
||||
for _, line := range result {
|
||||
plain = plain + line.Changes
|
||||
}
|
||||
log.Printf("[jsonRPC API] User %s used command \"jsonrpc %s\"\n", user, j.String())
|
||||
if len(result) == 0 {
|
||||
listResult = append(listResult, JSONRPCResponse{ID: j.ID, JSONRPC: "2.0"})
|
||||
continue
|
||||
}
|
||||
// check if the last entry of the result is an error
|
||||
last := result[len(result)-1]
|
||||
if !last.IsError() {
|
||||
listResult = append(listResult, JSONRPCResponse{ID: j.ID, JSONRPC: "2.0", Result: result})
|
||||
continue
|
||||
}
|
||||
listResult = append(listResult, JSONRPCNewError(-32000, j.ID, last.Error))
|
||||
h.dns.LogDebug(last.Error)
|
||||
if !j.Params.IgnoreError {
|
||||
break
|
||||
}
|
||||
}
|
||||
if textOnly {
|
||||
return plain
|
||||
}
|
||||
if wasArray {
|
||||
return string(printJSON(listResult))
|
||||
}
|
||||
if len(listResult) > 0 {
|
||||
return string(printJSON(listResult[0]))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Run the action defined in the json structure
|
||||
func (j *JSONInput) Run(h *HTTPServer, user string) []*JSONRPCResult {
|
||||
// store the username
|
||||
j.user = user
|
||||
ret := []*JSONRPCResult{}
|
||||
// normalize the query, and do some checks
|
||||
if err := j.Normalize(); err != nil {
|
||||
return append(ret, j.JSONRPCResult("", "", err))
|
||||
}
|
||||
switch j.Method {
|
||||
// list is a spacial case, it doesn't imply a DNSQuery() object
|
||||
case "list":
|
||||
result, err := h.dns.ListZones(j.Params.Name)
|
||||
// we apply the acl after the fact for the list method
|
||||
result = j.FilterList(result)
|
||||
|
||||
if err == nil && len(result) == 0 {
|
||||
err = errors.New("Unknown domain")
|
||||
}
|
||||
res := j.JSONRPCResult(result.List("\n"), "", err)
|
||||
for i := range result {
|
||||
res.Raw = append(res.Raw, result[i].Name)
|
||||
}
|
||||
return append(ret, res)
|
||||
case "domain":
|
||||
parentName, err := h.dns.GetDomain(j.Params.Name)
|
||||
res := j.JSONRPCResult(parentName, "", err)
|
||||
res.Raw = append(res.Raw, parentName)
|
||||
return append(ret, res)
|
||||
}
|
||||
actions, err := j.DNSQueries(h)
|
||||
if err != nil {
|
||||
return append(ret, j.JSONRPCResult("", "", err))
|
||||
}
|
||||
for _, act := range actions {
|
||||
// always add a comment
|
||||
if j.Params.Comment == "" && !j.Params.DryRun {
|
||||
j.Params.Comment = "-"
|
||||
}
|
||||
// add comment, ttl, and cleanup the payload
|
||||
if j.Method != "dump" && j.Method != "search" {
|
||||
act.AddCommentAndTTL(user, j.Params.Comment, j.Params.TTL)
|
||||
}
|
||||
result := j.JSONRPCResult("", strings.Join(act.PlainTexts(), "\n"), nil)
|
||||
result.Changes = act.String()
|
||||
result.Raw = append(result.Raw, act)
|
||||
|
||||
// if we are in dry run mode, stop here
|
||||
if j.Params.DryRun {
|
||||
result.Result = "Dry Run, nothing done"
|
||||
ret = append(ret, result)
|
||||
continue
|
||||
}
|
||||
// no domain, it will failed if executed
|
||||
// can be the output of a zone creation
|
||||
if !act.HasDomain() {
|
||||
ret = append(ret, result)
|
||||
continue
|
||||
}
|
||||
code, _, err := h.dns.Execute(act)
|
||||
switch {
|
||||
case err == nil && code == 204:
|
||||
result.Result = "Command Successfull"
|
||||
case err != nil:
|
||||
result.Error = err.Error()
|
||||
default:
|
||||
result.Error = fmt.Sprintf("The return code was %d", code)
|
||||
}
|
||||
ret = append(ret, result)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Normalize make some case change on the JSONInput and limit the wildcards in
|
||||
// the name
|
||||
func (j *JSONInput) Normalize() error {
|
||||
// normalize Name (lower case, add a final .)
|
||||
j.Params.Name = strings.ToLower(j.Params.Name)
|
||||
j.Params.Name = addPoint(j.Params.Name)
|
||||
if j.Method != "txt" {
|
||||
j.Params.Value = strings.ToLower(j.Params.Value)
|
||||
}
|
||||
switch j.Method {
|
||||
case "list":
|
||||
case "search":
|
||||
case "newzone":
|
||||
if !validName(j.Params.Name, false) {
|
||||
return errors.New("invalid name")
|
||||
}
|
||||
case "ptr":
|
||||
if !validName(j.Params.Name, false) {
|
||||
return errors.New("invalid name")
|
||||
}
|
||||
case "srv":
|
||||
if !validSRVName(j.Params.Name) {
|
||||
return errors.New("invalid name")
|
||||
}
|
||||
case "txt":
|
||||
j.Params.Value = addQuotes(j.Params.Value)
|
||||
case "domain":
|
||||
if !validName(j.Params.Name, false) {
|
||||
return errors.New("invalid name")
|
||||
}
|
||||
default:
|
||||
if !validName(j.Params.Name, true) {
|
||||
return errors.New("invalid name")
|
||||
}
|
||||
}
|
||||
// add a final . to the value
|
||||
for _, m := range []string{"cname", "mx", "dname", "ns", "ptr", "srv"} {
|
||||
if j.Method == m {
|
||||
j.Params.Value = addPoint(j.Params.Value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DNSQueries takes a JSONInput and returns a usable []DNSQuery to be sent to
|
||||
// pdns. It can change the content of j to force dry run mode
|
||||
func (j *JSONInput) DNSQueries(h *HTTPServer) ([]*DNSQuery, error) {
|
||||
var err error
|
||||
switch j.Method {
|
||||
case "search":
|
||||
result, err := h.dns.Search(j.Params.Name)
|
||||
// we apply the acl after the fact for the search method
|
||||
result = j.FilterSearch(result)
|
||||
j.Params.DryRun = true
|
||||
return []*DNSQuery{result.DNSQuery()}, err
|
||||
case "dump":
|
||||
result, err := h.dns.Zone(j.Params.Name)
|
||||
j.Params.DryRun = true
|
||||
return []*DNSQuery{result}, err
|
||||
case "newzone":
|
||||
newZone, otherActions, err := j.NewZone(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if j.Params.DryRun {
|
||||
return append(otherActions, newZone.TransformIntoDNSQuery()), nil
|
||||
}
|
||||
result := &DNSQuery{}
|
||||
code, _, err := h.dns.ExecuteZone(newZone, result)
|
||||
if err == nil && code != 201 {
|
||||
err = fmt.Errorf("The return code was %d for the creation of zone %s", code, j.Params.Name)
|
||||
}
|
||||
return append(otherActions, result), err
|
||||
}
|
||||
current, err := h.dns.GetRecord(j.Params.Name, j.Method, j.ignoreBadDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch j.Method {
|
||||
case "ttl":
|
||||
return j.DNSQueriesTTL(current)
|
||||
case "delete":
|
||||
return j.DNSQueriesDelete(h, current)
|
||||
}
|
||||
// test if there is something to add
|
||||
if current.Useless(j.Params.Name, j.Params.Value, j.Params.Append) {
|
||||
j.Params.DryRun = true
|
||||
current.AddPlainText(0, "Nothing to do, the record is unchanged")
|
||||
return []*DNSQuery{current}, nil
|
||||
}
|
||||
switch j.Method {
|
||||
case "ns":
|
||||
return j.DNSQueriesNS(h, current)
|
||||
case "a":
|
||||
return j.DNSQueriesA(h, current)
|
||||
case "aaaa":
|
||||
return j.DNSQueriesA(h, current)
|
||||
case "cname":
|
||||
if err = j.CheckCNAME(h); err == nil {
|
||||
return j.DNSQueriesGeneric(current)
|
||||
}
|
||||
case "dname":
|
||||
if err = j.CheckCNAME(h); err == nil {
|
||||
return j.DNSQueriesGeneric(current)
|
||||
}
|
||||
case "mx":
|
||||
if err = j.CheckMX(h); err == nil {
|
||||
return j.DNSQueriesGeneric(current)
|
||||
}
|
||||
case "srv":
|
||||
if err = j.CheckSRV(h); err == nil {
|
||||
return j.DNSQueriesGeneric(current)
|
||||
}
|
||||
case "ptr":
|
||||
if err = j.CheckPTR(h); err == nil {
|
||||
return j.DNSQueriesGeneric(current)
|
||||
}
|
||||
case "caa":
|
||||
if err = j.CheckCAA(); err == nil {
|
||||
return j.DNSQueriesGeneric(current)
|
||||
}
|
||||
case "txt":
|
||||
return j.DNSQueriesGeneric(current)
|
||||
default:
|
||||
err = errors.New("unknown action")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// DNSQueriesGeneric is the DNSQueries method for the most commands
|
||||
func (j *JSONInput) DNSQueriesGeneric(current *DNSQuery) ([]*DNSQuery, error) {
|
||||
// just change the value
|
||||
current.ChangeValue(j.Params.Name, j.Params.Value, j.Method,
|
||||
current.Len() > 0 && !j.Params.Append, false)
|
||||
return []*DNSQuery{current}, nil
|
||||
}
|
||||
|
||||
// DNSQueriesNS change the NS of a record
|
||||
func (j *JSONInput) DNSQueriesNS(h *HTTPServer, current *DNSQuery) ([]*DNSQuery, error) {
|
||||
if trimPoint(j.Params.Name) == trimPoint(current.Domain) {
|
||||
return nil, fmt.Errorf("You cannot change the NS of a local zone")
|
||||
}
|
||||
subZone, err := h.dns.Zone(j.Params.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentNS := map[string]bool{j.Params.Value: true}
|
||||
currentGlue := map[string]bool{}
|
||||
for _, entry := range subZone.RRSets {
|
||||
if entry.Type == "NS" && entry.Name == j.Params.Name {
|
||||
for _, v := range entry.Records {
|
||||
currentNS[v.Content] = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if entry.Type == "A" || entry.Type == "AAAA" {
|
||||
currentGlue[entry.Name] = true
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"There are records that will be masked if you delegate the zone\nPlease delete them first")
|
||||
}
|
||||
fmt.Println(currentNS, currentGlue)
|
||||
for ns := range currentNS {
|
||||
if !strings.HasSuffix(ns, "."+addPoint(j.Params.Name)) {
|
||||
continue
|
||||
}
|
||||
if _, ok := currentGlue[ns]; !ok {
|
||||
return nil, fmt.Errorf("You must first create a glue record to resolve %s", j.Params.Value)
|
||||
}
|
||||
}
|
||||
for glue := range currentGlue {
|
||||
if _, ok := currentNS[glue]; !ok {
|
||||
return nil, fmt.Errorf(
|
||||
"There are records that will be masked if you delegate the zone\nPlease delete them first")
|
||||
}
|
||||
}
|
||||
current.ChangeValue(j.Params.Name, j.Params.Value, j.Method, !j.Params.Append, false)
|
||||
return []*DNSQuery{current}, nil
|
||||
}
|
||||
|
||||
// DNSQueriesTTL change the TTL of a record
|
||||
func (j *JSONInput) DNSQueriesTTL(current *DNSQuery) ([]*DNSQuery, error) {
|
||||
todo := []int{}
|
||||
for i := range current.RRSets {
|
||||
if current.RRSets[i].TTL == j.Params.TTL {
|
||||
continue
|
||||
}
|
||||
if j.Params.Value == "" {
|
||||
todo = append(todo, i)
|
||||
continue
|
||||
}
|
||||
for _, v := range current.RRSets[i].Records {
|
||||
if v.Content == j.Params.Value {
|
||||
todo = append(todo, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(todo) == 0 {
|
||||
j.Params.DryRun = true
|
||||
return nil, fmt.Errorf("Nothing to do, the record is unchanged")
|
||||
}
|
||||
newRRSets := []*DNSRRSet{}
|
||||
for _, i := range todo {
|
||||
newRRSets = append(newRRSets, current.RRSets[i])
|
||||
}
|
||||
current.RRSets = newRRSets
|
||||
return []*DNSQuery{current}, nil
|
||||
}
|
||||
|
||||
// DNSQueriesA is the DNSQueries method for the A and AAAA commands
|
||||
// it checks the command is legal, and can make reverse DNS changes if needed
|
||||
func (j *JSONInput) DNSQueriesA(h *HTTPServer, forward *DNSQuery) ([]*DNSQuery, error) {
|
||||
reverse, err := h.dns.GetReverse(j.Params.Value, j.Method == "a")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If we manage the reverse .arpa zone, we set it if either there is no
|
||||
// reverse, or the forceReverse parameter is set
|
||||
askForReverse := reverse != nil && (j.Params.ForceReverse || reverse.Len() == 0)
|
||||
if askForReverse && j.Params.Name[0] == '*' {
|
||||
return nil, errors.New("Can't set a reverse to a wildcard")
|
||||
}
|
||||
// simple case : the name didn't exist, or we do an append
|
||||
if forward.Len() == 0 || j.Params.Append {
|
||||
forward.ChangeValue(j.Params.Name, j.Params.Value, j.Method, false, askForReverse)
|
||||
return []*DNSQuery{forward}, nil
|
||||
}
|
||||
actions, err := h.dns.ReverseChanges(forward, j.Params.Value)
|
||||
forward.ChangeValue(j.Params.Name, j.Params.Value, j.Method, true, askForReverse)
|
||||
return append(actions, forward), err
|
||||
}
|
||||
|
||||
// DNSQueriesDelete is the DNSQueries method for the Delete command
|
||||
func (j *JSONInput) DNSQueriesDelete(h *HTTPServer, current *DNSQuery) ([]*DNSQuery, error) {
|
||||
reverses, useful, err := current.SplitDeletionQuery(j.Params.Name, j.Params.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !useful {
|
||||
j.Params.DryRun = true
|
||||
current.AddPlainText(0, "Nothing to do, the record is unchanged")
|
||||
return []*DNSQuery{current}, nil
|
||||
}
|
||||
ret := []*DNSQuery{current}
|
||||
// add the reverse changes if needed
|
||||
for _, r := range reverses {
|
||||
parentName, err := h.dns.GetDomain(r.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret = append(ret, r.DNSQuery(parentName))
|
||||
if !r.IsDeletion() {
|
||||
continue
|
||||
}
|
||||
ip := ptrToIP(r.Name)
|
||||
if h.dns.IsUsed(ip, j.Params.Name, []string{"A", "AAAA"}) {
|
||||
message := "Reverse issue : %s is the reverse for %s and will be removed\n"
|
||||
message += "But other records are pointing to %s as well. Please cleanup first\n"
|
||||
return ret, fmt.Errorf(message, j.Params.Name, ip, ip)
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// CheckPTR validates that the query is a valid PTR
|
||||
func (j *JSONInput) CheckPTR(h *HTTPServer) error {
|
||||
ip := ptrToIP(j.Params.Name)
|
||||
if ip == "" {
|
||||
return fmt.Errorf("%s is not a valid PTR", j.Params.Name)
|
||||
}
|
||||
target := &DNSQuery{}
|
||||
if err := h.dns.CanCreate(j.Params.Value, true, target); err != nil || target.Len() == 0 {
|
||||
return err
|
||||
}
|
||||
for _, rec := range target.RRSets[0].Records {
|
||||
if rec.Content == ip {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%s must point to %s", j.Params.Value, ip)
|
||||
}
|
||||
|
||||
// CheckSRV validates that the query is a valid SRV
|
||||
func (j *JSONInput) CheckSRV(h *HTTPServer) error {
|
||||
name := validSRV(j.Params.Value)
|
||||
if name == "" {
|
||||
return fmt.Errorf("%s is not a valid SRV", j.Params.Value)
|
||||
}
|
||||
return h.dns.CanCreate(name, false, nil)
|
||||
}
|
||||
|
||||
// CheckCAA validates that the query is a valid CAA
|
||||
func (j *JSONInput) CheckCAA() error {
|
||||
v := validCAA(j.Params.Value)
|
||||
if v == "" {
|
||||
return fmt.Errorf("%s is not a valid CAA", j.Params.Value)
|
||||
}
|
||||
j.Params.Value = v
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckMX validates that the query is a valid MX
|
||||
func (j *JSONInput) CheckMX(h *HTTPServer) error {
|
||||
name := validMX(j.Params.Value)
|
||||
if name == "" {
|
||||
return fmt.Errorf("%s is not a valid MX", j.Params.Value)
|
||||
}
|
||||
return h.dns.CanCreate(name, true, nil)
|
||||
}
|
||||
|
||||
// CheckCNAME validates that the query is a valid CNAME
|
||||
func (j *JSONInput) CheckCNAME(h *HTTPServer) error {
|
||||
test := net.ParseIP(trimPoint(j.Params.Value))
|
||||
if test != nil {
|
||||
return fmt.Errorf("%s is an IP, not a DNS Name", j.Params.Value)
|
||||
}
|
||||
if !validName(j.Params.Value, false) {
|
||||
return fmt.Errorf("%s is not a valid DNS Name", j.Params.Value)
|
||||
}
|
||||
return h.dns.CanCreate(j.Params.Value, false, nil)
|
||||
}
|
||||
|
||||
// NewZone create a new zone in the DNS.
|
||||
// If the new zone is a subzone of an existing one, it will move potential
|
||||
// existing entries into the new zone
|
||||
func (j *JSONInput) NewZone(h *HTTPServer) (z *DNSZone, otherActions []*DNSQuery, err error) {
|
||||
// get the zone parameters
|
||||
zoneType, soa, nameServers, defaults, autoInc, err := h.GetZoneConfig(j.Params.Name)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if soa == "" {
|
||||
soa = fmt.Sprintf("%s hostmaster.%s 0 28800 7200 604800 86400", nameServers[0], j.Params.Name)
|
||||
}
|
||||
parentName, zoneErr := h.dns.GetDomain(j.Params.Name)
|
||||
parentName = trimPoint(parentName)
|
||||
|
||||
z, err = h.dns.NewZone(j.Params.Name, zoneType, soa, j.user, j.Params.Comment,
|
||||
j.Params.TTL, nameServers, autoInc)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// check if there is a parent zone
|
||||
if zoneErr == nil && parentName == trimPoint(j.Params.Name) {
|
||||
return nil, nil, fmt.Errorf("%s already exists", j.Params.Name)
|
||||
}
|
||||
|
||||
// we didn't find a parent zone, we must create the default entries
|
||||
if zoneErr != nil {
|
||||
for _, d := range defaults {
|
||||
// make a copy
|
||||
entry := *d
|
||||
if d.Params.Name == "" {
|
||||
entry.Params.Name = j.Params.Name
|
||||
} else {
|
||||
entry.Params.Name = fmt.Sprintf("%s.%s", d.Params.Name, j.Params.Name)
|
||||
}
|
||||
if err := entry.Normalize(); err != nil {
|
||||
h.dns.LogDebug(err)
|
||||
continue
|
||||
}
|
||||
actions, err := entry.DNSQueries(h)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, a := range actions {
|
||||
a.ChangeDomain(j.Params.Name)
|
||||
a.AddCommentAndTTL(j.user, j.Params.Comment, j.Params.TTL)
|
||||
z.AddEntries(a)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// we must create the NS records in the parent zone too
|
||||
glue, err := h.dns.SetNameServers(j.Params.Name, parentName, nameServers)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
otherActions = append(otherActions, glue)
|
||||
|
||||
// check if there are records in the parent zone
|
||||
records, err := h.dns.Zone(j.Params.Name)
|
||||
if err != nil {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
// if there are records, we must add them in the new zone...
|
||||
records.SetPlainTexts("Add", false)
|
||||
z.AddEntries(records)
|
||||
|
||||
// and delete them in the parent
|
||||
delete, err := h.dns.Zone(j.Params.Name)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
delete.EmptyZone()
|
||||
delete.ChangeDomain(parentName)
|
||||
otherActions = append(otherActions, delete)
|
||||
return
|
||||
}
|
||||
|
||||
// FilterSearch prune the DNSSearch according to the restrictions in j.listFilters
|
||||
func (j *JSONInput) FilterSearch(z []*DNSSearchEntry) []*DNSSearchEntry {
|
||||
if len(z) == 0 {
|
||||
return z
|
||||
}
|
||||
good := []int{}
|
||||
for i := range z {
|
||||
name := trimPoint(z[i].Name)
|
||||
for _, re := range j.listFilters {
|
||||
if re.MatchString(name) {
|
||||
good = append(good, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
newList := []*DNSSearchEntry{}
|
||||
for _, i := range good {
|
||||
newList = append(newList, z[i])
|
||||
}
|
||||
return newList
|
||||
}
|
||||
|
||||
// FilterList prune the DNSZones according to the restrictions in j.listFilters
|
||||
func (j *JSONInput) FilterList(z []*DNSZone) []*DNSZone {
|
||||
if len(z) == 0 {
|
||||
return z
|
||||
}
|
||||
good := []int{}
|
||||
for i := range z {
|
||||
name := trimPoint(z[i].Name)
|
||||
for _, re := range j.listFilters {
|
||||
if re.MatchString(name) {
|
||||
good = append(good, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
newList := []*DNSZone{}
|
||||
for _, i := range good {
|
||||
newList = append(newList, z[i])
|
||||
}
|
||||
return newList
|
||||
}
|
||||
|
||||
// ParsejsonRPCRequest read the payload of the query and put it in the structure
|
||||
func ParsejsonRPCRequest(s []byte, d *PowerDNS) (JSONArray, bool, error) {
|
||||
var inSimple *JSONInput
|
||||
var inArray JSONArray
|
||||
if json.Unmarshal(s, &inArray); len(inArray) > 0 {
|
||||
return inArray.SetDefaults(d), true, nil
|
||||
}
|
||||
if err := json.Unmarshal(s, &inSimple); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return JSONArray{inSimple}.SetDefaults(d), false, nil
|
||||
}
|
||||
|
||||
func isDryRun(j JSONArray) bool {
|
||||
for _, cmd := range j {
|
||||
if !cmd.Params.DryRun {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,189 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newHTTPClient(timeout int, keyFile, certFile string) (*http.Client, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
transport := &http.Transport{TLSClientConfig: tlsConfig}
|
||||
return &http.Client{
|
||||
Timeout: time.Duration(time.Duration(timeout) * time.Second),
|
||||
Transport: transport}, nil
|
||||
}
|
||||
|
||||
func jsonPost(url string, data []byte, client *http.Client) ([]byte, error) {
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func deleteZone(url string, client *http.Client) error {
|
||||
req, err := http.NewRequest("DELETE", url, nil)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if _, err := client.Do(req); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRecording(t *testing.T) {
|
||||
// JSONRPCResponse is a jsonRPC response structure
|
||||
type JSONRPCErrorOrResponse struct {
|
||||
ID int `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Result []*JSONRPCResult `json:"result"`
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
const (
|
||||
configFile = "pdns-proxy-unit-test.conf"
|
||||
commandsFile = "fixtures/replay/commands.asc"
|
||||
replayFile = "fixtures/replay/record"
|
||||
)
|
||||
var orig, rerun []*JSONRPCErrorOrResponse
|
||||
|
||||
badClient, err := newHTTPClient(30, "fixtures/test/client-key.pem", "fixtures/test/badclient-cert.pem")
|
||||
if err != nil {
|
||||
t.Errorf("cannot create bad client : %s", err)
|
||||
return
|
||||
}
|
||||
client, err := newHTTPClient(30, "fixtures/test/client-key.pem", "fixtures/test/client-cert.pem")
|
||||
if err != nil {
|
||||
t.Errorf("cannot create client : %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
h, err := loadConfig(configFile)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// disable nonce security
|
||||
h.nonceGen = ""
|
||||
go h.Run()
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
jsonCommands, err := ioutil.ReadFile(commandsFile)
|
||||
// no command files, no need to continue
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, zone := range []string{
|
||||
"10.in-addr.arpa",
|
||||
"2.0.192.in-addr.arpa",
|
||||
"toto.example.com",
|
||||
"example.com",
|
||||
} {
|
||||
url := fmt.Sprintf("https://127.0.0.1:8443/api/v1/servers/localhost/zones/%s", zone)
|
||||
deleteZone(url, client)
|
||||
}
|
||||
|
||||
result, err := jsonPost("https://127.0.0.1:8443/jsonrpc", []byte("{}"), badClient)
|
||||
if err != nil {
|
||||
t.Errorf("Issue with the server call : %s", err)
|
||||
return
|
||||
}
|
||||
badClientResponse := &JSONRPCErrorOrResponse{}
|
||||
if err := json.Unmarshal(result, badClientResponse); err != nil {
|
||||
t.Errorf("cannot decode response : %s", err)
|
||||
return
|
||||
}
|
||||
if badClientResponse.Error.Message != "Certificate for invalidserver is revoked" {
|
||||
t.Errorf("CRL not working")
|
||||
return
|
||||
}
|
||||
|
||||
result, err = jsonPost("https://127.0.0.1:8443/jsonrpc", jsonCommands, client)
|
||||
if err != nil {
|
||||
t.Errorf("Issue with the server call : %s", err)
|
||||
return
|
||||
}
|
||||
// get the structuve of the query
|
||||
jsonRPC, _, _, _, err := h.jrpcDecodeQuery(jsonCommands)
|
||||
|
||||
recording, err := ioutil.ReadFile(replayFile)
|
||||
// no command files, let's try to save the result for the next time
|
||||
if err != nil {
|
||||
if err := ioutil.WriteFile(replayFile, result, 0644); err != nil {
|
||||
t.Errorf("cannot write replay file")
|
||||
}
|
||||
return
|
||||
}
|
||||
// let's compare the 2 runs
|
||||
if err := json.Unmarshal(recording, &orig); err != nil {
|
||||
t.Errorf("cannot read replay file : %s", err)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(result, &rerun); err != nil {
|
||||
t.Errorf("cannot read result file : %s", err)
|
||||
return
|
||||
}
|
||||
last := len(orig) - 1
|
||||
for i := range orig {
|
||||
if orig[i].Error != rerun[i].Error {
|
||||
t.Errorf("the command \"pdns %s\" (line %d) has a different error message\n>>>>>>\n%s\n<<<<<<<\n%s",
|
||||
jsonRPC[i], i, rerun[i].Error.Message, orig[i].Error.Message)
|
||||
return
|
||||
}
|
||||
if rerun[i].Error.Message != "" {
|
||||
log.Printf("\"dmdns %s\" : %s\n", jsonRPC[i], rerun[i].Error.Message)
|
||||
}
|
||||
|
||||
if len(orig[i].Result) != len(rerun[i].Result) {
|
||||
t.Errorf("the command \"pdns %s\" (line %d) has a different result length", jsonRPC[i], i)
|
||||
return
|
||||
}
|
||||
for j := range orig[i].Result {
|
||||
if orig[i].Result[j].Comment != rerun[i].Result[j].Comment {
|
||||
t.Errorf("the command \"pdns %s\" (line %d) has a different error message\n>>>>>>\n%s\n<<<<<<<\n%s",
|
||||
jsonRPC[i], i, rerun[i].Result[j].Comment, orig[i].Result[j].Comment)
|
||||
return
|
||||
}
|
||||
if orig[i].Result[j].Result != rerun[i].Result[j].Result {
|
||||
t.Errorf("the command \"pdns %s\" (line %d) has a different result %d:\n>>>>>>\n%s\n<<<<<<<\n%s",
|
||||
jsonRPC[i], i, j, rerun[i].Result[j].Result, orig[i].Result[j].Result)
|
||||
return
|
||||
}
|
||||
if orig[i].Result[j].Changes != rerun[i].Result[j].Changes {
|
||||
t.Errorf("the command \"pdns %s\" (line %d) has a different result %d:\n>>>>>>\n%s\n<<<<<<<\n%s",
|
||||
jsonRPC[i], i, j, rerun[i].Result[j].Changes, orig[i].Result[j].Changes)
|
||||
return
|
||||
}
|
||||
log.Printf("%s : %s\n", rerun[i].Result[j].Comment, orig[i].Result[j].Result)
|
||||
if i == last {
|
||||
log.Println("\n" + rerun[i].Result[j].Changes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,272 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
)
|
||||
|
||||
// LdapHandler structure
|
||||
// Contains a pool of connection, and all information needed for authentication
|
||||
type LdapHandler struct {
|
||||
servers []string
|
||||
baseDN string
|
||||
bindCn string
|
||||
bindPw string
|
||||
searchFilter string
|
||||
attribute string
|
||||
pgpAttribute string
|
||||
validValues []string
|
||||
currentServer int
|
||||
ssl bool
|
||||
authCache *AuthCache
|
||||
clients chan LdapClient
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
// LdapClient interface definition
|
||||
type LdapClient interface {
|
||||
ldap.Client
|
||||
}
|
||||
|
||||
// NewLdap initializes a new LDAP structure
|
||||
// We start with an empty pool, no need to prepare connections
|
||||
func NewLdap(servers []string, bindCn, bindPw, baseDN, filter, attr, pgpAttr string, valid []string, nbConn int, ssl bool) *LdapHandler {
|
||||
ldap.DefaultTimeout = 3 * time.Second
|
||||
l := LdapHandler{
|
||||
servers: servers,
|
||||
baseDN: baseDN,
|
||||
bindCn: bindCn,
|
||||
bindPw: bindPw,
|
||||
searchFilter: filter,
|
||||
attribute: attr,
|
||||
pgpAttribute: pgpAttr,
|
||||
validValues: valid,
|
||||
clients: make(chan LdapClient, nbConn),
|
||||
authCache: NewAuthCache(),
|
||||
ssl: ssl,
|
||||
}
|
||||
return &l
|
||||
}
|
||||
|
||||
// PgpKeys get all knonw GPG keys in the directory that belong to entry
|
||||
// matching the the Auth Profile
|
||||
func (l *LdapHandler) PgpKeys() ([]string, error) {
|
||||
// try the cache first
|
||||
ret := l.authCache.PgpGet()
|
||||
if len(ret) > 0 {
|
||||
return ret, nil
|
||||
}
|
||||
// get a conn from the pool
|
||||
conn, err := l.GetConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// and put it back
|
||||
defer l.BackToPool(conn)
|
||||
|
||||
// the search parameters
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
l.baseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
"(objectClass=pgpUserKeyInfo)",
|
||||
[]string{l.attribute, l.pgpAttribute},
|
||||
nil,
|
||||
)
|
||||
sr, err := conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, entry := range sr.Entries {
|
||||
key := ""
|
||||
valid := false
|
||||
for _, attribute := range entry.Attributes {
|
||||
switch (*attribute).Name {
|
||||
case l.pgpAttribute:
|
||||
key = attribute.Values[0]
|
||||
case l.attribute:
|
||||
if inArray(l.validValues, attribute.Values) {
|
||||
valid = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if valid && key != "" {
|
||||
ret = append(ret, l.PgpNormalizeKey(key))
|
||||
}
|
||||
}
|
||||
if len(ret) > 0 {
|
||||
l.authCache.PgpSet(ret)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Auth method fot LDAP
|
||||
// We check that the subject exists in the ldap
|
||||
// If it's the case, we search for the attribute defined in the LdapHandler
|
||||
// This attribute's value must then be one of the registered value in the LdapHandler
|
||||
func (l *LdapHandler) Auth(subject string) (bool, error) {
|
||||
// use the cache if possible
|
||||
if l.authCache.Get(subject) {
|
||||
return true, nil
|
||||
}
|
||||
// get a conn from the pool
|
||||
conn, err := l.GetConn()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// and put it back
|
||||
defer l.BackToPool(conn)
|
||||
|
||||
// the search parameters
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
l.baseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf(l.searchFilter, subject),
|
||||
[]string{l.attribute},
|
||||
nil,
|
||||
)
|
||||
sr, err := conn.Search(searchRequest)
|
||||
|
||||
// if the search failed or returns more than one entry, it's not valid
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(sr.Entries) != 1 {
|
||||
return false, errors.New("User does not exist or too many entries returned")
|
||||
}
|
||||
// validate the values returned
|
||||
for _, attribute := range sr.Entries[0].Attributes {
|
||||
if (*attribute).Name == l.attribute {
|
||||
if inArray(l.validValues, attribute.Values) {
|
||||
l.authCache.Set(subject)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
// we found nothing, auth failed
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// NewConn creates a ldap client object
|
||||
func (l *LdapHandler) NewConn() (LdapClient, error) {
|
||||
var err error
|
||||
var conn *ldap.Conn
|
||||
// the server list cannot be empty
|
||||
if len(l.servers) == 0 {
|
||||
return nil, errors.New("Empty server list")
|
||||
}
|
||||
// circle the server list
|
||||
for i := 0; i < len(l.servers); i++ {
|
||||
l.currentServer++
|
||||
if l.currentServer >= len(l.servers) {
|
||||
l.currentServer = 0
|
||||
}
|
||||
if l.ssl {
|
||||
conn, err = ldap.DialTLS("tcp",
|
||||
fmt.Sprintf("%s:%d", l.servers[l.currentServer], 636),
|
||||
&tls.Config{ServerName: l.servers[l.currentServer]})
|
||||
} else {
|
||||
conn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d",
|
||||
l.servers[l.currentServer], 389))
|
||||
}
|
||||
// if the connection fails, maybe there is another server available
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
// First bind with a read only user
|
||||
// if it work, we are done
|
||||
if err = conn.Bind(l.bindCn, l.bindPw); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
// no working server were found
|
||||
return nil, errors.New("No valid ldap server found")
|
||||
}
|
||||
|
||||
// GetConn retrieves a new connection from the pool
|
||||
func (l *LdapHandler) GetConn() (LdapClient, error) {
|
||||
// check if the pool is valid
|
||||
if l.clients == nil {
|
||||
return nil, errors.New("Pool is closed")
|
||||
}
|
||||
select {
|
||||
case conn := <-l.clients:
|
||||
if conn == nil {
|
||||
return nil, errors.New("Pool is closed")
|
||||
}
|
||||
if ldapIsAlive(conn) {
|
||||
return conn, nil
|
||||
}
|
||||
// dead connection, restart it
|
||||
conn.Close()
|
||||
return l.NewConn()
|
||||
default:
|
||||
// No more conn in Pool, create a new one and return it
|
||||
return l.NewConn()
|
||||
}
|
||||
}
|
||||
|
||||
// BackToPool returns a connection to the pool
|
||||
func (l *LdapHandler) BackToPool(p LdapClient) error {
|
||||
// if it's nil, stop here
|
||||
if p == nil {
|
||||
return errors.New("Connexion is closed")
|
||||
}
|
||||
// check if it's alive. If not, no need to put it back
|
||||
if !ldapIsAlive(p) {
|
||||
p.Close()
|
||||
return errors.New("returned connection was closed")
|
||||
}
|
||||
// same if the pool is not active
|
||||
if l.clients == nil {
|
||||
p.Close()
|
||||
return errors.New("Pool is closed")
|
||||
}
|
||||
select {
|
||||
case l.clients <- p:
|
||||
return nil
|
||||
default:
|
||||
// Pool is full
|
||||
p.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// make a dummy request to validate that the server is alive
|
||||
func ldapIsAlive(client LdapClient) bool {
|
||||
_, err := client.Search(&ldap.SearchRequest{BaseDN: "", Scope: ldap.ScopeBaseObject, Filter: "(&)", Attributes: []string{"1.1"}})
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Len returns information about the pool
|
||||
func (l *LdapHandler) Len() int {
|
||||
return len(l.clients)
|
||||
}
|
||||
|
||||
// Cap returns cap information about the pool
|
||||
func (l *LdapHandler) Cap() int {
|
||||
return cap(l.clients)
|
||||
}
|
||||
|
||||
// PgpNormalizeKey replace space with new line in ldap stored GPG keys
|
||||
func (l *LdapHandler) PgpNormalizeKey(key string) string {
|
||||
var rkey, sep string
|
||||
for _, sub := range strings.Split(key, " ") {
|
||||
rkey = rkey + sep + sub
|
||||
|
||||
if strings.HasPrefix(sub, "---") {
|
||||
sep = " "
|
||||
}
|
||||
if strings.HasSuffix(sub, "---") {
|
||||
sep = "\n"
|
||||
}
|
||||
}
|
||||
return rkey
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func hasLdapProfile() (bool, *LdapHandler) {
|
||||
for _, profile := range fakeserver.authProfiles {
|
||||
if p, ok := profile.(AuthProfileLdap); ok {
|
||||
return true, p.ldap
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func TestLdapAuth(t *testing.T) {
|
||||
if ok, _ := hasLdapProfile(); !ok {
|
||||
return
|
||||
}
|
||||
// optionnal test, need a ldap configuration in localtest.conf
|
||||
v := fakeserver.getProfiles("xavier@example.org")
|
||||
if len(v) != 1 {
|
||||
t.Errorf("cannot valid test profile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLdapPool(t *testing.T) {
|
||||
conns := []LdapClient{}
|
||||
|
||||
ok, l := hasLdapProfile()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cap := l.Cap()
|
||||
|
||||
// use all connexions in the pool, with one bonus
|
||||
for i := 0; i < cap+1; i++ {
|
||||
conn, err := l.GetConn()
|
||||
if err != nil {
|
||||
t.Errorf("Get error: %s", err)
|
||||
}
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
// return the originally pooled connexions to the pool
|
||||
for i := 0; i < cap; i++ {
|
||||
if err := l.BackToPool(conns[i]); err != nil {
|
||||
t.Errorf("BackToPool error: %s", err)
|
||||
}
|
||||
if nbConn := l.Len(); nbConn != i+1 {
|
||||
t.Errorf("Pool didn't return used conn %d", nbConn)
|
||||
}
|
||||
// even if is was returned, it's still there
|
||||
if !ldapIsAlive(conns[i]) {
|
||||
t.Errorf("This connextion should not be closed")
|
||||
}
|
||||
}
|
||||
// return the last one
|
||||
if err := l.BackToPool(conns[cap]); err != nil {
|
||||
t.Errorf("BackToPool error: %s", err)
|
||||
}
|
||||
if nbConn := l.Len(); nbConn != cap {
|
||||
t.Errorf("The pool should be full")
|
||||
}
|
||||
// check it was closed
|
||||
if ldapIsAlive(conns[cap]) {
|
||||
t.Errorf("This connexion should be closed")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"log/syslog"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
|
||||
// default configuration file is /etc/pdns-proxy.conf
|
||||
configFile := flag.String("config", "/etc/pdns-proxy.conf", "configuration file")
|
||||
logToSyslog := flag.Bool("syslog", false, "Log to syslog")
|
||||
debug := flag.Bool("debug", false, "log every message received")
|
||||
flag.Parse()
|
||||
|
||||
// in the production environnement, redirect every lgo() call to syslog
|
||||
if *logToSyslog {
|
||||
log.SetFlags(0)
|
||||
logWriter, e := syslog.New(syslog.LOG_NOTICE, "pdnsProxy")
|
||||
if e == nil {
|
||||
log.SetOutput(logWriter)
|
||||
defer logWriter.Close()
|
||||
}
|
||||
}
|
||||
h, err := loadConfig(*configFile)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *debug {
|
||||
h.Debug()
|
||||
}
|
||||
h.Run()
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
config
|
||||
{
|
||||
profiles
|
||||
{
|
||||
{{< localtest.conf }}
|
||||
testS2S
|
||||
{
|
||||
subjectRegexp: "validserver"
|
||||
type: regexp
|
||||
}
|
||||
}
|
||||
pdnsAcls
|
||||
{
|
||||
"testS2S"
|
||||
{
|
||||
regexp: "zones/dev\\..*"
|
||||
perms: ["r"]
|
||||
profiles: [ "testS2S" ]
|
||||
},
|
||||
"admin"
|
||||
{
|
||||
regexp: ".*"
|
||||
perms: ["r", "w"]
|
||||
profiles: [ "infra" ]
|
||||
},
|
||||
"writeTest"
|
||||
{
|
||||
regexp: "zones/specificdomain.example"
|
||||
perms: ["r", "w"]
|
||||
profiles: [ "testS2S" ]
|
||||
},
|
||||
}
|
||||
jrpcAcls
|
||||
{
|
||||
"admin"
|
||||
{
|
||||
perms
|
||||
{
|
||||
"*": [ ".*" ]
|
||||
}
|
||||
pgpProfiles: [ "infra" ]
|
||||
},
|
||||
"testS2S"
|
||||
{
|
||||
perms
|
||||
{
|
||||
"*": [ ".*toto.example.com" ]
|
||||
"list": [ ".*" ]
|
||||
"search" [ ".*example.com" ]
|
||||
}
|
||||
sslProfiles: [ "testS2S" ]
|
||||
},
|
||||
"webui":
|
||||
{
|
||||
perms
|
||||
{
|
||||
"*": [ ".*corp.*" ]
|
||||
}
|
||||
sslProfiles: [ "infra" ]
|
||||
}
|
||||
"security"
|
||||
{
|
||||
perms
|
||||
{
|
||||
"dump": [ ".*" ]
|
||||
"list": [ ".*" ]
|
||||
"search": [ ".*" ]
|
||||
}
|
||||
sslProfiles: [ "security"]
|
||||
}
|
||||
}
|
||||
http
|
||||
{
|
||||
port: ":8443"
|
||||
ca: "fixtures/test/ca.crt"
|
||||
key: "fixtures/test/server-key.pem"
|
||||
cert: "fixtures/test/server-cert.pem"
|
||||
}
|
||||
pdns
|
||||
{
|
||||
api-key: "123password"
|
||||
api-url: "http://127.0.0.1:8081/api/v1/servers/localhost"
|
||||
timeout: 300
|
||||
defaultTTL: 172800
|
||||
|
||||
|
||||
}
|
||||
zoneProfile
|
||||
{
|
||||
Native
|
||||
{
|
||||
nameservers: [ "a.iana-servers.net.", "b.iana-servers.net." ]
|
||||
default: false
|
||||
soa: "ns.icann.org. noc.dns.icann.org. 0 28800 7200 604800 86400"
|
||||
whenRegexp
|
||||
[
|
||||
"(^|.*[^.]\\.)10\\.in-addr\\.arpa",
|
||||
"(^|.*[^.]\\.)168\\.192\\.in-addr\\.arpa",
|
||||
"(^|.*[^.]\\.)(1[6-9]|2[0-9]|3[0-1])\\.172\\.in-addr\\.arpa",
|
||||
"(^|.*[^.]\\.)(6[4-9]|[7-9][0-9]|1([0-1][0-9]|2[0-7]))\\.100\\.in-addr\\.arpa",
|
||||
]
|
||||
}
|
||||
Master
|
||||
{
|
||||
nameservers: [ "a.iana-servers.net.", "b.iana-servers.net." ]
|
||||
default: true
|
||||
soa: "ns.icann.org. noc.dns.icann.org. 0 28800 7200 604800 86400"
|
||||
populate
|
||||
{
|
||||
spf
|
||||
{
|
||||
name: ""
|
||||
type: "txt"
|
||||
value: "v=spf1 -all"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
config
|
||||
{
|
||||
profiles:
|
||||
{
|
||||
infra:
|
||||
{
|
||||
subjectRegexp: ".*@example.org"
|
||||
type: "ldap"
|
||||
servers: [ "ldap" ]
|
||||
bindCn: "cn=admin,dc=example,dc=org"
|
||||
bindPw: "admin"
|
||||
baseDN: "ou=users,dc=example,dc=org"
|
||||
searchFilter: "(&(mail=%s))"
|
||||
attribute: "description"
|
||||
pgpAttribute: "pgpKey"
|
||||
ssl: false
|
||||
validValues: [ "infra", "vwf" ]
|
||||
}
|
||||
testS2S:
|
||||
{
|
||||
subjectRegexp: "validserver"
|
||||
type: regexp
|
||||
pgpKeys: "{{<fixtures/test/public-key.txt}}"
|
||||
}
|
||||
}
|
||||
pdnsAcls:
|
||||
{
|
||||
"testS2S":
|
||||
{
|
||||
regexp: ".*"
|
||||
perms: ["r", "w"]
|
||||
profiles: [ "testS2S" ]
|
||||
},
|
||||
"admin":
|
||||
{
|
||||
regexp: ".*"
|
||||
perms: ["r", "w"]
|
||||
profiles: [ "infra" ]
|
||||
},
|
||||
"writeTest":
|
||||
{
|
||||
regexp: "zones/specificdomain.example"
|
||||
perms: ["r", "w"]
|
||||
profiles: [ "testS2S" ]
|
||||
},
|
||||
}
|
||||
jrpcAcls:
|
||||
{
|
||||
"admin"
|
||||
{
|
||||
perms
|
||||
{
|
||||
"*": [ ".*" ]
|
||||
}
|
||||
pgpProfiles: [ "testS2S" ]
|
||||
},
|
||||
}
|
||||
http:
|
||||
{
|
||||
port: "127.0.0.1:8443"
|
||||
ca: "fixtures/test/ca.crt"
|
||||
key: "fixtures/test/server-key.pem"
|
||||
cert: "fixtures/test/server-cert.pem"
|
||||
crl: "fixtures/test/root.crl.pem"
|
||||
}
|
||||
pdns:
|
||||
{
|
||||
api-key: "123password"
|
||||
api-url: "http://127.0.0.1:8081/api/v1/servers/localhost"
|
||||
defaultTTL: 172800
|
||||
}
|
||||
zoneProfile:
|
||||
{
|
||||
Native:
|
||||
{
|
||||
nameservers: [ "a.example.org.", "b.example.org." ]
|
||||
default: true
|
||||
autoIncrement: false
|
||||
soa: "a.example.org. admin.example.com. 0 10380 3600 604800 3600"
|
||||
populate
|
||||
{
|
||||
spf
|
||||
{
|
||||
name: ""
|
||||
type: "txt"
|
||||
value: "v=spf1 -all"
|
||||
}
|
||||
}
|
||||
}
|
||||
Master:
|
||||
{
|
||||
nameservers: [ "private-01.example.org.", "private-02.example.org." ]
|
||||
soa: "private-01.example.org. admin.priv.example.com. 0 10380 3600 604800 3600"
|
||||
autoIncrement: false
|
||||
whenRegexp
|
||||
[
|
||||
"(^|.*[^.]\\.)10\\.in-addr\\.arpa",
|
||||
"(^|.*[^.]\\.)168\\.192\\.in-addr\\.arpa",
|
||||
"(^|.*[^.]\\.)(1[6-9]|2[0-9]|3[0-1])\\.172\\.in-addr\\.arpa",
|
||||
"(^|.*[^.]\\.)(6[4-9]|[7-9][0-9]|1([0-1][0-9]|2[0-7]))\\.100\\.in-addr\\.arpa",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
config
|
||||
{
|
||||
profiles:
|
||||
{
|
||||
infra:
|
||||
{
|
||||
subjectRegexp: ".*@example.org"
|
||||
type: ldap
|
||||
servers: [ "ldap.example.org" ]
|
||||
bindCn: "cn=readonly,dc=example,dc=org"
|
||||
bindPw: "**********"
|
||||
baseDN: "ou=users,dc=example,dc=org"
|
||||
searchFilter: "(&(mail=%s))"
|
||||
attribute: "description"
|
||||
pgpAttribute: "pgpKey"
|
||||
validValues: [ "infra", "vwf" ]
|
||||
}
|
||||
devServer:
|
||||
{
|
||||
subjectRegexp: "[a-z0-9-]*\\.dev\\.[a-z0-9]*\\.example.org"
|
||||
type: regexp
|
||||
}
|
||||
letsencrypt:
|
||||
{
|
||||
subjectRegexp: "probe-[0-9]*\\.adm\\.dc3\\.example.org"
|
||||
type: regexp
|
||||
}
|
||||
icscale:
|
||||
{
|
||||
subjectRegexp: "icscale-[0-9]*\\.adm\\.[a-z0-9]*\\.example.org"
|
||||
type: regexp
|
||||
}
|
||||
}
|
||||
pdnsAcls:
|
||||
{
|
||||
"dev":
|
||||
{
|
||||
regexp: "zones/dev\\.[a-z0-9]*\\.example.org"
|
||||
perms: ["r"]
|
||||
profiles: [ "devServer", "devUsers" ]
|
||||
},
|
||||
"letsencrypt":
|
||||
{
|
||||
regexp: "zones/.*"
|
||||
perms: ["r", "w"]
|
||||
profiles: [ "letsencrypt" ]
|
||||
},
|
||||
"infra":
|
||||
{
|
||||
regexp: ".*"
|
||||
perms: ["r", "w"]
|
||||
profiles: [ "infra" ]
|
||||
},
|
||||
"scale":
|
||||
{
|
||||
regexp: "zones/kube.dm.gg"
|
||||
perms: ["r", "w"]
|
||||
profiles: [ "icscale" ]
|
||||
},
|
||||
}
|
||||
jrpcAcls:
|
||||
{
|
||||
}
|
||||
http:
|
||||
{
|
||||
port: ":443"
|
||||
ca: "/usr/local/share/ca-certificates/ca.crt"
|
||||
key: "/etc/ssl/private/server-key.pem"
|
||||
cert: "/etc/ssl/certs/server-bundle.pem"
|
||||
}
|
||||
pdns:
|
||||
{
|
||||
api-key: "<pdns_api_key>"
|
||||
api-url: "http://127.0.0.1:8081/api/v1/servers/localhost"
|
||||
}
|
||||
zoneProfile:
|
||||
{
|
||||
private:
|
||||
{
|
||||
nameservers: [ "a.example.org", "b.example.org" ]
|
||||
zoneType: "MASTER"
|
||||
}
|
||||
public:
|
||||
{
|
||||
nameservers: [ "a.iana-servers.net", "b.iana-servers.net" ]
|
||||
zoneType: "NATIVE"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,320 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PowerDNS defines the client to which queries are passed
|
||||
type (
|
||||
PowerDNS struct {
|
||||
Scheme string
|
||||
Hostname string
|
||||
Port string
|
||||
apiURL string
|
||||
debug bool
|
||||
Client *http.Client
|
||||
m sync.RWMutex
|
||||
listCache map[string]*listCache
|
||||
DefaultTTL int
|
||||
}
|
||||
)
|
||||
|
||||
// NewClient initializes a new PowerDNS client configuration
|
||||
func NewClient(baseURL, apiKey string, timeout, ttl int) (*PowerDNS, error) {
|
||||
scheme, hostname, port, path, err := parseBaseURL(baseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("%s is not a valid URL: %v", baseURL, err)
|
||||
}
|
||||
transport := http.DefaultTransport
|
||||
apiKeyTransport := &APIKeyTransport{
|
||||
Transport: transport,
|
||||
APIKey: apiKey,
|
||||
}
|
||||
errorTransport := &ErrorTransport{
|
||||
Transport: apiKeyTransport,
|
||||
}
|
||||
powerDNS := &PowerDNS{
|
||||
Scheme: scheme,
|
||||
Hostname: hostname,
|
||||
Port: port,
|
||||
Client: &http.Client{
|
||||
Transport: errorTransport,
|
||||
Timeout: time.Duration(timeout) * time.Second,
|
||||
},
|
||||
apiURL: path,
|
||||
listCache: map[string]*listCache{},
|
||||
DefaultTTL: ttl,
|
||||
}
|
||||
return powerDNS, nil
|
||||
}
|
||||
|
||||
// Debug toggle the debug mode
|
||||
func (p *PowerDNS) Debug() {
|
||||
p.debug = true
|
||||
}
|
||||
|
||||
// LogDebug facility
|
||||
func (p *PowerDNS) LogDebug(v ...interface{}) {
|
||||
if p.debug {
|
||||
log.Println(v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Lock the structure
|
||||
func (p *PowerDNS) lock() {
|
||||
p.m.Lock()
|
||||
}
|
||||
|
||||
// ...or unlock it
|
||||
func (p *PowerDNS) unlock() {
|
||||
p.m.Unlock()
|
||||
}
|
||||
|
||||
// Ping tries to contact a powerDNS URL API to make sure its up and accessible
|
||||
func (p *PowerDNS) Ping(ctx context.Context) error {
|
||||
u := url.URL{}
|
||||
u.Host = p.Hostname + ":" + p.Port
|
||||
u.Scheme = p.Scheme
|
||||
u.Path = "/api"
|
||||
req, err := http.NewRequest("GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := p.Client.Do(req.WithContext(ctx))
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute send the DNSQuery structure to the PDNS api
|
||||
func (p *PowerDNS) Execute(d *DNSQuery) (int, http.Header, error) {
|
||||
uri := fmt.Sprintf("%s/%s/%s", p.apiURL, "zones", trimPoint(d.Domain))
|
||||
return p.sendQuery(context.Background(), uri, "PATCH", printJSON(d), nil)
|
||||
}
|
||||
|
||||
// ExecuteZone send the DNSZone structure to the PDNS api
|
||||
func (p *PowerDNS) ExecuteZone(d *DNSZone, resp *DNSQuery) (int, http.Header, error) {
|
||||
uri := fmt.Sprintf("%s/%s", p.apiURL, "zones")
|
||||
// purge the zone cache
|
||||
p.listCache = map[string]*listCache{}
|
||||
return p.sendQuery(context.Background(), uri, "POST", printJSON(d), resp)
|
||||
}
|
||||
|
||||
// Zone return the whole zone if it exist
|
||||
func (p *PowerDNS) Zone(name string) (*DNSQuery, error) {
|
||||
var result DNSQuery
|
||||
|
||||
parentName, err := p.GetDomain(name)
|
||||
if err != nil {
|
||||
return nil, errors.New("Unknown domain")
|
||||
}
|
||||
code, _, err := p.sendQuery(context.Background(),
|
||||
fmt.Sprintf("%s/%s/%s", p.apiURL, "zones", parentName), "GET", nil, &result)
|
||||
if code != 200 {
|
||||
return nil, errors.New("Unknown zone, but it shouldn't")
|
||||
}
|
||||
// if this is a subzone, filter the result
|
||||
if strings.HasSuffix(name, "."+addPoint(parentName)) {
|
||||
good := []int{}
|
||||
filter := []*DNSRRSet{}
|
||||
sub := "." + name
|
||||
for i := range result.RRSets {
|
||||
if result.RRSets[i].Name == name || strings.HasSuffix(result.RRSets[i].Name, sub) {
|
||||
good = append(good, i)
|
||||
}
|
||||
}
|
||||
for _, i := range good {
|
||||
filter = append(filter, result.RRSets[i])
|
||||
}
|
||||
result.RRSets = filter
|
||||
}
|
||||
if err == nil {
|
||||
result.Domain = name
|
||||
}
|
||||
return &result, err
|
||||
}
|
||||
|
||||
// Search perform a search in the pdns API
|
||||
func (p *PowerDNS) Search(r string) (DNSSearch, error) {
|
||||
var records DNSSearch
|
||||
// remove any final point to the record if necessary
|
||||
url := fmt.Sprintf("%s/%s?max=5000&q=%s", p.apiURL, "search-data", trimPoint(r))
|
||||
|
||||
if _, _, err := p.sendQuery(context.Background(), url, "GET", nil, &records); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// IsUsed search the DB to check if anyting is pointing to name
|
||||
func (p *PowerDNS) IsUsed(name, exception string, rtype []string) bool {
|
||||
exception = strings.ToLower(exception)
|
||||
search, err := p.Search(name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, rec := range search {
|
||||
if !rec.IsRecord() || rec.Name == exception {
|
||||
continue
|
||||
}
|
||||
for _, t := range rtype {
|
||||
if t == rec.Type {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ReverseChanges returns the list of changes needed if we move the value of
|
||||
// the record d which pointed to excludeIP
|
||||
func (p *PowerDNS) ReverseChanges(d *DNSQuery, excludeIP string) ([]*DNSQuery, error) {
|
||||
actions := []*DNSQuery{}
|
||||
for _, entry := range d.RRSets {
|
||||
for _, record := range entry.Records {
|
||||
ip := record.Content
|
||||
// ignore the record pointing to excludeIP
|
||||
if ip == excludeIP {
|
||||
continue
|
||||
}
|
||||
reverse, _ := p.GetReverse(ip, entry.Type == "A")
|
||||
// no reverse (inconstancy or external IP): no problem here
|
||||
if reverse == nil || reverse.Len() == 0 {
|
||||
continue
|
||||
}
|
||||
// this reverse doesn't point back to entry.Name, no action needed
|
||||
if !reverse.RemoveValue(entry.Name) {
|
||||
continue
|
||||
}
|
||||
// if we don't remove the reverse, no problem
|
||||
if !reverse.RRSets[0].IsDeletion() {
|
||||
continue
|
||||
}
|
||||
// there was only one PTR, and it's pointing back to j.Params.Name, we
|
||||
// need to remove it
|
||||
if p.IsUsed(ip, entry.Name, []string{entry.Type}) {
|
||||
message := "Reverse issue : %s is the reverse for %s and will be changed\n"
|
||||
message += "But other records are pointing to %s as well. Please cleanup first\n"
|
||||
return actions, fmt.Errorf(message, entry.Name, ip, ip)
|
||||
}
|
||||
actions = append(actions, reverse)
|
||||
}
|
||||
}
|
||||
return actions, nil
|
||||
}
|
||||
|
||||
// getRecord get the PDNS record
|
||||
func (p *PowerDNS) getRecord(r string, wanted, notWanted []string, ignoreBadDomain bool) (*DNSQuery, error) {
|
||||
// Get the domain
|
||||
d, err := p.GetDomain(r)
|
||||
|
||||
// we don't manage the domain, no need to continue (ignoreBadDomain is used
|
||||
// for newly created domains only)
|
||||
if err != nil && !ignoreBadDomain {
|
||||
return nil, err
|
||||
}
|
||||
// get the records, in search format
|
||||
records, err := p.Search(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Now we convert the search format to the DNSQuery format
|
||||
query := records.DNSQuery()
|
||||
query.Domain = d
|
||||
|
||||
// and we do some cleanup
|
||||
err = query.RecordFilter(r, wanted, notWanted)
|
||||
|
||||
return query, err
|
||||
}
|
||||
|
||||
// CanCreate return an error if t can be created on the local DNS and doesn't
|
||||
// exist. An external record, existing or not, is fine
|
||||
func (p *PowerDNS) CanCreate(r string, directOnly bool, target *DNSQuery) error {
|
||||
search, err := p.getRecord(r, []string{"*"}, []string{}, false)
|
||||
// copy the actual structure
|
||||
if target != nil && search != nil {
|
||||
*target = *search
|
||||
}
|
||||
switch {
|
||||
case err == nil:
|
||||
break
|
||||
case err.Error() == "Unknown domain":
|
||||
return nil
|
||||
case strings.HasSuffix(err.Error(), "delegated to another server"):
|
||||
return nil
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
if directOnly {
|
||||
if err := search.RecordFilter(r,
|
||||
[]string{"A", "AAAA"},
|
||||
[]string{"CNAME", "DNAME"}); err != nil {
|
||||
return fmt.Errorf("%s cannot be a CNAME", r)
|
||||
}
|
||||
}
|
||||
if search.Len() == 0 {
|
||||
return fmt.Errorf("%s doesn't exist, create it first", r)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRecord get the PDNS record of type t
|
||||
func (p *PowerDNS) GetRecord(r, t string, ignoreBadDomain bool) (*DNSQuery, error) {
|
||||
notWanted := []string{}
|
||||
wanted := []string{strings.ToUpper(t)}
|
||||
switch t {
|
||||
case "ttl":
|
||||
wanted = []string{"*"}
|
||||
notWanted = []string{}
|
||||
case "delete":
|
||||
wanted = []string{"*", "indirect"}
|
||||
case "cname":
|
||||
notWanted = []string{"*"}
|
||||
case "dname":
|
||||
notWanted = []string{"*"}
|
||||
case "ns":
|
||||
wanted = []string{"NS"}
|
||||
notWanted = []string{"*"}
|
||||
default:
|
||||
notWanted = []string{"CNAME", "DNAME"}
|
||||
}
|
||||
return p.getRecord(r, wanted, notWanted, ignoreBadDomain)
|
||||
}
|
||||
|
||||
// GetReverse get the reverse DNS record of the IP
|
||||
func (p *PowerDNS) GetReverse(ipString string, v4 bool) (*DNSQuery, error) {
|
||||
ip := net.ParseIP(ipString)
|
||||
if ip == nil || (v4 != (ip.To4() != nil)) {
|
||||
return nil, errors.New("not a valid IP address")
|
||||
}
|
||||
// we ignore the error on the actual query, which most certainly means
|
||||
// there is no reverse or that we don't manage it
|
||||
ret, _ := p.getRecord(iPtoReverse(ip), []string{"PTR"}, nil, false)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// SetNameServers returns a *DNSQuery to create a set of nameservers in a zone
|
||||
// or its parent
|
||||
func (p *PowerDNS) SetNameServers(zone, where string, servers []string) (*DNSQuery, error) {
|
||||
ns, err := p.GetRecord(zone, "NS", true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, nameServer := range servers {
|
||||
ns.ChangeValue(zone, nameServer, "NS", false, false)
|
||||
}
|
||||
ns.ChangeDomain(where)
|
||||
return ns, nil
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type (
|
||||
// APIKeyTransport definition
|
||||
APIKeyTransport struct {
|
||||
Transport http.RoundTripper
|
||||
APIKey string
|
||||
}
|
||||
|
||||
// Handle the answer from the PowerDNS API
|
||||
httpStatusError struct {
|
||||
Response *http.Response
|
||||
Body []byte // Copied from `Response.Body` to avoid problems with unclosed bodies later.
|
||||
}
|
||||
|
||||
// ErrorTransport defines the data structure for returning errors from the round tripper
|
||||
ErrorTransport struct {
|
||||
Transport http.RoundTripper
|
||||
}
|
||||
)
|
||||
|
||||
// RoundTrip implements http.RoundTripper for ApiKeyTransport (RoundTrip method)
|
||||
func (t *APIKeyTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
request.Header.Set("X-API-Key", fmt.Sprintf("%s", t.APIKey))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
return t.Transport.RoundTrip(request)
|
||||
}
|
||||
|
||||
func (err *httpStatusError) Error() string {
|
||||
return fmt.Sprintf("http: non-successful response (status=%v body=%q)", err.Response.StatusCode, err.Body)
|
||||
}
|
||||
|
||||
// RoundTrip is the round tripper for the error transport
|
||||
func (t *ErrorTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
resp, err := t.Transport.RoundTrip(request)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusUnauthorized {
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("http: failed to read response body (status=%v, err=%q)", resp.StatusCode, err)
|
||||
}
|
||||
return nil, &httpStatusError{
|
||||
Response: resp,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Proxy is a HTTP Handler which will send the queries directly to PDNS
|
||||
func (p *PowerDNS) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
// ctx is the Context for this handler.
|
||||
var (
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// Thanks to context, we could do some req checks before submitting it to powerDNS client
|
||||
// Maybe later
|
||||
|
||||
var response []byte
|
||||
// Get body if provided
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Cannot read body", 400)
|
||||
return
|
||||
}
|
||||
if r.ContentLength == 0 {
|
||||
body = nil
|
||||
}
|
||||
uri := r.RequestURI
|
||||
stats := false
|
||||
if strings.HasPrefix(uri, "/stats/") {
|
||||
stats = true
|
||||
uri = strings.Replace(uri, "/stats/", "/", -1)
|
||||
}
|
||||
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
_, head, err := p.sendQuery(ctx, uri, r.Method, body, &response)
|
||||
if err == nil {
|
||||
// copy the headers
|
||||
for name := range head {
|
||||
if stats && name == "Content-Length" {
|
||||
continue
|
||||
}
|
||||
w.Header().Set(name, head.Get(name))
|
||||
}
|
||||
stringRet := string(response)
|
||||
if stats {
|
||||
stringRet = strings.Replace(stringRet, `href="/"`, `href="/stats/"`, -1)
|
||||
}
|
||||
fmt.Fprintf(w, "%s", stringRet)
|
||||
return
|
||||
}
|
||||
urlErr, ok := err.(*url.Error)
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "Response is: %v", err)
|
||||
return
|
||||
}
|
||||
httpErr, ok := urlErr.Err.(*httpStatusError)
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "Response is: %v", urlErr)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "Response is: %v", httpErr)
|
||||
}
|
||||
|
||||
// sendQuery returns the anwser of a request made to the API
|
||||
func (p *PowerDNS) sendQuery(ctx context.Context, endpoint, method string, body []byte, response interface{}) (int, http.Header, error) {
|
||||
var req *http.Request
|
||||
var err error
|
||||
u := url.URL{}
|
||||
u.Host = p.Hostname + ":" + p.Port
|
||||
u.Scheme = p.Scheme
|
||||
if !strings.HasPrefix(endpoint, "/") {
|
||||
u.Path = "/"
|
||||
}
|
||||
uri := u.String() + endpoint
|
||||
|
||||
p.LogDebug(fmt.Sprintf("pdns.client %s", uri))
|
||||
if body != nil {
|
||||
p.LogDebug(string(body))
|
||||
}
|
||||
if body != nil {
|
||||
req, err = http.NewRequest(method, uri, bytes.NewReader(body))
|
||||
} else {
|
||||
req, err = http.NewRequest(method, uri, nil)
|
||||
}
|
||||
if err != nil {
|
||||
return 500, nil, err
|
||||
}
|
||||
resp, err := p.Client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return 500, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
p.LogDebug(fmt.Sprintf(" resp.Status=%s", resp.Status))
|
||||
raw, _ := ioutil.ReadAll(resp.Body)
|
||||
|
||||
if response == nil {
|
||||
return resp.StatusCode, resp.Header, nil
|
||||
}
|
||||
if err := json.Unmarshal(raw, response); err == nil {
|
||||
return resp.StatusCode, resp.Header, nil
|
||||
}
|
||||
if ar, ok := response.(*[]byte); ok {
|
||||
*ar = raw
|
||||
}
|
||||
p.LogDebug(string(raw))
|
||||
return resp.StatusCode, resp.Header, nil
|
||||
}
|
||||
|
||||
// parseBaseURL decompose an URL into parts
|
||||
func parseBaseURL(baseURL string) (string, string, string, string, error) {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
hp := strings.Split(u.Host, ":")
|
||||
hostname := hp[0]
|
||||
var port string
|
||||
if len(hp) > 1 {
|
||||
port = hp[1]
|
||||
} else {
|
||||
if u.Scheme == "https" {
|
||||
port = "443"
|
||||
} else {
|
||||
port = "8081"
|
||||
}
|
||||
}
|
||||
return u.Scheme, hostname, port, u.Path, nil
|
||||
}
|
|
@ -0,0 +1,467 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type (
|
||||
// DNSQuery is the json structure of the get/add/replace actions in the PDNS API
|
||||
DNSQuery struct {
|
||||
RRSets []*DNSRRSet `json:"rrsets"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
// DNSRRSet is the json structure of a DNS entry in the PDNS API
|
||||
DNSRRSet struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Comment []*DNSComments `json:"comments"`
|
||||
Records []*DNSRecord `json:"records"`
|
||||
ChangeType string `json:"changetype"`
|
||||
TTL int `json:"ttl"`
|
||||
PlainText string `json:"-"`
|
||||
}
|
||||
// DNSComments is the json structure of a comment in the PDNS API
|
||||
DNSComments struct {
|
||||
Account string `json:"account"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
// DNSRecord is the json structure of a record in the PDNS API
|
||||
DNSRecord struct {
|
||||
Content string `json:"content"`
|
||||
Disabled bool `json:"disabled"`
|
||||
SetPTR bool `json:"set-ptr"`
|
||||
}
|
||||
)
|
||||
|
||||
// IsDeletion tells if the nth record of the query is a DELETE or not
|
||||
func (d *DNSQuery) IsDeletion(n int) bool {
|
||||
if d.RecordSize(n) < n {
|
||||
return false
|
||||
}
|
||||
return d.RRSets[0].IsDeletion()
|
||||
}
|
||||
|
||||
// ChangeDomain sets a new name for the DNSQuery
|
||||
func (d *DNSQuery) ChangeDomain(s string) {
|
||||
d.Domain = s
|
||||
}
|
||||
|
||||
// HasDomain tell of the query as the Domain variable empty or not
|
||||
func (d *DNSQuery) HasDomain() bool {
|
||||
return len(d.Domain) != 0
|
||||
}
|
||||
|
||||
// SetPlainTexts add relevent PlainTexts to the RRSets
|
||||
func (d *DNSQuery) SetPlainTexts(action string, reverse bool) {
|
||||
for n := range d.RRSets {
|
||||
i := d.RecordSize(n) - 1
|
||||
if i < 0 {
|
||||
return
|
||||
}
|
||||
if d.RRSets[n].ChangeType == "" {
|
||||
d.RRSets[n].ChangeType = "REPLACE"
|
||||
}
|
||||
comment := fmt.Sprintf("%s %s record %s to %s",
|
||||
action, d.RRSets[n].Type, d.RRSets[n].Name, d.RRSets[n].Records[i].Content)
|
||||
if reverse {
|
||||
comment += " and add the reverse"
|
||||
}
|
||||
d.AddPlainText(n, comment)
|
||||
}
|
||||
}
|
||||
|
||||
// ChangeValue adds a new value to the first record of the query, and creates it
|
||||
// if necessary
|
||||
func (d *DNSQuery) ChangeValue(name, value, rtype string, overwrite, reverse bool) {
|
||||
rtype = strings.ToUpper(rtype)
|
||||
name = strings.ToLower(name)
|
||||
if rtype != "TXT" {
|
||||
value = strings.ToLower(value)
|
||||
}
|
||||
if rtype != "A" && rtype != "AAAA" {
|
||||
reverse = false
|
||||
}
|
||||
action := "Append a new"
|
||||
switch {
|
||||
case len(d.RRSets) == 0:
|
||||
action = "Create a new"
|
||||
case overwrite:
|
||||
action = "Modify"
|
||||
}
|
||||
if overwrite || d.Len() == 0 {
|
||||
d.RRSets = []*DNSRRSet{{
|
||||
Name: name,
|
||||
Type: rtype,
|
||||
Records: []*DNSRecord{},
|
||||
ChangeType: "REPLACE",
|
||||
}}
|
||||
}
|
||||
d.RRSets[0].Records = append(d.RRSets[0].Records, &DNSRecord{
|
||||
Content: value,
|
||||
Disabled: false,
|
||||
SetPTR: reverse,
|
||||
})
|
||||
d.SetPlainTexts(action, reverse)
|
||||
}
|
||||
|
||||
// RecordFilter cleans a []*DNSRRSet by keeping only wanted elements
|
||||
func (d *DNSQuery) RecordFilter(r string, wanted, notWanted []string) error {
|
||||
good := []int{}
|
||||
// Add a final point to the record if necessary
|
||||
r = addPoint(r)
|
||||
for i, record := range d.RRSets {
|
||||
switch {
|
||||
// ignore reverse & co by default
|
||||
case record.Name != r && !stringInSlice("indirect", wanted):
|
||||
continue
|
||||
// check for forbidden types
|
||||
case stringInSlice(record.Type, notWanted):
|
||||
return fmt.Errorf("This entry is a %s, you must delete it first", record.Type)
|
||||
// this is the wanted type
|
||||
case stringInSlice(record.Type, wanted):
|
||||
good = append(good, i)
|
||||
// this is too
|
||||
case stringInSlice("*", wanted):
|
||||
good = append(good, i)
|
||||
// in certain case, we don't want any entry type
|
||||
case stringInSlice("*", notWanted):
|
||||
return fmt.Errorf("This entry is a %s, you must delete it first", record.Type)
|
||||
}
|
||||
}
|
||||
ret := []*DNSRRSet{}
|
||||
for _, i := range good {
|
||||
ret = append(ret, d.RRSets[i])
|
||||
}
|
||||
d.RRSets = ret
|
||||
return nil
|
||||
}
|
||||
|
||||
// Useless check if adding a new name/value couple change anything
|
||||
func (d *DNSQuery) Useless(name, value string, append bool) bool {
|
||||
// record empty, no issue
|
||||
if len(d.RRSets) == 0 {
|
||||
return false
|
||||
}
|
||||
alreadySet := false
|
||||
for _, record := range d.RRSets[0].Records {
|
||||
if record.Content == value {
|
||||
alreadySet = true
|
||||
}
|
||||
}
|
||||
return alreadySet && (len(d.RRSets[0].Records) == 1 || append)
|
||||
}
|
||||
|
||||
// RemoveValue remove a value from the first RRSet
|
||||
func (d *DNSQuery) RemoveValue(value string) bool {
|
||||
return d.RRSets[0].RemoveValue(value)
|
||||
}
|
||||
|
||||
// EmptyZone marks all lines for deletion
|
||||
func (d *DNSQuery) EmptyZone() {
|
||||
for i := range d.RRSets {
|
||||
d.RRSets[i].Delete()
|
||||
}
|
||||
}
|
||||
|
||||
func (d DNSQuery) Len() int {
|
||||
return len(d.RRSets)
|
||||
}
|
||||
|
||||
func (d DNSQuery) Swap(i, j int) {
|
||||
d.RRSets[i], d.RRSets[j] = d.RRSets[j], d.RRSets[i]
|
||||
}
|
||||
|
||||
func (d DNSQuery) Less(i, j int) bool {
|
||||
if d.RRSets[i].Type != d.RRSets[j].Type {
|
||||
if d.RRSets[i].Type == "A" {
|
||||
return false
|
||||
}
|
||||
if d.RRSets[j].Type == "A" {
|
||||
return true
|
||||
}
|
||||
return d.RRSets[i].Type < d.RRSets[j].Type
|
||||
}
|
||||
if len(d.RRSets[i].Records) == 0 || len(d.RRSets[j].Records) == 0 {
|
||||
return d.RRSets[i].Name < d.RRSets[j].Name
|
||||
}
|
||||
if d.RRSets[i].Type == d.RRSets[j].Type {
|
||||
if d.RRSets[i].Records[0].Content == d.RRSets[j].Records[0].Content {
|
||||
return d.RRSets[i].Name < d.RRSets[j].Name
|
||||
}
|
||||
switch d.RRSets[i].Type {
|
||||
case "PTR":
|
||||
return ipToInt(ptrToIP(d.RRSets[i].Name)).Cmp(ipToInt(ptrToIP(d.RRSets[j].Name))) < 0
|
||||
case "CNAME":
|
||||
return d.RRSets[i].Records[0].Content < d.RRSets[j].Records[0].Content
|
||||
case "A":
|
||||
return ipToInt(d.RRSets[i].Records[0].Content).Cmp(ipToInt(d.RRSets[j].Records[0].Content)) < 0
|
||||
case "AAAA":
|
||||
return ipToInt(d.RRSets[i].Records[0].Content).Cmp(ipToInt(d.RRSets[j].Records[0].Content)) < 0
|
||||
}
|
||||
}
|
||||
return reverse(d.RRSets[i].Name) < reverse(d.RRSets[j].Name)
|
||||
}
|
||||
|
||||
// RecordSize returns the number of records in the nth RRSets of the query
|
||||
func (d DNSQuery) RecordSize(n int) int {
|
||||
if d.Len() < n {
|
||||
return 0
|
||||
}
|
||||
return len(d.RRSets[n].Records)
|
||||
}
|
||||
|
||||
// String() Convert a DNSQuery to a Bind like string
|
||||
func (d DNSQuery) String() string {
|
||||
if len(d.RRSets) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Sort(d)
|
||||
result := ""
|
||||
padName := 40
|
||||
padTTL := 6
|
||||
padType := 5
|
||||
padContent := 0
|
||||
|
||||
for _, line := range d.RRSets {
|
||||
lName := len(line.Name)
|
||||
if line.IsDeletion() {
|
||||
lName += 3
|
||||
}
|
||||
if lName-len(d.Domain) > padName {
|
||||
padName = len(line.Name) - len(d.Domain)
|
||||
}
|
||||
if len(strconv.Itoa(line.TTL)) > padTTL {
|
||||
padTTL = len(strconv.Itoa(line.TTL))
|
||||
}
|
||||
if len(line.Type) > padType {
|
||||
padType = len(line.Type)
|
||||
}
|
||||
for _, record := range line.Records {
|
||||
if record.Disabled {
|
||||
continue
|
||||
}
|
||||
if line.Type != "SOA" && len(record.Content) > padContent {
|
||||
padContent = len(record.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
pattern := fmt.Sprintf("%%-%ds %%-%dd IN %%-%ds %%s%%s", padName, padTTL, padType)
|
||||
for _, line := range d.RRSets {
|
||||
comment := ""
|
||||
for _, c := range line.Comment {
|
||||
if c.Account != "" {
|
||||
c.Account = " - " + c.Account
|
||||
}
|
||||
comment += fmt.Sprintf(" ; %s%s", c.Content, c.Account)
|
||||
}
|
||||
sort.Sort(line)
|
||||
name := line.Name
|
||||
|
||||
if line.IsDeletion() {
|
||||
name = fmt.Sprintf(";; %s", name)
|
||||
}
|
||||
idx := strings.LastIndex(name, "."+d.Domain)
|
||||
if idx > 0 {
|
||||
name = name[:idx]
|
||||
}
|
||||
if name == d.Domain {
|
||||
name = "@"
|
||||
}
|
||||
if len(line.Records) == 0 {
|
||||
newLine := fmt.Sprintf(pattern, name, line.TTL, line.Type, "", "")
|
||||
result += strings.Trim(newLine, " ") + "\n"
|
||||
}
|
||||
for _, record := range line.Records {
|
||||
if record.Disabled {
|
||||
continue
|
||||
}
|
||||
newLine := fmt.Sprintf(pattern, name, line.TTL, line.Type, record.Content, comment)
|
||||
result += strings.Trim(newLine, " ") + "\n"
|
||||
}
|
||||
}
|
||||
if d.Domain != "" {
|
||||
result = fmt.Sprintf("$ORIGIN %s\n%s", d.Domain, result)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AddPlainText set the PlainText of the nth record of the query
|
||||
func (d *DNSQuery) AddPlainText(n int, s string) {
|
||||
if d.Len() < n {
|
||||
return
|
||||
}
|
||||
d.RRSets[n].AddPlainText(s)
|
||||
}
|
||||
|
||||
// PlainTexts returns the list of plaintext of each record
|
||||
func (d *DNSQuery) PlainTexts() []string {
|
||||
ret := []string{}
|
||||
for record := range d.RRSets {
|
||||
if d.RRSets[record].PlainText == "" {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, d.RRSets[record].PlainText)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddCommentAndTTL adjust the DNSQuery before execution
|
||||
func (d *DNSQuery) AddCommentAndTTL(user, content string, ttl int) {
|
||||
// don't touch if the query is not an executable one
|
||||
if !d.HasDomain() {
|
||||
return
|
||||
}
|
||||
for record := range d.RRSets {
|
||||
if d.RRSets[record].IsDeletion() {
|
||||
d.RRSets[record].Comment = nil
|
||||
continue
|
||||
}
|
||||
d.RRSets[record].TTL = ttl
|
||||
// don't change the comments if you don't have to
|
||||
if len(content) > 0 {
|
||||
d.RRSets[record].Comment = []*DNSComments{{
|
||||
Account: user,
|
||||
Content: content,
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SplitDeletionQuery split the the result of a GetRecord() for a deletion
|
||||
// change the DNSQuery RRSets to be valid for the deletion, and return the
|
||||
// list of affected reverses
|
||||
func (d *DNSQuery) SplitDeletionQuery(name, value string) ([]*DNSRRSet, bool, error) {
|
||||
var remain, indirection bool
|
||||
filter := []*DNSRRSet{}
|
||||
actions := []*DNSRRSet{}
|
||||
reverses := []*DNSRRSet{}
|
||||
for _, entry := range d.RRSets {
|
||||
switch {
|
||||
// protect the zone NS
|
||||
case entry.Type == "NS" && d.Domain+"." == entry.Name:
|
||||
filter = append(filter, entry)
|
||||
// protect the SOA
|
||||
case entry.Type == "SOA":
|
||||
filter = append(filter, entry)
|
||||
// simple case : remove every entries
|
||||
case entry.Name == name && value == "":
|
||||
entry.Delete()
|
||||
actions = append(actions, entry)
|
||||
// a value was specified, need to remove the exact value
|
||||
case entry.Name == name:
|
||||
filter = append(filter, entry)
|
||||
if entry.RemoveValue(value) {
|
||||
actions = append(actions, entry)
|
||||
// we keep some entries
|
||||
remain = remain || entry.Len() > 0
|
||||
}
|
||||
// if we remove a wildcard, no need to check the reverse and
|
||||
// indirections
|
||||
case name[0] == '*':
|
||||
continue
|
||||
case entry.Type != "PTR":
|
||||
indirection = true
|
||||
// entry.Type has now to be PTR
|
||||
case value == "":
|
||||
entry.RemoveValue(name)
|
||||
reverses = append(reverses, entry)
|
||||
case value == ptrToIP(entry.Name):
|
||||
entry.RemoveValue(name)
|
||||
reverses = append(reverses, entry)
|
||||
}
|
||||
}
|
||||
// test if there is something to do
|
||||
if len(actions)+len(filter) == 0 {
|
||||
return nil, false, fmt.Errorf("Unknown entry")
|
||||
}
|
||||
if len(actions) == 0 {
|
||||
d.RRSets = filter // display only the direct entries, not the indirect ones
|
||||
return nil, false, nil
|
||||
}
|
||||
if indirection && !remain {
|
||||
return nil, false, fmt.Errorf("there are records pointing to %s, please delete them first", name)
|
||||
}
|
||||
d.RRSets = actions
|
||||
return reverses, true, nil
|
||||
}
|
||||
|
||||
func (d *DNSRRSet) Len() int {
|
||||
return len(d.Records)
|
||||
}
|
||||
|
||||
func (d *DNSRRSet) Swap(i, j int) {
|
||||
d.Records[i], d.Records[j] = d.Records[j], d.Records[i]
|
||||
}
|
||||
|
||||
func (d *DNSRRSet) Less(i, j int) bool {
|
||||
cmp := ipToInt(d.Records[i].Content).Cmp(ipToInt(d.Records[j].Content))
|
||||
if cmp == 0 {
|
||||
return d.Records[i].Content < d.Records[j].Content
|
||||
}
|
||||
return cmp < 0
|
||||
}
|
||||
|
||||
func (d *DNSRRSet) String() string {
|
||||
return string(printJSON(d))
|
||||
}
|
||||
|
||||
// RemoveValue remove a value from a RRSet
|
||||
func (d *DNSRRSet) RemoveValue(value string) bool {
|
||||
newRecords := []*DNSRecord{}
|
||||
for _, r := range d.Records {
|
||||
switch r.Content {
|
||||
case value:
|
||||
continue
|
||||
case addPoint(value):
|
||||
continue
|
||||
case fmt.Sprintf("\"%s\"", value):
|
||||
continue
|
||||
default:
|
||||
newRecords = append(newRecords, r)
|
||||
}
|
||||
}
|
||||
// no record removed
|
||||
if len(newRecords) == d.Len() {
|
||||
return false
|
||||
}
|
||||
// no record left, change the RRSet ChangeType to DELETE
|
||||
if len(newRecords) == 0 {
|
||||
d.Delete()
|
||||
return true
|
||||
}
|
||||
d.PlainText = fmt.Sprintf("Update %s by removing %s", d.Name, value)
|
||||
d.Records = newRecords
|
||||
return true
|
||||
}
|
||||
|
||||
// DNSQuery return a DNSQuery containing the current DNSRRSet
|
||||
func (d *DNSRRSet) DNSQuery(domain string) *DNSQuery {
|
||||
return &DNSQuery{RRSets: []*DNSRRSet{d}, Domain: domain}
|
||||
}
|
||||
|
||||
// AddPlainText set the PlainText of the rrset
|
||||
func (d *DNSRRSet) AddPlainText(s string) {
|
||||
d.PlainText = s
|
||||
}
|
||||
|
||||
// Delete change tye RRSet ChangeType to DELETE
|
||||
func (d *DNSRRSet) Delete() {
|
||||
d.ChangeType = "DELETE"
|
||||
d.PlainText = fmt.Sprintf("Removing %s %s", d.Type, d.Name)
|
||||
}
|
||||
|
||||
// IsDeletion tells if d is a DELETE or not
|
||||
func (d *DNSRRSet) IsDeletion() bool {
|
||||
return d.ChangeType == "DELETE"
|
||||
}
|
||||
|
||||
func (d DNSRecord) String() string {
|
||||
return string(printJSON(d))
|
||||
}
|
||||
|
||||
func (d DNSComments) String() string {
|
||||
return string(printJSON(d))
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
package main
|
||||
|
||||
type (
|
||||
// DNSSearch is the json structure of a search query response
|
||||
DNSSearch []*DNSSearchEntry
|
||||
|
||||
// DNSSearchEntry is an element of a DNSSearch
|
||||
DNSSearchEntry struct {
|
||||
Content string `json:"content"`
|
||||
Disabled bool `json:"disabled"`
|
||||
Name string `json:"name"`
|
||||
ObjectType string `json:"object_type"`
|
||||
TTL int `json:"ttl"`
|
||||
Type string `json:"type"`
|
||||
Zone string `json:"zone"`
|
||||
ZoneID string `json:"zone_id"`
|
||||
}
|
||||
)
|
||||
|
||||
// IsRecord tells if d is a record or not
|
||||
func (d DNSSearchEntry) IsRecord() bool {
|
||||
return d.ObjectType == "record"
|
||||
}
|
||||
|
||||
// IsComment tells if d is a comment or not
|
||||
func (d DNSSearchEntry) IsComment() bool {
|
||||
return d.ObjectType == "comment"
|
||||
}
|
||||
|
||||
// DNSQuery converts a DNSSearch to a DNSQuery
|
||||
func (d DNSSearch) DNSQuery() *DNSQuery {
|
||||
tempRet := map[string]map[string]*DNSRRSet{}
|
||||
tempContent := map[string]map[string][]*DNSRecord{}
|
||||
comments := map[string][]*DNSComments{}
|
||||
|
||||
for _, record := range d {
|
||||
switch {
|
||||
// ignore disabled records
|
||||
case record.Disabled:
|
||||
continue
|
||||
// store the comments
|
||||
case record.IsComment():
|
||||
comments[record.Name] = []*DNSComments{{Content: record.Content}}
|
||||
continue
|
||||
// ignore non records
|
||||
case !record.IsRecord():
|
||||
continue
|
||||
}
|
||||
// check if we already encounter this record
|
||||
if _, ok := tempRet[record.Name]; !ok {
|
||||
tempRet[record.Name] = map[string]*DNSRRSet{}
|
||||
tempContent[record.Name] = map[string][]*DNSRecord{}
|
||||
}
|
||||
// store the entry
|
||||
if _, ok := tempRet[record.Name][record.Type]; !ok {
|
||||
tempRet[record.Name][record.Type] = &DNSRRSet{
|
||||
Name: record.Name,
|
||||
Type: record.Type,
|
||||
ChangeType: "REPLACE",
|
||||
TTL: record.TTL,
|
||||
}
|
||||
tempContent[record.Name][record.Type] = []*DNSRecord{}
|
||||
}
|
||||
// and store the content
|
||||
tempContent[record.Name][record.Type] = append(tempContent[record.Name][record.Type], &DNSRecord{
|
||||
Content: record.Content,
|
||||
Disabled: false,
|
||||
SetPTR: false,
|
||||
})
|
||||
}
|
||||
// stitch everything together
|
||||
retRRSet := []*DNSRRSet{}
|
||||
|
||||
for rName, recordsByName := range tempRet {
|
||||
for rType, record := range recordsByName {
|
||||
if c, ok := comments[rName]; ok {
|
||||
record.Comment = c
|
||||
}
|
||||
if content, ok := tempContent[rName][rType]; ok {
|
||||
record.Records = content
|
||||
}
|
||||
retRRSet = append(retRRSet, record)
|
||||
}
|
||||
}
|
||||
return &DNSQuery{RRSets: retRRSet}
|
||||
}
|
||||
|
||||
func (d *DNSSearch) String() string {
|
||||
return string(printJSON(d))
|
||||
}
|
|
@ -0,0 +1,158 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/jarcoal/httpmock"
|
||||
)
|
||||
|
||||
var (
|
||||
pingBody = `{
|
||||
"url": "/api/v1",
|
||||
"version": 1
|
||||
}`
|
||||
getZonesBody = `[{
|
||||
"account": "",
|
||||
"dnssec": false,
|
||||
"id": "example.com.",
|
||||
"kind": "Native",
|
||||
"last_check": 0,
|
||||
"masters": [],
|
||||
"name": "example.com.",
|
||||
"notified_serial": 0,
|
||||
"serial": 1,
|
||||
"url": "/api/v1/servers/localhost/zones/example.com."},
|
||||
{"account": "",
|
||||
"dnssec": false,
|
||||
"id": "example2.net.",
|
||||
"kind": "Native",
|
||||
"last_check": 0,
|
||||
"masters": [],
|
||||
"name": "example2.net.",
|
||||
"notified_serial": 0,
|
||||
"serial": 2019091801,
|
||||
"url": "/api/v1/servers/localhost/zones/example2.net."
|
||||
}]`
|
||||
badPatchBody = `{
|
||||
"error": "DNS Name 'email.example.com' is not canonical"
|
||||
}`
|
||||
addMasterZonePayload = `{
|
||||
"name":"example.org",
|
||||
"kind": "Master",
|
||||
"dnssec":false,
|
||||
"soa-edit":"INCEPTION-INCREMENT",
|
||||
"masters": [],
|
||||
"nameservers": ["ns1.example.org"]
|
||||
}`
|
||||
)
|
||||
|
||||
func TestPdnsPing(t *testing.T) {
|
||||
httpmock.Activate()
|
||||
defer httpmock.DeactivateAndReset()
|
||||
httpmock.RegisterResponder("GET", testBaseURL+"/api",
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("X-Api-Key") == testAPIKey {
|
||||
return httpmock.NewJsonResponse(201, pingBody)
|
||||
}
|
||||
return httpmock.NewStringResponse(401, "Unauthorized"), nil
|
||||
},
|
||||
)
|
||||
var response json.RawMessage
|
||||
testClient, _ := initializePowerDNSTestClient()
|
||||
_, _, err := testClient.sendQuery(context.Background(), "api", "GET", nil, &response)
|
||||
if err != nil {
|
||||
t.Errorf("%s", err)
|
||||
}
|
||||
isEqual, _ := areEqualJSON([]byte(strconv.Quote(pingBody)), response)
|
||||
if !isEqual {
|
||||
t.Error("Test Pdns Ping - Failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPdnsGetZones(t *testing.T) {
|
||||
httpmock.Activate()
|
||||
defer httpmock.DeactivateAndReset()
|
||||
httpmock.RegisterResponder("GET", generateTestAPIVhostURL()+"/zones",
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("X-Api-Key") == testAPIKey {
|
||||
return httpmock.NewJsonResponse(201, getZonesBody)
|
||||
}
|
||||
return httpmock.NewStringResponse(401, "Unauthorized"), nil
|
||||
},
|
||||
)
|
||||
var response json.RawMessage
|
||||
testClient, _ := initializePowerDNSTestClient()
|
||||
_, _, err := testClient.sendQuery(context.Background(), generateTestRequestURI()+"/zones", "GET", nil, &response)
|
||||
if err != nil {
|
||||
t.Errorf("%s", err)
|
||||
}
|
||||
isEqual, _ := areEqualJSON([]byte(strconv.Quote(getZonesBody)), response)
|
||||
if !isEqual {
|
||||
t.Error("Test Pdns getZones - Failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPdnsBadGet(t *testing.T) {
|
||||
httpmock.Activate()
|
||||
defer httpmock.DeactivateAndReset()
|
||||
httpmock.RegisterResponder("GET", generateTestAPIVhostURL()+"/zone",
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("X-Api-Key") == testAPIKey {
|
||||
return httpmock.NewJsonResponse(404, "Not found")
|
||||
}
|
||||
return httpmock.NewStringResponse(401, "Unauthorized"), nil
|
||||
},
|
||||
)
|
||||
var response json.RawMessage
|
||||
testClient, _ := initializePowerDNSTestClient()
|
||||
testClient.sendQuery(context.Background(), generateTestRequestURI()+"/zone", "GET", nil, &response)
|
||||
isEqual, _ := areEqualJSON([]byte(strconv.Quote("Not found")), response)
|
||||
if !isEqual {
|
||||
t.Error("Test Pdns bad get - Failed")
|
||||
}
|
||||
}
|
||||
|
||||
func testPdnsBadPatch(t *testing.T) {
|
||||
httpmock.Activate()
|
||||
defer httpmock.DeactivateAndReset()
|
||||
httpmock.RegisterResponder("PATCH", generateTestAPIVhostURL()+"/zones/example.com.",
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("X-Api-Key") == testAPIKey {
|
||||
return httpmock.NewJsonResponse(422, badPatchBody)
|
||||
}
|
||||
return httpmock.NewStringResponse(401, "Unauthorized"), nil
|
||||
},
|
||||
)
|
||||
var response json.RawMessage
|
||||
testClient, _ := initializePowerDNSTestClient()
|
||||
testClient.sendQuery(context.Background(), generateTestRequestURI()+"/zones/example.com.", "PATCH", nil, &response)
|
||||
isEqual, _ := areEqualJSON([]byte(strconv.Quote(badPatchBody)), response)
|
||||
if !isEqual {
|
||||
t.Error("Test Pdns bad patch - Failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPdnsPatchRecord(t *testing.T) {
|
||||
httpmock.Activate()
|
||||
defer httpmock.DeactivateAndReset()
|
||||
httpmock.RegisterResponder("PATCH", generateTestAPIVhostURL()+"/zones/example.com.",
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("X-Api-Key") == testAPIKey {
|
||||
return httpmock.NewJsonResponse(204, nil)
|
||||
}
|
||||
return httpmock.NewStringResponse(401, "Unauthorized"), nil
|
||||
},
|
||||
)
|
||||
var response json.RawMessage
|
||||
testClient, _ := initializePowerDNSTestClient()
|
||||
testClient.sendQuery(context.Background(), generateTestRequestURI()+"/zones/example.com.", "PATCH", nil, &response)
|
||||
t.Logf("Response is: %s", response)
|
||||
if !bytes.Equal(response, []byte("null")) {
|
||||
t.Error("Test Pdns patch record - Failed")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pyke369/golang-support/rcache"
|
||||
)
|
||||
|
||||
type (
|
||||
// DNSZones is a list of DNSZone
|
||||
DNSZones []*DNSZone
|
||||
|
||||
// DNSZone is the json structure of a DNS zone in the PDNS API
|
||||
DNSZone struct {
|
||||
DNSQuery
|
||||
Name string `json:"name"`
|
||||
Kind string `json:"kind"`
|
||||
NameServers []string `json:"nameservers"`
|
||||
SOAEditAPI string `json:"soa_edit_api"`
|
||||
}
|
||||
)
|
||||
|
||||
type listCache struct {
|
||||
result []*DNSZone
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
// List all zones in the DNSZones struct as string
|
||||
func (z DNSZones) List(sep string) string {
|
||||
ret := []string{}
|
||||
for _, l := range z {
|
||||
ret = append(ret, trimPoint(l.Name))
|
||||
}
|
||||
domainSort(ret)
|
||||
return strings.Join(ret, sep)
|
||||
}
|
||||
|
||||
// NewZone returns a new DNSZone struct, use for domain creation
|
||||
func (p *PowerDNS) NewZone(name, zoneType, soa, user, comment string, ttl int, nameServers []string, autoInc bool) (*DNSZone, error) {
|
||||
z := &DNSZone{
|
||||
Name: name,
|
||||
Kind: zoneType,
|
||||
NameServers: nameServers,
|
||||
}
|
||||
if autoInc {
|
||||
z.SOAEditAPI = "INCEPTION-INCREMENT"
|
||||
}
|
||||
// create SOA record
|
||||
soaRec, err := p.GetRecord(name, "SOA", true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
soaRec.ChangeValue(name, soa, "SOA", false, false)
|
||||
soaRec.ChangeDomain(name)
|
||||
soaRec.AddCommentAndTTL(user, comment, ttl)
|
||||
|
||||
// merge it into z
|
||||
z.RRSets = soaRec.RRSets
|
||||
return z, nil
|
||||
}
|
||||
|
||||
func (z *DNSZone) String() string {
|
||||
return string(printJSON(z))
|
||||
}
|
||||
|
||||
// AddEntries put all RRSets in q into z
|
||||
func (z *DNSZone) AddEntries(q *DNSQuery) bool {
|
||||
if q.Domain != z.Name {
|
||||
return false
|
||||
}
|
||||
z.RRSets = append(z.RRSets, q.RRSets...)
|
||||
return true
|
||||
}
|
||||
|
||||
// TransformIntoDNSQuery converts a DNSZone into a DNSQuery
|
||||
func (z *DNSZone) TransformIntoDNSQuery() *DNSQuery {
|
||||
return &DNSQuery{
|
||||
Domain: z.Name,
|
||||
RRSets: z.RRSets,
|
||||
}
|
||||
}
|
||||
|
||||
// ListZones list every zone matchin the re regexp if not empty
|
||||
func (p *PowerDNS) ListZones(re string) (DNSZones, error) {
|
||||
now := time.Now()
|
||||
if cache, ok := p.listCache[re]; ok && len(cache.result) > 0 && now.Before(cache.expire) {
|
||||
p.lock()
|
||||
cache.expire = now.Add(3 * time.Minute)
|
||||
p.unlock()
|
||||
return cache.result, nil
|
||||
}
|
||||
list, err := p.listZones(re)
|
||||
if err == nil && len(list) > 0 {
|
||||
p.lock()
|
||||
p.listCache[re] = &listCache{
|
||||
result: list,
|
||||
expire: now.Add(3 * time.Minute),
|
||||
}
|
||||
p.unlock()
|
||||
}
|
||||
return list, err
|
||||
}
|
||||
|
||||
// ListZones list every zone matchin the re regexp if not empty
|
||||
func (p *PowerDNS) listZones(re string) (DNSZones, error) {
|
||||
domains := DNSZones{}
|
||||
ret := DNSZones{}
|
||||
if _, _, err := p.sendQuery(context.Background(), fmt.Sprintf("%s/%s", p.apiURL, "zones"), "GET", nil, &domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if re == "" || re == "." {
|
||||
re = ".*"
|
||||
}
|
||||
matcher := rcache.Get(fmt.Sprintf("^%s$", re))
|
||||
if matcher == nil {
|
||||
return nil, errors.New("invalid regexp")
|
||||
}
|
||||
for _, domain := range domains {
|
||||
if !matcher.MatchString(domain.Name) {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, domain)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// GetDomain find the domain where the record r should be
|
||||
// ex : www.example.org should be in the domain example.org, or org, or doesn't
|
||||
// exist and we return an error
|
||||
func (p *PowerDNS) GetDomain(r string) (string, error) {
|
||||
domains, err := p.ListZones("")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Remove the final point of the record if necessary
|
||||
r = trimPoint(r)
|
||||
best := ""
|
||||
|
||||
for _, domain := range domains {
|
||||
possible := trimPoint(domain.Name)
|
||||
if possible == r {
|
||||
return r, nil
|
||||
}
|
||||
if r == possible || strings.HasSuffix(r, "."+possible) && len(possible) > len(best) {
|
||||
best = possible
|
||||
}
|
||||
}
|
||||
if len(best) == 0 {
|
||||
return "", errors.New("Unknown domain")
|
||||
}
|
||||
// check there is no delegation down the road
|
||||
subs := strings.Split(strings.TrimSuffix(r, "."+best), ".")
|
||||
current := best
|
||||
rp := addPoint(r)
|
||||
for i := len(subs) - 1; i > 0; i-- {
|
||||
current = subs[i] + "." + current
|
||||
search, err := p.Search(fmt.Sprintf(current))
|
||||
if err != nil {
|
||||
return best, nil
|
||||
}
|
||||
for _, entry := range search {
|
||||
if entry.IsRecord() && // ignore non record
|
||||
entry.Type == "NS" && // need a NS
|
||||
strings.HasSuffix(rp, "."+entry.Name) {
|
||||
return best, fmt.Errorf("%s is delegated to another server", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
return best, nil
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
const (
|
||||
testBaseURL string = "http://localhost:8081"
|
||||
testVhost string = "localhost"
|
||||
testAPIKey string = "123Password"
|
||||
)
|
||||
|
||||
func generateTestAPIURL() string {
|
||||
return fmt.Sprintf("%s/api/v1", testBaseURL)
|
||||
}
|
||||
|
||||
func generateTestAPIVhostURL() string {
|
||||
return fmt.Sprintf("%s/servers/%s", generateTestAPIURL(), testVhost)
|
||||
}
|
||||
|
||||
func generateTestRequestURI() string {
|
||||
return fmt.Sprintf("api/v1/servers/%s", testVhost)
|
||||
}
|
||||
|
||||
func initializePowerDNSTestClient() (*PowerDNS, error) {
|
||||
pdns, err := NewClient(testBaseURL, testAPIKey, 3, 7200)
|
||||
return pdns, err
|
||||
}
|
||||
|
||||
func generateTestZone() string {
|
||||
domain := fmt.Sprintf("test-%d.com", rand.Int())
|
||||
return domain
|
||||
}
|
||||
|
||||
func areEqualJSON(s1, s2 []byte) (bool, error) {
|
||||
var o1 interface{}
|
||||
var o2 interface{}
|
||||
|
||||
var err error
|
||||
err = json.Unmarshal([]byte(s1), &o1)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("Error mashalling string 1 :: %s", err.Error())
|
||||
}
|
||||
err = json.Unmarshal(s2, &o2)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("Error mashalling string 2 :: %s", err.Error())
|
||||
}
|
||||
return reflect.DeepEqual(o1, o2), nil
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pyke369/golang-support/rcache"
|
||||
)
|
||||
|
||||
func stringInSlice(a string, list []string) bool {
|
||||
for _, b := range list {
|
||||
if b == a {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// inArray checks if there is a member of "search" in "list". sort list to be more
|
||||
// efficiant. adapt with sorting "search" too
|
||||
func inArray(search, list []string) bool {
|
||||
sort.Strings(list)
|
||||
for _, g := range search {
|
||||
i := sort.Search(len(list), func(i int) bool { return list[i] >= g })
|
||||
if i < len(list) && list[i] == g {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// printJSON pretty print d
|
||||
func printJSON(d interface{}) []byte {
|
||||
ret, _ := json.MarshalIndent(d, " ", " ")
|
||||
return ret
|
||||
}
|
||||
|
||||
// reverse the part order on every member of the array
|
||||
func reverse(s string) string {
|
||||
part := strings.Split(s, ".")
|
||||
for i, j := 0, len(part)-1; i < j; i, j = i+1, j-1 {
|
||||
part[i], part[j] = part[j], part[i]
|
||||
}
|
||||
return strings.Join(part, ".")
|
||||
}
|
||||
|
||||
// domainSort sort string by first reversing them
|
||||
func domainSort(s []string) {
|
||||
for i := range s {
|
||||
s[i] = reverse(s[i])
|
||||
}
|
||||
sort.Strings(s)
|
||||
for i := range s {
|
||||
s[i] = reverse(s[i])
|
||||
}
|
||||
}
|
||||
|
||||
func ptrToIP(s string) string {
|
||||
s = reverse(s)
|
||||
count := 0
|
||||
ip := ""
|
||||
version := 4
|
||||
for _, elt := range strings.Split(s, ".") {
|
||||
switch elt {
|
||||
case "":
|
||||
case "ip6":
|
||||
version = 6
|
||||
case "in-addr":
|
||||
case "arpa":
|
||||
default:
|
||||
count++
|
||||
ip += elt
|
||||
if version == 4 && count != 4 {
|
||||
ip += "."
|
||||
}
|
||||
if version == 6 && count%4 == 0 && count != 32 {
|
||||
ip += ":"
|
||||
}
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func ipToInt(ip string) *big.Int {
|
||||
return big.NewInt(0).SetBytes([]byte(net.ParseIP(ip)))
|
||||
}
|
||||
|
||||
// iPtoReverse calculate the reverse name associated with an IPv4 or v6
|
||||
func iPtoReverse(ip net.IP) (arpa string) {
|
||||
const hexDigit = "0123456789abcdef"
|
||||
// code copied and adapted from the net library
|
||||
// ip can be 4 or 16 bytes long
|
||||
if ip.To4() != nil {
|
||||
if len(ip) == 16 {
|
||||
return uitoa(uint(ip[15])) + "." + uitoa(uint(ip[14])) + "." + uitoa(uint(ip[13])) + "." + uitoa(uint(ip[12])) + ".in-addr.arpa."
|
||||
}
|
||||
return uitoa(uint(ip[3])) + "." + uitoa(uint(ip[2])) + "." + uitoa(uint(ip[1])) + "." + uitoa(uint(ip[0])) + ".in-addr.arpa."
|
||||
}
|
||||
// Must be IPv6
|
||||
buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
|
||||
|
||||
// Add it, in reverse, to the buffer
|
||||
for i := len(ip) - 1; i >= 0; i-- {
|
||||
v := ip[i]
|
||||
buf = append(buf, hexDigit[v&0xF])
|
||||
buf = append(buf, '.')
|
||||
buf = append(buf, hexDigit[v>>4])
|
||||
buf = append(buf, '.')
|
||||
}
|
||||
// Append "ip6.arpa." and return (buf already has the final .)
|
||||
buf = append(buf, "ip6.arpa."...)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// Convert unsigned integer to decimal string.
|
||||
// code copied from the net library
|
||||
func uitoa(val uint) string {
|
||||
if val == 0 { // avoid string allocation
|
||||
return "0"
|
||||
}
|
||||
var buf [20]byte // big enough for 64bit value base 10
|
||||
i := len(buf) - 1
|
||||
for val >= 10 {
|
||||
q := val / 10
|
||||
buf[i] = byte('0' + val - q*10)
|
||||
i--
|
||||
val = q
|
||||
}
|
||||
// val < 10
|
||||
buf[i] = byte('0' + val)
|
||||
return string(buf[i:])
|
||||
}
|
||||
|
||||
func trimPoint(s string) string {
|
||||
return strings.TrimRight(s, ".")
|
||||
}
|
||||
|
||||
func addPoint(s string) string {
|
||||
s = trimPoint(s)
|
||||
if !strings.HasSuffix(s, ".") && !strings.HasSuffix(s, "*") {
|
||||
s = s + "."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// add double quotes at the begining and the end of the string.
|
||||
func addQuotes(s string) string {
|
||||
const q = "\""
|
||||
s = strings.Trim(s, q)
|
||||
s = strings.Replace(s, q, "\\\"", -1)
|
||||
return q + s + q
|
||||
}
|
||||
|
||||
func validSRVName(s string) bool {
|
||||
return rcache.Get("^_[a-z0-9]+\\._(tcp|udp|tls)[^ ]+$").MatchString(s)
|
||||
}
|
||||
|
||||
func validSRV(s string) string {
|
||||
// _service._proto.name. TTL class SRV priority weight port target.
|
||||
match := rcache.Get("^([0-9]+) +([0-9]+) +([0-9]+) +(.+)$").FindStringSubmatch(s)
|
||||
if len(match) > 0 && validName(match[4], false) {
|
||||
return match[4]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func validName(s string, wildcard bool) bool {
|
||||
if !wildcard {
|
||||
return rcache.Get(
|
||||
"^(([a-z0-9]|[a-z0-9_][a-z0-9_\\-]*[a-z0-9])\\.)*([a-z0-9]|[a-z0-9][a-z0-9\\-]*[a-z0-9])\\.$").MatchString(s)
|
||||
}
|
||||
return rcache.Get(
|
||||
"^(\\*\\.)?(([a-z0-9]|[a-z0-9_][a-z0-9_\\-]*[a-z0-9])\\.)*([a-z0-9]|[a-z0-9][a-z0-9\\-]*[a-z0-9])\\.$").MatchString(s)
|
||||
}
|
||||
|
||||
func validMX(s string) string {
|
||||
match := rcache.Get("^([0-9]+) +(.+)$").FindStringSubmatch(s)
|
||||
if len(match) > 0 && validName(match[2], false) {
|
||||
return match[2]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func validCAA(s string) string {
|
||||
match := rcache.Get("^([0-9]+) +(issue|issuewild|iodef) +(.+)$").FindStringSubmatch(s)
|
||||
if len(match) == 0 {
|
||||
return ""
|
||||
}
|
||||
i, _ := strconv.Atoi(match[1])
|
||||
if i < 0 || i > 255 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s %s %s", match[1], match[2], addQuotes(match[3]))
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAddQuotes(t *testing.T) {
|
||||
for s, wanted := range map[string]string{
|
||||
"toto": "\"toto\"",
|
||||
"\"titi": "\"titi\"",
|
||||
"\"tutu\"": "\"tutu\"",
|
||||
"te\"te": "\"te\\\"te\"",
|
||||
"\"\"ta\"ta": "\"ta\\\"ta\"",
|
||||
} {
|
||||
if res := addQuotes(s); res != wanted {
|
||||
t.Errorf("%s quoted as %s and not %s", s, res, wanted)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
sudo: false
|
||||
|
||||
language: go
|
||||
|
||||
before_script:
|
||||
- go get -u golang.org/x/lint/golint
|
||||
|
||||
go:
|
||||
- 1.10.x
|
||||
- master
|
||||
|
||||
script:
|
||||
- test -z "$(gofmt -s -l . | tee /dev/stderr)"
|
||||
- test -z "$(golint ./... | tee /dev/stderr)"
|
||||
- go vet ./...
|
||||
- go build -v ./...
|
||||
- go test -v ./...
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Microsoft
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,29 @@
|
|||
# go-ntlmssp
|
||||
Golang package that provides NTLM/Negotiate authentication over HTTP
|
||||
|
||||
[![GoDoc](https://godoc.org/github.com/Azure/go-ntlmssp?status.svg)](https://godoc.org/github.com/Azure/go-ntlmssp) [![Build Status](https://travis-ci.org/Azure/go-ntlmssp.svg?branch=dev)](https://travis-ci.org/Azure/go-ntlmssp)
|
||||
|
||||
Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx
|
||||
Implementation hints from http://davenport.sourceforge.net/ntlm.html
|
||||
|
||||
This package only implements authentication, no key exchange or encryption. It
|
||||
only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding.
|
||||
This package implements NTLMv2.
|
||||
|
||||
# Usage
|
||||
|
||||
```
|
||||
url, user, password := "http://www.example.com/secrets", "robpike", "pw123"
|
||||
client := &http.Client{
|
||||
Transport: ntlmssp.Negotiator{
|
||||
RoundTripper:&http.Transport{},
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req.SetBasicAuth(user, password)
|
||||
res, _ := client.Do(req)
|
||||
```
|
||||
|
||||
-----
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,41 @@
|
|||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.8 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
|
@ -0,0 +1,187 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type authenicateMessage struct {
|
||||
LmChallengeResponse []byte
|
||||
NtChallengeResponse []byte
|
||||
|
||||
TargetName string
|
||||
UserName string
|
||||
|
||||
// only set if negotiateFlag_NTLMSSP_NEGOTIATE_KEY_EXCH
|
||||
EncryptedRandomSessionKey []byte
|
||||
|
||||
NegotiateFlags negotiateFlags
|
||||
|
||||
MIC []byte
|
||||
}
|
||||
|
||||
type authenticateMessageFields struct {
|
||||
messageHeader
|
||||
LmChallengeResponse varField
|
||||
NtChallengeResponse varField
|
||||
TargetName varField
|
||||
UserName varField
|
||||
Workstation varField
|
||||
_ [8]byte
|
||||
NegotiateFlags negotiateFlags
|
||||
}
|
||||
|
||||
func (m authenicateMessage) MarshalBinary() ([]byte, error) {
|
||||
if !m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE) {
|
||||
return nil, errors.New("Only unicode is supported")
|
||||
}
|
||||
|
||||
target, user := toUnicode(m.TargetName), toUnicode(m.UserName)
|
||||
workstation := toUnicode("")
|
||||
|
||||
ptr := binary.Size(&authenticateMessageFields{})
|
||||
f := authenticateMessageFields{
|
||||
messageHeader: newMessageHeader(3),
|
||||
NegotiateFlags: m.NegotiateFlags,
|
||||
LmChallengeResponse: newVarField(&ptr, len(m.LmChallengeResponse)),
|
||||
NtChallengeResponse: newVarField(&ptr, len(m.NtChallengeResponse)),
|
||||
TargetName: newVarField(&ptr, len(target)),
|
||||
UserName: newVarField(&ptr, len(user)),
|
||||
Workstation: newVarField(&ptr, len(workstation)),
|
||||
}
|
||||
|
||||
f.NegotiateFlags.Unset(negotiateFlagNTLMSSPNEGOTIATEVERSION)
|
||||
|
||||
b := bytes.Buffer{}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &m.LmChallengeResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &m.NtChallengeResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &workstation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
//ProcessChallenge crafts an AUTHENTICATE message in response to the CHALLENGE message
|
||||
//that was received from the server
|
||||
func ProcessChallenge(challengeMessageData []byte, user, password string, domainNeeded bool) ([]byte, error) {
|
||||
if user == "" && password == "" {
|
||||
return nil, errors.New("Anonymous authentication not supported")
|
||||
}
|
||||
|
||||
var cm challengeMessage
|
||||
if err := cm.UnmarshalBinary(challengeMessageData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) {
|
||||
return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)")
|
||||
}
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) {
|
||||
return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)")
|
||||
}
|
||||
|
||||
if !domainNeeded {
|
||||
cm.TargetName = ""
|
||||
}
|
||||
|
||||
am := authenicateMessage{
|
||||
UserName: user,
|
||||
TargetName: cm.TargetName,
|
||||
NegotiateFlags: cm.NegotiateFlags,
|
||||
}
|
||||
|
||||
timestamp := cm.TargetInfo[avIDMsvAvTimestamp]
|
||||
if timestamp == nil { // no time sent, take current time
|
||||
ft := uint64(time.Now().UnixNano()) / 100
|
||||
ft += 116444736000000000 // add time between unix & windows offset
|
||||
timestamp = make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(timestamp, ft)
|
||||
}
|
||||
|
||||
clientChallenge := make([]byte, 8)
|
||||
rand.Reader.Read(clientChallenge)
|
||||
|
||||
ntlmV2Hash := getNtlmV2Hash(password, user, cm.TargetName)
|
||||
|
||||
am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw)
|
||||
|
||||
if cm.TargetInfoRaw == nil {
|
||||
am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge)
|
||||
}
|
||||
return am.MarshalBinary()
|
||||
}
|
||||
|
||||
func ProcessChallengeWithHash(challengeMessageData []byte, user, hash string) ([]byte, error) {
|
||||
if user == "" && hash == "" {
|
||||
return nil, errors.New("Anonymous authentication not supported")
|
||||
}
|
||||
|
||||
var cm challengeMessage
|
||||
if err := cm.UnmarshalBinary(challengeMessageData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) {
|
||||
return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)")
|
||||
}
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) {
|
||||
return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)")
|
||||
}
|
||||
|
||||
am := authenicateMessage{
|
||||
UserName: user,
|
||||
TargetName: cm.TargetName,
|
||||
NegotiateFlags: cm.NegotiateFlags,
|
||||
}
|
||||
|
||||
timestamp := cm.TargetInfo[avIDMsvAvTimestamp]
|
||||
if timestamp == nil { // no time sent, take current time
|
||||
ft := uint64(time.Now().UnixNano()) / 100
|
||||
ft += 116444736000000000 // add time between unix & windows offset
|
||||
timestamp = make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(timestamp, ft)
|
||||
}
|
||||
|
||||
clientChallenge := make([]byte, 8)
|
||||
rand.Reader.Read(clientChallenge)
|
||||
|
||||
hashParts := strings.Split(hash, ":")
|
||||
if len(hashParts) > 1 {
|
||||
hash = hashParts[1]
|
||||
}
|
||||
hashBytes, err := hex.DecodeString(hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ntlmV2Hash := hmacMd5(hashBytes, toUnicode(strings.ToUpper(user)+cm.TargetName))
|
||||
|
||||
am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw)
|
||||
|
||||
if cm.TargetInfoRaw == nil {
|
||||
am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge)
|
||||
}
|
||||
return am.MarshalBinary()
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type authheader []string
|
||||
|
||||
func (h authheader) IsBasic() bool {
|
||||
for _, s := range h {
|
||||
if strings.HasPrefix(string(s), "Basic ") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h authheader) Basic() string {
|
||||
for _, s := range h {
|
||||
if strings.HasPrefix(string(s), "Basic ") {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h authheader) IsNegotiate() bool {
|
||||
for _, s := range h {
|
||||
if strings.HasPrefix(string(s), "Negotiate") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h authheader) IsNTLM() bool {
|
||||
for _, s := range h {
|
||||
if strings.HasPrefix(string(s), "NTLM") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h authheader) GetData() ([]byte, error) {
|
||||
for _, s := range h {
|
||||
if strings.HasPrefix(string(s), "NTLM") || strings.HasPrefix(string(s), "Negotiate") || strings.HasPrefix(string(s), "Basic ") {
|
||||
p := strings.Split(string(s), " ")
|
||||
if len(p) < 2 {
|
||||
return nil, nil
|
||||
}
|
||||
return base64.StdEncoding.DecodeString(string(p[1]))
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (h authheader) GetBasicCreds() (username, password string, err error) {
|
||||
d, err := h.GetData()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
parts := strings.SplitN(string(d), ":", 2)
|
||||
return parts[0], parts[1], nil
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package ntlmssp
|
||||
|
||||
type avID uint16
|
||||
|
||||
const (
|
||||
avIDMsvAvEOL avID = iota
|
||||
avIDMsvAvNbComputerName
|
||||
avIDMsvAvNbDomainName
|
||||
avIDMsvAvDNSComputerName
|
||||
avIDMsvAvDNSDomainName
|
||||
avIDMsvAvDNSTreeName
|
||||
avIDMsvAvFlags
|
||||
avIDMsvAvTimestamp
|
||||
avIDMsvAvSingleHost
|
||||
avIDMsvAvTargetName
|
||||
avIDMsvChannelBindings
|
||||
)
|
|
@ -0,0 +1,82 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type challengeMessageFields struct {
|
||||
messageHeader
|
||||
TargetName varField
|
||||
NegotiateFlags negotiateFlags
|
||||
ServerChallenge [8]byte
|
||||
_ [8]byte
|
||||
TargetInfo varField
|
||||
}
|
||||
|
||||
func (m challengeMessageFields) IsValid() bool {
|
||||
return m.messageHeader.IsValid() && m.MessageType == 2
|
||||
}
|
||||
|
||||
type challengeMessage struct {
|
||||
challengeMessageFields
|
||||
TargetName string
|
||||
TargetInfo map[avID][]byte
|
||||
TargetInfoRaw []byte
|
||||
}
|
||||
|
||||
func (m *challengeMessage) UnmarshalBinary(data []byte) error {
|
||||
r := bytes.NewReader(data)
|
||||
err := binary.Read(r, binary.LittleEndian, &m.challengeMessageFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !m.challengeMessageFields.IsValid() {
|
||||
return fmt.Errorf("Message is not a valid challenge message: %+v", m.challengeMessageFields.messageHeader)
|
||||
}
|
||||
|
||||
if m.challengeMessageFields.TargetName.Len > 0 {
|
||||
m.TargetName, err = m.challengeMessageFields.TargetName.ReadStringFrom(data, m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if m.challengeMessageFields.TargetInfo.Len > 0 {
|
||||
d, err := m.challengeMessageFields.TargetInfo.ReadFrom(data)
|
||||
m.TargetInfoRaw = d
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.TargetInfo = make(map[avID][]byte)
|
||||
r := bytes.NewReader(d)
|
||||
for {
|
||||
var id avID
|
||||
var l uint16
|
||||
err = binary.Read(r, binary.LittleEndian, &id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if id == avIDMsvAvEOL {
|
||||
break
|
||||
}
|
||||
|
||||
err = binary.Read(r, binary.LittleEndian, &l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value := make([]byte, l)
|
||||
n, err := r.Read(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != int(l) {
|
||||
return fmt.Errorf("Expected to read %d bytes, got only %d", l, n)
|
||||
}
|
||||
m.TargetInfo[id] = value
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
var signature = [8]byte{'N', 'T', 'L', 'M', 'S', 'S', 'P', 0}
|
||||
|
||||
type messageHeader struct {
|
||||
Signature [8]byte
|
||||
MessageType uint32
|
||||
}
|
||||
|
||||
func (h messageHeader) IsValid() bool {
|
||||
return bytes.Equal(h.Signature[:], signature[:]) &&
|
||||
h.MessageType > 0 && h.MessageType < 4
|
||||
}
|
||||
|
||||
func newMessageHeader(messageType uint32) messageHeader {
|
||||
return messageHeader{signature, messageType}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package ntlmssp
|
||||
|
||||
type negotiateFlags uint32
|
||||
|
||||
const (
|
||||
/*A*/ negotiateFlagNTLMSSPNEGOTIATEUNICODE negotiateFlags = 1 << 0
|
||||
/*B*/ negotiateFlagNTLMNEGOTIATEOEM = 1 << 1
|
||||
/*C*/ negotiateFlagNTLMSSPREQUESTTARGET = 1 << 2
|
||||
|
||||
/*D*/
|
||||
negotiateFlagNTLMSSPNEGOTIATESIGN = 1 << 4
|
||||
/*E*/ negotiateFlagNTLMSSPNEGOTIATESEAL = 1 << 5
|
||||
/*F*/ negotiateFlagNTLMSSPNEGOTIATEDATAGRAM = 1 << 6
|
||||
/*G*/ negotiateFlagNTLMSSPNEGOTIATELMKEY = 1 << 7
|
||||
|
||||
/*H*/
|
||||
negotiateFlagNTLMSSPNEGOTIATENTLM = 1 << 9
|
||||
|
||||
/*J*/
|
||||
negotiateFlagANONYMOUS = 1 << 11
|
||||
/*K*/ negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED = 1 << 12
|
||||
/*L*/ negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED = 1 << 13
|
||||
|
||||
/*M*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEALWAYSSIGN = 1 << 15
|
||||
/*N*/ negotiateFlagNTLMSSPTARGETTYPEDOMAIN = 1 << 16
|
||||
/*O*/ negotiateFlagNTLMSSPTARGETTYPESERVER = 1 << 17
|
||||
|
||||
/*P*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY = 1 << 19
|
||||
/*Q*/ negotiateFlagNTLMSSPNEGOTIATEIDENTIFY = 1 << 20
|
||||
|
||||
/*R*/
|
||||
negotiateFlagNTLMSSPREQUESTNONNTSESSIONKEY = 1 << 22
|
||||
/*S*/ negotiateFlagNTLMSSPNEGOTIATETARGETINFO = 1 << 23
|
||||
|
||||
/*T*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEVERSION = 1 << 25
|
||||
|
||||
/*U*/
|
||||
negotiateFlagNTLMSSPNEGOTIATE128 = 1 << 29
|
||||
/*V*/ negotiateFlagNTLMSSPNEGOTIATEKEYEXCH = 1 << 30
|
||||
/*W*/ negotiateFlagNTLMSSPNEGOTIATE56 = 1 << 31
|
||||
)
|
||||
|
||||
func (field negotiateFlags) Has(flags negotiateFlags) bool {
|
||||
return field&flags == flags
|
||||
}
|
||||
|
||||
func (field *negotiateFlags) Unset(flags negotiateFlags) {
|
||||
*field = *field ^ (*field & flags)
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const expMsgBodyLen = 40
|
||||
|
||||
type negotiateMessageFields struct {
|
||||
messageHeader
|
||||
NegotiateFlags negotiateFlags
|
||||
|
||||
Domain varField
|
||||
Workstation varField
|
||||
|
||||
Version
|
||||
}
|
||||
|
||||
var defaultFlags = negotiateFlagNTLMSSPNEGOTIATETARGETINFO |
|
||||
negotiateFlagNTLMSSPNEGOTIATE56 |
|
||||
negotiateFlagNTLMSSPNEGOTIATE128 |
|
||||
negotiateFlagNTLMSSPNEGOTIATEUNICODE |
|
||||
negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY
|
||||
|
||||
//NewNegotiateMessage creates a new NEGOTIATE message with the
|
||||
//flags that this package supports.
|
||||
func NewNegotiateMessage(domainName, workstationName string) ([]byte, error) {
|
||||
payloadOffset := expMsgBodyLen
|
||||
flags := defaultFlags
|
||||
|
||||
if domainName != "" {
|
||||
flags |= negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED
|
||||
}
|
||||
|
||||
if workstationName != "" {
|
||||
flags |= negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED
|
||||
}
|
||||
|
||||
msg := negotiateMessageFields{
|
||||
messageHeader: newMessageHeader(1),
|
||||
NegotiateFlags: flags,
|
||||
Domain: newVarField(&payloadOffset, len(domainName)),
|
||||
Workstation: newVarField(&payloadOffset, len(workstationName)),
|
||||
Version: DefaultVersion(),
|
||||
}
|
||||
|
||||
b := bytes.Buffer{}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if b.Len() != expMsgBodyLen {
|
||||
return nil, errors.New("incorrect body length")
|
||||
}
|
||||
|
||||
payload := strings.ToUpper(domainName + workstationName)
|
||||
if _, err := b.WriteString(payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
|
@ -0,0 +1,151 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetDomain : parse domain name from based on slashes in the input
|
||||
// Need to check for upn as well
|
||||
func GetDomain(user string) (string, string, bool) {
|
||||
domain := ""
|
||||
domainNeeded := false
|
||||
|
||||
if strings.Contains(user, "\\") {
|
||||
ucomponents := strings.SplitN(user, "\\", 2)
|
||||
domain = ucomponents[0]
|
||||
user = ucomponents[1]
|
||||
domainNeeded = true
|
||||
} else if strings.Contains(user, "@") {
|
||||
domainNeeded = false
|
||||
} else {
|
||||
domainNeeded = true
|
||||
}
|
||||
return user, domain, domainNeeded
|
||||
}
|
||||
|
||||
//Negotiator is a http.Roundtripper decorator that automatically
|
||||
//converts basic authentication to NTLM/Negotiate authentication when appropriate.
|
||||
type Negotiator struct{ http.RoundTripper }
|
||||
|
||||
//RoundTrip sends the request to the server, handling any authentication
|
||||
//re-sends as needed.
|
||||
func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||
// Use default round tripper if not provided
|
||||
rt := l.RoundTripper
|
||||
if rt == nil {
|
||||
rt = http.DefaultTransport
|
||||
}
|
||||
// If it is not basic auth, just round trip the request as usual
|
||||
reqauth := authheader(req.Header.Values("Authorization"))
|
||||
if !reqauth.IsBasic() {
|
||||
return rt.RoundTrip(req)
|
||||
}
|
||||
reqauthBasic := reqauth.Basic()
|
||||
// Save request body
|
||||
body := bytes.Buffer{}
|
||||
if req.Body != nil {
|
||||
_, err = body.ReadFrom(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Body.Close()
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
}
|
||||
// first try anonymous, in case the server still finds us
|
||||
// authenticated from previous traffic
|
||||
req.Header.Del("Authorization")
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
return res, err
|
||||
}
|
||||
resauth := authheader(res.Header.Values("Www-Authenticate"))
|
||||
if !resauth.IsNegotiate() && !resauth.IsNTLM() {
|
||||
// Unauthorized, Negotiate not requested, let's try with basic auth
|
||||
req.Header.Set("Authorization", string(reqauthBasic))
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
return res, err
|
||||
}
|
||||
resauth = authheader(res.Header.Values("Www-Authenticate"))
|
||||
}
|
||||
|
||||
if resauth.IsNegotiate() || resauth.IsNTLM() {
|
||||
// 401 with request:Basic and response:Negotiate
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
// recycle credentials
|
||||
u, p, err := reqauth.GetBasicCreds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get domain from username
|
||||
domain := ""
|
||||
u, domain, domainNeeded := GetDomain(u)
|
||||
|
||||
// send negotiate
|
||||
negotiateMessage, err := NewNegotiateMessage(domain, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resauth.IsNTLM() {
|
||||
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// receive challenge?
|
||||
resauth = authheader(res.Header.Values("Www-Authenticate"))
|
||||
challengeMessage, err := resauth.GetData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
|
||||
// Negotiation failed, let client deal with response
|
||||
return res, nil
|
||||
}
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
// send authenticate
|
||||
authenticateMessage, err := ProcessChallenge(challengeMessage, u, p, domainNeeded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resauth.IsNTLM() {
|
||||
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
return rt.RoundTrip(req)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
// Package ntlmssp provides NTLM/Negotiate authentication over HTTP
|
||||
//
|
||||
// Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx,
|
||||
// implementation hints from http://davenport.sourceforge.net/ntlm.html .
|
||||
// This package only implements authentication, no key exchange or encryption. It
|
||||
// only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding.
|
||||
// This package implements NTLMv2.
|
||||
package ntlmssp
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"golang.org/x/crypto/md4"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getNtlmV2Hash(password, username, target string) []byte {
|
||||
return hmacMd5(getNtlmHash(password), toUnicode(strings.ToUpper(username)+target))
|
||||
}
|
||||
|
||||
func getNtlmHash(password string) []byte {
|
||||
hash := md4.New()
|
||||
hash.Write(toUnicode(password))
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
func computeNtlmV2Response(ntlmV2Hash, serverChallenge, clientChallenge,
|
||||
timestamp, targetInfo []byte) []byte {
|
||||
|
||||
temp := []byte{1, 1, 0, 0, 0, 0, 0, 0}
|
||||
temp = append(temp, timestamp...)
|
||||
temp = append(temp, clientChallenge...)
|
||||
temp = append(temp, 0, 0, 0, 0)
|
||||
temp = append(temp, targetInfo...)
|
||||
temp = append(temp, 0, 0, 0, 0)
|
||||
|
||||
NTProofStr := hmacMd5(ntlmV2Hash, serverChallenge, temp)
|
||||
return append(NTProofStr, temp...)
|
||||
}
|
||||
|
||||
func computeLmV2Response(ntlmV2Hash, serverChallenge, clientChallenge []byte) []byte {
|
||||
return append(hmacMd5(ntlmV2Hash, serverChallenge, clientChallenge), clientChallenge...)
|
||||
}
|
||||
|
||||
func hmacMd5(key []byte, data ...[]byte) []byte {
|
||||
mac := hmac.New(md5.New, key)
|
||||
for _, d := range data {
|
||||
mac.Write(d)
|
||||
}
|
||||
return mac.Sum(nil)
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"unicode/utf16"
|
||||
)
|
||||
|
||||
// helper func's for dealing with Windows Unicode (UTF16LE)
|
||||
|
||||
func fromUnicode(d []byte) (string, error) {
|
||||
if len(d)%2 > 0 {
|
||||
return "", errors.New("Unicode (UTF 16 LE) specified, but uneven data length")
|
||||
}
|
||||
s := make([]uint16, len(d)/2)
|
||||
err := binary.Read(bytes.NewReader(d), binary.LittleEndian, &s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(utf16.Decode(s)), nil
|
||||
}
|
||||
|
||||
func toUnicode(s string) []byte {
|
||||
uints := utf16.Encode([]rune(s))
|
||||
b := bytes.Buffer{}
|
||||
binary.Write(&b, binary.LittleEndian, &uints)
|
||||
return b.Bytes()
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
type varField struct {
|
||||
Len uint16
|
||||
MaxLen uint16
|
||||
BufferOffset uint32
|
||||
}
|
||||
|
||||
func (f varField) ReadFrom(buffer []byte) ([]byte, error) {
|
||||
if len(buffer) < int(f.BufferOffset+uint32(f.Len)) {
|
||||
return nil, errors.New("Error reading data, varField extends beyond buffer")
|
||||
}
|
||||
return buffer[f.BufferOffset : f.BufferOffset+uint32(f.Len)], nil
|
||||
}
|
||||
|
||||
func (f varField) ReadStringFrom(buffer []byte, unicode bool) (string, error) {
|
||||
d, err := f.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if unicode { // UTF-16LE encoding scheme
|
||||
return fromUnicode(d)
|
||||
}
|
||||
// OEM encoding, close enough to ASCII, since no code page is specified
|
||||
return string(d), err
|
||||
}
|
||||
|
||||
func newVarField(ptr *int, fieldsize int) varField {
|
||||
f := varField{
|
||||
Len: uint16(fieldsize),
|
||||
MaxLen: uint16(fieldsize),
|
||||
BufferOffset: uint32(*ptr),
|
||||
}
|
||||
*ptr += fieldsize
|
||||
return f
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package ntlmssp
|
||||
|
||||
// Version is a struct representing https://msdn.microsoft.com/en-us/library/cc236654.aspx
|
||||
type Version struct {
|
||||
ProductMajorVersion uint8
|
||||
ProductMinorVersion uint8
|
||||
ProductBuild uint16
|
||||
_ [3]byte
|
||||
NTLMRevisionCurrent uint8
|
||||
}
|
||||
|
||||
// DefaultVersion returns a Version with "sensible" defaults (Windows 7)
|
||||
func DefaultVersion() Version {
|
||||
return Version{
|
||||
ProductMajorVersion: 6,
|
||||
ProductMinorVersion: 1,
|
||||
ProductBuild: 7601,
|
||||
NTLMRevisionCurrent: 15,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2011-2015 Michael Mitton (mmitton@gmail.com)
|
||||
Portions copyright (c) 2015-2016 go-asn1-ber Authors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,24 @@
|
|||
[![GoDoc](https://godoc.org/gopkg.in/asn1-ber.v1?status.svg)](https://godoc.org/gopkg.in/asn1-ber.v1) [![Build Status](https://travis-ci.org/go-asn1-ber/asn1-ber.svg)](https://travis-ci.org/go-asn1-ber/asn1-ber)
|
||||
|
||||
|
||||
ASN1 BER Encoding / Decoding Library for the GO programming language.
|
||||
---------------------------------------------------------------------
|
||||
|
||||
Required libraries:
|
||||
None
|
||||
|
||||
Working:
|
||||
Very basic encoding / decoding needed for LDAP protocol
|
||||
|
||||
Tests Implemented:
|
||||
A few
|
||||
|
||||
TODO:
|
||||
Fix all encoding / decoding to conform to ASN1 BER spec
|
||||
Implement Tests / Benchmarks
|
||||
|
||||
---
|
||||
|
||||
The Go gopher was designed by Renee French. (http://reneefrench.blogspot.com/)
|
||||
The design is licensed under the Creative Commons 3.0 Attributions license.
|
||||
Read this article for more details: http://blog.golang.org/gopher
|
|
@ -0,0 +1,625 @@
|
|||
package ber
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// MaxPacketLengthBytes specifies the maximum allowed packet size when calling ReadPacket or DecodePacket. Set to 0 for
|
||||
// no limit.
|
||||
var MaxPacketLengthBytes int64 = math.MaxInt32
|
||||
|
||||
type Packet struct {
|
||||
Identifier
|
||||
Value interface{}
|
||||
ByteValue []byte
|
||||
Data *bytes.Buffer
|
||||
Children []*Packet
|
||||
Description string
|
||||
}
|
||||
|
||||
type Identifier struct {
|
||||
ClassType Class
|
||||
TagType Type
|
||||
Tag Tag
|
||||
}
|
||||
|
||||
type Tag uint64
|
||||
|
||||
const (
|
||||
TagEOC Tag = 0x00
|
||||
TagBoolean Tag = 0x01
|
||||
TagInteger Tag = 0x02
|
||||
TagBitString Tag = 0x03
|
||||
TagOctetString Tag = 0x04
|
||||
TagNULL Tag = 0x05
|
||||
TagObjectIdentifier Tag = 0x06
|
||||
TagObjectDescriptor Tag = 0x07
|
||||
TagExternal Tag = 0x08
|
||||
TagRealFloat Tag = 0x09
|
||||
TagEnumerated Tag = 0x0a
|
||||
TagEmbeddedPDV Tag = 0x0b
|
||||
TagUTF8String Tag = 0x0c
|
||||
TagRelativeOID Tag = 0x0d
|
||||
TagSequence Tag = 0x10
|
||||
TagSet Tag = 0x11
|
||||
TagNumericString Tag = 0x12
|
||||
TagPrintableString Tag = 0x13
|
||||
TagT61String Tag = 0x14
|
||||
TagVideotexString Tag = 0x15
|
||||
TagIA5String Tag = 0x16
|
||||
TagUTCTime Tag = 0x17
|
||||
TagGeneralizedTime Tag = 0x18
|
||||
TagGraphicString Tag = 0x19
|
||||
TagVisibleString Tag = 0x1a
|
||||
TagGeneralString Tag = 0x1b
|
||||
TagUniversalString Tag = 0x1c
|
||||
TagCharacterString Tag = 0x1d
|
||||
TagBMPString Tag = 0x1e
|
||||
TagBitmask Tag = 0x1f // xxx11111b
|
||||
|
||||
// HighTag indicates the start of a high-tag byte sequence
|
||||
HighTag Tag = 0x1f // xxx11111b
|
||||
// HighTagContinueBitmask indicates the high-tag byte sequence should continue
|
||||
HighTagContinueBitmask Tag = 0x80 // 10000000b
|
||||
// HighTagValueBitmask obtains the tag value from a high-tag byte sequence byte
|
||||
HighTagValueBitmask Tag = 0x7f // 01111111b
|
||||
)
|
||||
|
||||
const (
|
||||
// LengthLongFormBitmask is the mask to apply to the length byte to see if a long-form byte sequence is used
|
||||
LengthLongFormBitmask = 0x80
|
||||
// LengthValueBitmask is the mask to apply to the length byte to get the number of bytes in the long-form byte sequence
|
||||
LengthValueBitmask = 0x7f
|
||||
|
||||
// LengthIndefinite is returned from readLength to indicate an indefinite length
|
||||
LengthIndefinite = -1
|
||||
)
|
||||
|
||||
var tagMap = map[Tag]string{
|
||||
TagEOC: "EOC (End-of-Content)",
|
||||
TagBoolean: "Boolean",
|
||||
TagInteger: "Integer",
|
||||
TagBitString: "Bit String",
|
||||
TagOctetString: "Octet String",
|
||||
TagNULL: "NULL",
|
||||
TagObjectIdentifier: "Object Identifier",
|
||||
TagObjectDescriptor: "Object Descriptor",
|
||||
TagExternal: "External",
|
||||
TagRealFloat: "Real (float)",
|
||||
TagEnumerated: "Enumerated",
|
||||
TagEmbeddedPDV: "Embedded PDV",
|
||||
TagUTF8String: "UTF8 String",
|
||||
TagRelativeOID: "Relative-OID",
|
||||
TagSequence: "Sequence and Sequence of",
|
||||
TagSet: "Set and Set OF",
|
||||
TagNumericString: "Numeric String",
|
||||
TagPrintableString: "Printable String",
|
||||
TagT61String: "T61 String",
|
||||
TagVideotexString: "Videotex String",
|
||||
TagIA5String: "IA5 String",
|
||||
TagUTCTime: "UTC Time",
|
||||
TagGeneralizedTime: "Generalized Time",
|
||||
TagGraphicString: "Graphic String",
|
||||
TagVisibleString: "Visible String",
|
||||
TagGeneralString: "General String",
|
||||
TagUniversalString: "Universal String",
|
||||
TagCharacterString: "Character String",
|
||||
TagBMPString: "BMP String",
|
||||
}
|
||||
|
||||
type Class uint8
|
||||
|
||||
const (
|
||||
ClassUniversal Class = 0 // 00xxxxxxb
|
||||
ClassApplication Class = 64 // 01xxxxxxb
|
||||
ClassContext Class = 128 // 10xxxxxxb
|
||||
ClassPrivate Class = 192 // 11xxxxxxb
|
||||
ClassBitmask Class = 192 // 11xxxxxxb
|
||||
)
|
||||
|
||||
var ClassMap = map[Class]string{
|
||||
ClassUniversal: "Universal",
|
||||
ClassApplication: "Application",
|
||||
ClassContext: "Context",
|
||||
ClassPrivate: "Private",
|
||||
}
|
||||
|
||||
type Type uint8
|
||||
|
||||
const (
|
||||
TypePrimitive Type = 0 // xx0xxxxxb
|
||||
TypeConstructed Type = 32 // xx1xxxxxb
|
||||
TypeBitmask Type = 32 // xx1xxxxxb
|
||||
)
|
||||
|
||||
var TypeMap = map[Type]string{
|
||||
TypePrimitive: "Primitive",
|
||||
TypeConstructed: "Constructed",
|
||||
}
|
||||
|
||||
var Debug = false
|
||||
|
||||
func PrintBytes(out io.Writer, buf []byte, indent string) {
|
||||
dataLines := make([]string, (len(buf)/30)+1)
|
||||
numLines := make([]string, (len(buf)/30)+1)
|
||||
|
||||
for i, b := range buf {
|
||||
dataLines[i/30] += fmt.Sprintf("%02x ", b)
|
||||
numLines[i/30] += fmt.Sprintf("%02d ", (i+1)%100)
|
||||
}
|
||||
|
||||
for i := 0; i < len(dataLines); i++ {
|
||||
_, _ = out.Write([]byte(indent + dataLines[i] + "\n"))
|
||||
_, _ = out.Write([]byte(indent + numLines[i] + "\n\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func WritePacket(out io.Writer, p *Packet) {
|
||||
printPacket(out, p, 0, false)
|
||||
}
|
||||
|
||||
func PrintPacket(p *Packet) {
|
||||
printPacket(os.Stdout, p, 0, false)
|
||||
}
|
||||
|
||||
// Return a string describing packet content. This is not recursive,
|
||||
// If the packet is a sequence, use `printPacket()`, or browse
|
||||
// sequence yourself.
|
||||
func DescribePacket(p *Packet) string {
|
||||
|
||||
classStr := ClassMap[p.ClassType]
|
||||
|
||||
tagTypeStr := TypeMap[p.TagType]
|
||||
|
||||
tagStr := fmt.Sprintf("0x%02X", p.Tag)
|
||||
|
||||
if p.ClassType == ClassUniversal {
|
||||
tagStr = tagMap[p.Tag]
|
||||
}
|
||||
|
||||
value := fmt.Sprint(p.Value)
|
||||
description := ""
|
||||
|
||||
if p.Description != "" {
|
||||
description = p.Description + ": "
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s(%s, %s, %s) Len=%d %q", description, classStr, tagTypeStr, tagStr, p.Data.Len(), value)
|
||||
}
|
||||
|
||||
func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) {
|
||||
indentStr := ""
|
||||
|
||||
for len(indentStr) != indent {
|
||||
indentStr += " "
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(out, "%s%s\n", indentStr, DescribePacket(p))
|
||||
|
||||
if printBytes {
|
||||
PrintBytes(out, p.Bytes(), indentStr)
|
||||
}
|
||||
|
||||
for _, child := range p.Children {
|
||||
printPacket(out, child, indent+1, printBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPacket reads a single Packet from the reader.
|
||||
func ReadPacket(reader io.Reader) (*Packet, error) {
|
||||
p, _, err := readPacket(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func DecodeString(data []byte) string {
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func ParseInt64(bytes []byte) (ret int64, err error) {
|
||||
if len(bytes) > 8 {
|
||||
// We'll overflow an int64 in this case.
|
||||
err = fmt.Errorf("integer too large")
|
||||
return
|
||||
}
|
||||
for bytesRead := 0; bytesRead < len(bytes); bytesRead++ {
|
||||
ret <<= 8
|
||||
ret |= int64(bytes[bytesRead])
|
||||
}
|
||||
|
||||
// Shift up and down in order to sign extend the result.
|
||||
ret <<= 64 - uint8(len(bytes))*8
|
||||
ret >>= 64 - uint8(len(bytes))*8
|
||||
return
|
||||
}
|
||||
|
||||
func encodeInteger(i int64) []byte {
|
||||
n := int64Length(i)
|
||||
out := make([]byte, n)
|
||||
|
||||
var j int
|
||||
for ; n > 0; n-- {
|
||||
out[j] = byte(i >> uint((n-1)*8))
|
||||
j++
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func int64Length(i int64) (numBytes int) {
|
||||
numBytes = 1
|
||||
|
||||
for i > 127 {
|
||||
numBytes++
|
||||
i >>= 8
|
||||
}
|
||||
|
||||
for i < -128 {
|
||||
numBytes++
|
||||
i >>= 8
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// DecodePacket decodes the given bytes into a single Packet
|
||||
// If a decode error is encountered, nil is returned.
|
||||
func DecodePacket(data []byte) *Packet {
|
||||
p, _, _ := readPacket(bytes.NewBuffer(data))
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// DecodePacketErr decodes the given bytes into a single Packet
|
||||
// If a decode error is encountered, nil is returned.
|
||||
func DecodePacketErr(data []byte) (*Packet, error) {
|
||||
p, _, err := readPacket(bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// readPacket reads a single Packet from the reader, returning the number of bytes read.
|
||||
func readPacket(reader io.Reader) (*Packet, int, error) {
|
||||
identifier, length, read, err := readHeader(reader)
|
||||
if err != nil {
|
||||
return nil, read, err
|
||||
}
|
||||
|
||||
p := &Packet{
|
||||
Identifier: identifier,
|
||||
}
|
||||
|
||||
p.Data = new(bytes.Buffer)
|
||||
p.Children = make([]*Packet, 0, 2)
|
||||
p.Value = nil
|
||||
|
||||
if p.TagType == TypeConstructed {
|
||||
// TODO: if universal, ensure tag type is allowed to be constructed
|
||||
|
||||
// Track how much content we've read
|
||||
contentRead := 0
|
||||
for {
|
||||
if length != LengthIndefinite {
|
||||
// End if we've read what we've been told to
|
||||
if contentRead == length {
|
||||
break
|
||||
}
|
||||
// Detect if a packet boundary didn't fall on the expected length
|
||||
if contentRead > length {
|
||||
return nil, read, fmt.Errorf("expected to read %d bytes, read %d", length, contentRead)
|
||||
}
|
||||
}
|
||||
|
||||
// Read the next packet
|
||||
child, r, err := readPacket(reader)
|
||||
if err != nil {
|
||||
return nil, read, unexpectedEOF(err)
|
||||
}
|
||||
contentRead += r
|
||||
read += r
|
||||
|
||||
// Test is this is the EOC marker for our packet
|
||||
if isEOCPacket(child) {
|
||||
if length == LengthIndefinite {
|
||||
break
|
||||
}
|
||||
return nil, read, errors.New("eoc child not allowed with definite length")
|
||||
}
|
||||
|
||||
// Append and continue
|
||||
p.AppendChild(child)
|
||||
}
|
||||
return p, read, nil
|
||||
}
|
||||
|
||||
if length == LengthIndefinite {
|
||||
return nil, read, errors.New("indefinite length used with primitive type")
|
||||
}
|
||||
|
||||
// Read definite-length content
|
||||
if MaxPacketLengthBytes > 0 && int64(length) > MaxPacketLengthBytes {
|
||||
return nil, read, fmt.Errorf("length %d greater than maximum %d", length, MaxPacketLengthBytes)
|
||||
}
|
||||
content := make([]byte, length)
|
||||
if length > 0 {
|
||||
_, err := io.ReadFull(reader, content)
|
||||
if err != nil {
|
||||
return nil, read, unexpectedEOF(err)
|
||||
}
|
||||
read += length
|
||||
}
|
||||
|
||||
if p.ClassType == ClassUniversal {
|
||||
p.Data.Write(content)
|
||||
p.ByteValue = content
|
||||
|
||||
switch p.Tag {
|
||||
case TagEOC:
|
||||
case TagBoolean:
|
||||
val, _ := ParseInt64(content)
|
||||
|
||||
p.Value = val != 0
|
||||
case TagInteger:
|
||||
p.Value, _ = ParseInt64(content)
|
||||
case TagBitString:
|
||||
case TagOctetString:
|
||||
// the actual string encoding is not known here
|
||||
// (e.g. for LDAP content is already an UTF8-encoded
|
||||
// string). Return the data without further processing
|
||||
p.Value = DecodeString(content)
|
||||
case TagNULL:
|
||||
case TagObjectIdentifier:
|
||||
case TagObjectDescriptor:
|
||||
case TagExternal:
|
||||
case TagRealFloat:
|
||||
p.Value, err = ParseReal(content)
|
||||
case TagEnumerated:
|
||||
p.Value, _ = ParseInt64(content)
|
||||
case TagEmbeddedPDV:
|
||||
case TagUTF8String:
|
||||
val := DecodeString(content)
|
||||
if !utf8.Valid([]byte(val)) {
|
||||
err = errors.New("invalid UTF-8 string")
|
||||
} else {
|
||||
p.Value = val
|
||||
}
|
||||
case TagRelativeOID:
|
||||
case TagSequence:
|
||||
case TagSet:
|
||||
case TagNumericString:
|
||||
case TagPrintableString:
|
||||
val := DecodeString(content)
|
||||
if err = isPrintableString(val); err == nil {
|
||||
p.Value = val
|
||||
}
|
||||
case TagT61String:
|
||||
case TagVideotexString:
|
||||
case TagIA5String:
|
||||
val := DecodeString(content)
|
||||
for i, c := range val {
|
||||
if c >= 0x7F {
|
||||
err = fmt.Errorf("invalid character for IA5String at pos %d: %c", i, c)
|
||||
break
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
p.Value = val
|
||||
}
|
||||
case TagUTCTime:
|
||||
case TagGeneralizedTime:
|
||||
p.Value, err = ParseGeneralizedTime(content)
|
||||
case TagGraphicString:
|
||||
case TagVisibleString:
|
||||
case TagGeneralString:
|
||||
case TagUniversalString:
|
||||
case TagCharacterString:
|
||||
case TagBMPString:
|
||||
}
|
||||
} else {
|
||||
p.Data.Write(content)
|
||||
}
|
||||
|
||||
return p, read, err
|
||||
}
|
||||
|
||||
func isPrintableString(val string) error {
|
||||
for i, c := range val {
|
||||
switch {
|
||||
case c >= 'a' && c <= 'z':
|
||||
case c >= 'A' && c <= 'Z':
|
||||
case c >= '0' && c <= '9':
|
||||
default:
|
||||
switch c {
|
||||
case '\'', '(', ')', '+', ',', '-', '.', '=', '/', ':', '?', ' ':
|
||||
default:
|
||||
return fmt.Errorf("invalid character in position %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Packet) Bytes() []byte {
|
||||
var out bytes.Buffer
|
||||
|
||||
out.Write(encodeIdentifier(p.Identifier))
|
||||
out.Write(encodeLength(p.Data.Len()))
|
||||
out.Write(p.Data.Bytes())
|
||||
|
||||
return out.Bytes()
|
||||
}
|
||||
|
||||
func (p *Packet) AppendChild(child *Packet) {
|
||||
p.Data.Write(child.Bytes())
|
||||
p.Children = append(p.Children, child)
|
||||
}
|
||||
|
||||
func Encode(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
|
||||
p := new(Packet)
|
||||
|
||||
p.ClassType = classType
|
||||
p.TagType = tagType
|
||||
p.Tag = tag
|
||||
p.Data = new(bytes.Buffer)
|
||||
|
||||
p.Children = make([]*Packet, 0, 2)
|
||||
|
||||
p.Value = value
|
||||
p.Description = description
|
||||
|
||||
if value != nil {
|
||||
v := reflect.ValueOf(value)
|
||||
|
||||
if classType == ClassUniversal {
|
||||
switch tag {
|
||||
case TagOctetString:
|
||||
sv, ok := v.Interface().(string)
|
||||
|
||||
if ok {
|
||||
p.Data.Write([]byte(sv))
|
||||
}
|
||||
case TagEnumerated:
|
||||
bv, ok := v.Interface().([]byte)
|
||||
if ok {
|
||||
p.Data.Write(bv)
|
||||
}
|
||||
case TagEmbeddedPDV:
|
||||
bv, ok := v.Interface().([]byte)
|
||||
if ok {
|
||||
p.Data.Write(bv)
|
||||
}
|
||||
}
|
||||
} else if classType == ClassContext {
|
||||
switch tag {
|
||||
case TagEnumerated:
|
||||
bv, ok := v.Interface().([]byte)
|
||||
if ok {
|
||||
p.Data.Write(bv)
|
||||
}
|
||||
case TagEmbeddedPDV:
|
||||
bv, ok := v.Interface().([]byte)
|
||||
if ok {
|
||||
p.Data.Write(bv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func NewSequence(description string) *Packet {
|
||||
return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, description)
|
||||
}
|
||||
|
||||
func NewBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet {
|
||||
intValue := int64(0)
|
||||
|
||||
if value {
|
||||
intValue = 1
|
||||
}
|
||||
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
p.Value = value
|
||||
p.Data.Write(encodeInteger(intValue))
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// NewLDAPBoolean returns a RFC 4511-compliant Boolean packet.
|
||||
func NewLDAPBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet {
|
||||
intValue := int64(0)
|
||||
|
||||
if value {
|
||||
intValue = 255
|
||||
}
|
||||
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
p.Value = value
|
||||
p.Data.Write(encodeInteger(intValue))
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewInteger(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
p.Value = value
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case uint:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case int64:
|
||||
p.Data.Write(encodeInteger(v))
|
||||
case uint64:
|
||||
// TODO : check range or add encodeUInt...
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case int32:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case uint32:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case int16:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case uint16:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case int8:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case uint8:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
default:
|
||||
// TODO : add support for big.Int ?
|
||||
panic(fmt.Sprintf("Invalid type %T, expected {u|}int{64|32|16|8}", v))
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewString(classType Class, tagType Type, tag Tag, value, description string) *Packet {
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
p.Value = value
|
||||
p.Data.Write([]byte(value))
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewGeneralizedTime(classType Class, tagType Type, tag Tag, value time.Time, description string) *Packet {
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
var s string
|
||||
if value.Nanosecond() != 0 {
|
||||
s = value.Format(`20060102150405.000000000Z`)
|
||||
} else {
|
||||
s = value.Format(`20060102150405Z`)
|
||||
}
|
||||
p.Value = s
|
||||
p.Data.Write([]byte(s))
|
||||
return p
|
||||
}
|
||||
|
||||
func NewReal(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
p.Data.Write(encodeFloat(v))
|
||||
case float32:
|
||||
p.Data.Write(encodeFloat(float64(v)))
|
||||
default:
|
||||
panic(fmt.Sprintf("Invalid type %T, expected float{64|32}", v))
|
||||
}
|
||||
return p
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
package ber
|
||||
|
||||
func encodeUnsignedInteger(i uint64) []byte {
|
||||
n := uint64Length(i)
|
||||
out := make([]byte, n)
|
||||
|
||||
var j int
|
||||
for ; n > 0; n-- {
|
||||
out[j] = byte(i >> uint((n-1)*8))
|
||||
j++
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func uint64Length(i uint64) (numBytes int) {
|
||||
numBytes = 1
|
||||
|
||||
for i > 255 {
|
||||
numBytes++
|
||||
i >>= 8
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package ber
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrInvalidTimeFormat is returned when the generalizedTime string was not correct.
|
||||
var ErrInvalidTimeFormat = errors.New("invalid time format")
|
||||
|
||||
var zeroTime = time.Time{}
|
||||
|
||||
// ParseGeneralizedTime parses a string value and if it conforms to
|
||||
// GeneralizedTime[^0] format, will return a time.Time for that value.
|
||||
//
|
||||
// [^0]: https://www.itu.int/rec/T-REC-X.690-201508-I/en Section 11.7
|
||||
func ParseGeneralizedTime(v []byte) (time.Time, error) {
|
||||
var format string
|
||||
var fract time.Duration
|
||||
|
||||
str := []byte(DecodeString(v))
|
||||
tzIndex := bytes.IndexAny(str, "Z+-")
|
||||
if tzIndex < 0 {
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
|
||||
dot := bytes.IndexAny(str, ".,")
|
||||
switch dot {
|
||||
case -1:
|
||||
switch tzIndex {
|
||||
case 10:
|
||||
format = `2006010215Z`
|
||||
case 12:
|
||||
format = `200601021504Z`
|
||||
case 14:
|
||||
format = `20060102150405Z`
|
||||
default:
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
|
||||
case 10, 12:
|
||||
if tzIndex < dot {
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
// a "," is also allowed, but would not be parsed by time.Parse():
|
||||
str[dot] = '.'
|
||||
|
||||
// If <minute> is omitted, then <fraction> represents a fraction of an
|
||||
// hour; otherwise, if <second> and <leap-second> are omitted, then
|
||||
// <fraction> represents a fraction of a minute; otherwise, <fraction>
|
||||
// represents a fraction of a second.
|
||||
|
||||
// parse as float from dot to timezone
|
||||
f, err := strconv.ParseFloat(string(str[dot:tzIndex]), 64)
|
||||
if err != nil {
|
||||
return zeroTime, fmt.Errorf("failed to parse float: %s", err)
|
||||
}
|
||||
// ...and strip that part
|
||||
str = append(str[:dot], str[tzIndex:]...)
|
||||
tzIndex = dot
|
||||
|
||||
if dot == 10 {
|
||||
fract = time.Duration(int64(f * float64(time.Hour)))
|
||||
format = `2006010215Z`
|
||||
} else {
|
||||
fract = time.Duration(int64(f * float64(time.Minute)))
|
||||
format = `200601021504Z`
|
||||
}
|
||||
|
||||
case 14:
|
||||
if tzIndex < dot {
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
str[dot] = '.'
|
||||
// no need for fractional seconds, time.Parse() handles that
|
||||
format = `20060102150405Z`
|
||||
|
||||
default:
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
|
||||
l := len(str)
|
||||
switch l - tzIndex {
|
||||
case 1:
|
||||
if str[l-1] != 'Z' {
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
case 3:
|
||||
format += `0700`
|
||||
str = append(str, []byte("00")...)
|
||||
case 5:
|
||||
format += `0700`
|
||||
default:
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
|
||||
t, err := time.Parse(format, string(str))
|
||||
if err != nil {
|
||||
return zeroTime, fmt.Errorf("%s: %s", ErrInvalidTimeFormat, err)
|
||||
}
|
||||
return t.Add(fract), nil
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package ber
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
func readHeader(reader io.Reader) (identifier Identifier, length int, read int, err error) {
|
||||
var (
|
||||
c, l int
|
||||
i Identifier
|
||||
)
|
||||
|
||||
if i, c, err = readIdentifier(reader); err != nil {
|
||||
return Identifier{}, 0, read, err
|
||||
}
|
||||
identifier = i
|
||||
read += c
|
||||
|
||||
if l, c, err = readLength(reader); err != nil {
|
||||
return Identifier{}, 0, read, err
|
||||
}
|
||||
length = l
|
||||
read += c
|
||||
|
||||
// Validate length type with identifier (x.600, 8.1.3.2.a)
|
||||
if length == LengthIndefinite && identifier.TagType == TypePrimitive {
|
||||
return Identifier{}, 0, read, errors.New("indefinite length used with primitive type")
|
||||
}
|
||||
|
||||
if length < LengthIndefinite {
|
||||
err = fmt.Errorf("length cannot be less than %d", LengthIndefinite)
|
||||
return
|
||||
}
|
||||
|
||||
return identifier, length, read, nil
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package ber
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
func readIdentifier(reader io.Reader) (Identifier, int, error) {
|
||||
identifier := Identifier{}
|
||||
read := 0
|
||||
|
||||
// identifier byte
|
||||
b, err := readByte(reader)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Printf("error reading identifier byte: %v\n", err)
|
||||
}
|
||||
return Identifier{}, read, err
|
||||
}
|
||||
read++
|
||||
|
||||
identifier.ClassType = Class(b) & ClassBitmask
|
||||
identifier.TagType = Type(b) & TypeBitmask
|
||||
|
||||
if tag := Tag(b) & TagBitmask; tag != HighTag {
|
||||
// short-form tag
|
||||
identifier.Tag = tag
|
||||
return identifier, read, nil
|
||||
}
|
||||
|
||||
// high-tag-number tag
|
||||
tagBytes := 0
|
||||
for {
|
||||
b, err := readByte(reader)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Printf("error reading high-tag-number tag byte %d: %v\n", tagBytes, err)
|
||||
}
|
||||
return Identifier{}, read, unexpectedEOF(err)
|
||||
}
|
||||
tagBytes++
|
||||
read++
|
||||
|
||||
// Lowest 7 bits get appended to the tag value (x.690, 8.1.2.4.2.b)
|
||||
identifier.Tag <<= 7
|
||||
identifier.Tag |= Tag(b) & HighTagValueBitmask
|
||||
|
||||
// First byte may not be all zeros (x.690, 8.1.2.4.2.c)
|
||||
if tagBytes == 1 && identifier.Tag == 0 {
|
||||
return Identifier{}, read, errors.New("invalid first high-tag-number tag byte")
|
||||
}
|
||||
// Overflow of int64
|
||||
// TODO: support big int tags?
|
||||
if tagBytes > 9 {
|
||||
return Identifier{}, read, errors.New("high-tag-number tag overflow")
|
||||
}
|
||||
|
||||
// Top bit of 0 means this is the last byte in the high-tag-number tag (x.690, 8.1.2.4.2.a)
|
||||
if Tag(b)&HighTagContinueBitmask == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return identifier, read, nil
|
||||
}
|
||||
|
||||
func encodeIdentifier(identifier Identifier) []byte {
|
||||
b := []byte{0x0}
|
||||
b[0] |= byte(identifier.ClassType)
|
||||
b[0] |= byte(identifier.TagType)
|
||||
|
||||
if identifier.Tag < HighTag {
|
||||
// Short-form
|
||||
b[0] |= byte(identifier.Tag)
|
||||
} else {
|
||||
// high-tag-number
|
||||
b[0] |= byte(HighTag)
|
||||
|
||||
tag := identifier.Tag
|
||||
|
||||
b = append(b, encodeHighTag(tag)...)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func encodeHighTag(tag Tag) []byte {
|
||||
// set cap=4 to hopefully avoid additional allocations
|
||||
b := make([]byte, 0, 4)
|
||||
for tag != 0 {
|
||||
// t := last 7 bits of tag (HighTagValueBitmask = 0x7F)
|
||||
t := tag & HighTagValueBitmask
|
||||
|
||||
// right shift tag 7 to remove what was just pulled off
|
||||
tag >>= 7
|
||||
|
||||
// if b already has entries this entry needs a continuation bit (0x80)
|
||||
if len(b) != 0 {
|
||||
t |= HighTagContinueBitmask
|
||||
}
|
||||
|
||||
b = append(b, byte(t))
|
||||
}
|
||||
// reverse
|
||||
// since bits were pulled off 'tag' small to high the byte slice is in reverse order.
|
||||
// example: tag = 0xFF results in {0x7F, 0x01 + 0x80 (continuation bit)}
|
||||
// this needs to be reversed into 0x81 0x7F
|
||||
for i, j := 0, len(b)-1; i < len(b)/2; i++ {
|
||||
b[i], b[j-i] = b[j-i], b[i]
|
||||
}
|
||||
return b
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
package ber
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
func readLength(reader io.Reader) (length int, read int, err error) {
|
||||
// length byte
|
||||
b, err := readByte(reader)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Printf("error reading length byte: %v\n", err)
|
||||
}
|
||||
return 0, 0, unexpectedEOF(err)
|
||||
}
|
||||
read++
|
||||
|
||||
switch {
|
||||
case b == 0xFF:
|
||||
// Invalid 0xFF (x.600, 8.1.3.5.c)
|
||||
return 0, read, errors.New("invalid length byte 0xff")
|
||||
|
||||
case b == LengthLongFormBitmask:
|
||||
// Indefinite form, we have to decode packets until we encounter an EOC packet (x.600, 8.1.3.6)
|
||||
length = LengthIndefinite
|
||||
|
||||
case b&LengthLongFormBitmask == 0:
|
||||
// Short definite form, extract the length from the bottom 7 bits (x.600, 8.1.3.4)
|
||||
length = int(b) & LengthValueBitmask
|
||||
|
||||
case b&LengthLongFormBitmask != 0:
|
||||
// Long definite form, extract the number of length bytes to follow from the bottom 7 bits (x.600, 8.1.3.5.b)
|
||||
lengthBytes := int(b) & LengthValueBitmask
|
||||
// Protect against overflow
|
||||
// TODO: support big int length?
|
||||
if lengthBytes > 8 {
|
||||
return 0, read, errors.New("long-form length overflow")
|
||||
}
|
||||
|
||||
// Accumulate into a 64-bit variable
|
||||
var length64 int64
|
||||
for i := 0; i < lengthBytes; i++ {
|
||||
b, err = readByte(reader)
|
||||
if err != nil {
|
||||
if Debug {
|
||||
fmt.Printf("error reading long-form length byte %d: %v\n", i, err)
|
||||
}
|
||||
return 0, read, unexpectedEOF(err)
|
||||
}
|
||||
read++
|
||||
|
||||
// x.600, 8.1.3.5
|
||||
length64 <<= 8
|
||||
length64 |= int64(b)
|
||||
}
|
||||
|
||||
// Cast to a platform-specific integer
|
||||
length = int(length64)
|
||||
// Ensure we didn't overflow
|
||||
if int64(length) != length64 {
|
||||
return 0, read, errors.New("long-form length overflow")
|
||||
}
|
||||
|
||||
default:
|
||||
return 0, read, errors.New("invalid length byte")
|
||||
}
|
||||
|
||||
return length, read, nil
|
||||
}
|
||||
|
||||
func encodeLength(length int) []byte {
|
||||
lengthBytes := encodeUnsignedInteger(uint64(length))
|
||||
if length > 127 || len(lengthBytes) > 1 {
|
||||
longFormBytes := []byte{LengthLongFormBitmask | byte(len(lengthBytes))}
|
||||
longFormBytes = append(longFormBytes, lengthBytes...)
|
||||
lengthBytes = longFormBytes
|
||||
}
|
||||
return lengthBytes
|
||||
}
|
|
@ -0,0 +1,163 @@
|
|||
package ber
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func encodeFloat(v float64) []byte {
|
||||
switch {
|
||||
case math.IsInf(v, 1):
|
||||
return []byte{0x40}
|
||||
case math.IsInf(v, -1):
|
||||
return []byte{0x41}
|
||||
case math.IsNaN(v):
|
||||
return []byte{0x42}
|
||||
case v == 0.0:
|
||||
if math.Signbit(v) {
|
||||
return []byte{0x43}
|
||||
}
|
||||
return []byte{}
|
||||
default:
|
||||
// we take the easy part ;-)
|
||||
value := []byte(strconv.FormatFloat(v, 'G', -1, 64))
|
||||
var ret []byte
|
||||
if bytes.Contains(value, []byte{'E'}) {
|
||||
ret = []byte{0x03}
|
||||
} else {
|
||||
ret = []byte{0x02}
|
||||
}
|
||||
ret = append(ret, value...)
|
||||
return ret
|
||||
}
|
||||
}
|
||||
|
||||
func ParseReal(v []byte) (val float64, err error) {
|
||||
if len(v) == 0 {
|
||||
return 0.0, nil
|
||||
}
|
||||
switch {
|
||||
case v[0]&0x80 == 0x80:
|
||||
val, err = parseBinaryFloat(v)
|
||||
case v[0]&0xC0 == 0x40:
|
||||
val, err = parseSpecialFloat(v)
|
||||
case v[0]&0xC0 == 0x0:
|
||||
val, err = parseDecimalFloat(v)
|
||||
default:
|
||||
return 0.0, fmt.Errorf("invalid info block")
|
||||
}
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
if val == 0.0 && !math.Signbit(val) {
|
||||
return 0.0, errors.New("REAL value +0 must be encoded with zero-length value block")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func parseBinaryFloat(v []byte) (float64, error) {
|
||||
var info byte
|
||||
var buf []byte
|
||||
|
||||
info, v = v[0], v[1:]
|
||||
|
||||
var base int
|
||||
switch info & 0x30 {
|
||||
case 0x00:
|
||||
base = 2
|
||||
case 0x10:
|
||||
base = 8
|
||||
case 0x20:
|
||||
base = 16
|
||||
case 0x30:
|
||||
return 0.0, errors.New("bits 6 and 5 of information octet for REAL are equal to 11")
|
||||
}
|
||||
|
||||
scale := uint((info & 0x0c) >> 2)
|
||||
|
||||
var expLen int
|
||||
switch info & 0x03 {
|
||||
case 0x00:
|
||||
expLen = 1
|
||||
case 0x01:
|
||||
expLen = 2
|
||||
case 0x02:
|
||||
expLen = 3
|
||||
case 0x03:
|
||||
if len(v) < 2 {
|
||||
return 0.0, errors.New("invalid data")
|
||||
}
|
||||
expLen = int(v[0])
|
||||
if expLen > 8 {
|
||||
return 0.0, errors.New("too big value of exponent")
|
||||
}
|
||||
v = v[1:]
|
||||
}
|
||||
if expLen > len(v) {
|
||||
return 0.0, errors.New("too big value of exponent")
|
||||
}
|
||||
buf, v = v[:expLen], v[expLen:]
|
||||
exponent, err := ParseInt64(buf)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
if len(v) > 8 {
|
||||
return 0.0, errors.New("too big value of mantissa")
|
||||
}
|
||||
|
||||
mant, err := ParseInt64(v)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
mantissa := mant << scale
|
||||
|
||||
if info&0x40 == 0x40 {
|
||||
mantissa = -mantissa
|
||||
}
|
||||
|
||||
return float64(mantissa) * math.Pow(float64(base), float64(exponent)), nil
|
||||
}
|
||||
|
||||
func parseDecimalFloat(v []byte) (val float64, err error) {
|
||||
switch v[0] & 0x3F {
|
||||
case 0x01: // NR form 1
|
||||
var iVal int64
|
||||
iVal, err = strconv.ParseInt(strings.TrimLeft(string(v[1:]), " "), 10, 64)
|
||||
val = float64(iVal)
|
||||
case 0x02, 0x03: // NR form 2, 3
|
||||
val, err = strconv.ParseFloat(strings.Replace(strings.TrimLeft(string(v[1:]), " "), ",", ".", -1), 64)
|
||||
default:
|
||||
err = errors.New("incorrect NR form")
|
||||
}
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
if val == 0.0 && math.Signbit(val) {
|
||||
return 0.0, errors.New("REAL value -0 must be encoded as a special value")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func parseSpecialFloat(v []byte) (float64, error) {
|
||||
if len(v) != 1 {
|
||||
return 0.0, errors.New(`encoding of "special value" must not contain exponent and mantissa`)
|
||||
}
|
||||
switch v[0] {
|
||||
case 0x40:
|
||||
return math.Inf(1), nil
|
||||
case 0x41:
|
||||
return math.Inf(-1), nil
|
||||
case 0x42:
|
||||
return math.NaN(), nil
|
||||
case 0x43:
|
||||
return math.Copysign(0, -1), nil
|
||||
}
|
||||
return 0.0, errors.New(`encoding of "special value" not from ASN.1 standard`)
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package ber
|
||||
|
||||
import "io"
|
||||
|
||||
func readByte(reader io.Reader) (byte, error) {
|
||||
bytes := make([]byte, 1)
|
||||
_, err := io.ReadFull(reader, bytes)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return bytes[0], nil
|
||||
}
|
||||
|
||||
func unexpectedEOF(err error) error {
|
||||
if err == io.EOF {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func isEOCPacket(p *Packet) bool {
|
||||
return p != nil &&
|
||||
p.Tag == TagEOC &&
|
||||
p.ClassType == ClassUniversal &&
|
||||
p.TagType == TypePrimitive &&
|
||||
len(p.ByteValue) == 0 &&
|
||||
len(p.Children) == 0
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2011-2015 Michael Mitton (mmitton@gmail.com)
|
||||
Portions copyright (c) 2015-2016 go-ldap Authors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,89 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// Attribute represents an LDAP attribute
|
||||
type Attribute struct {
|
||||
// Type is the name of the LDAP attribute
|
||||
Type string
|
||||
// Vals are the LDAP attribute values
|
||||
Vals []string
|
||||
}
|
||||
|
||||
func (a *Attribute) encode() *ber.Packet {
|
||||
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
|
||||
seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.Type, "Type"))
|
||||
set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
|
||||
for _, value := range a.Vals {
|
||||
set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
|
||||
}
|
||||
seq.AppendChild(set)
|
||||
return seq
|
||||
}
|
||||
|
||||
// AddRequest represents an LDAP AddRequest operation
|
||||
type AddRequest struct {
|
||||
// DN identifies the entry being added
|
||||
DN string
|
||||
// Attributes list the attributes of the new entry
|
||||
Attributes []Attribute
|
||||
// Controls hold optional controls to send with the request
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
func (req *AddRequest) appendTo(envelope *ber.Packet) error {
|
||||
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
|
||||
attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
|
||||
for _, attribute := range req.Attributes {
|
||||
attributes.AppendChild(attribute.encode())
|
||||
}
|
||||
pkt.AppendChild(attributes)
|
||||
|
||||
envelope.AppendChild(pkt)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Attribute adds an attribute with the given type and values
|
||||
func (req *AddRequest) Attribute(attrType string, attrVals []string) {
|
||||
req.Attributes = append(req.Attributes, Attribute{Type: attrType, Vals: attrVals})
|
||||
}
|
||||
|
||||
// NewAddRequest returns an AddRequest for the given DN, with no attributes
|
||||
func NewAddRequest(dn string, controls []Control) *AddRequest {
|
||||
return &AddRequest{
|
||||
DN: dn,
|
||||
Controls: controls,
|
||||
}
|
||||
}
|
||||
|
||||
// Add performs the given AddRequest
|
||||
func (l *Conn) Add(addRequest *AddRequest) error {
|
||||
msgCtx, err := l.doRequest(addRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if packet.Children[1].Tag == ApplicationAddResponse {
|
||||
err := GetLDAPError(packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("ldap: unexpected response: %d", packet.Children[1].Tag)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,735 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
enchex "encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/go-ntlmssp"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// SimpleBindRequest represents a username/password bind operation
|
||||
type SimpleBindRequest struct {
|
||||
// Username is the name of the Directory object that the client wishes to bind as
|
||||
Username string
|
||||
// Password is the credentials to bind with
|
||||
Password string
|
||||
// Controls are optional controls to send with the bind request
|
||||
Controls []Control
|
||||
// AllowEmptyPassword sets whether the client allows binding with an empty password
|
||||
// (normally used for unauthenticated bind).
|
||||
AllowEmptyPassword bool
|
||||
}
|
||||
|
||||
// SimpleBindResult contains the response from the server
|
||||
type SimpleBindResult struct {
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// NewSimpleBindRequest returns a bind request
|
||||
func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest {
|
||||
return &SimpleBindRequest{
|
||||
Username: username,
|
||||
Password: password,
|
||||
Controls: controls,
|
||||
AllowEmptyPassword: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error {
|
||||
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name"))
|
||||
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password"))
|
||||
|
||||
envelope.AppendChild(pkt)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SimpleBind performs the simple bind operation defined in the given request
|
||||
func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) {
|
||||
if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword {
|
||||
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
|
||||
}
|
||||
|
||||
msgCtx, err := l.doRequest(simpleBindRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SimpleBindResult{
|
||||
Controls: make([]Control, 0),
|
||||
}
|
||||
|
||||
if len(packet.Children) == 3 {
|
||||
for _, child := range packet.Children[2].Children {
|
||||
decodedChild, decodeErr := DecodeControl(child)
|
||||
if decodeErr != nil {
|
||||
return nil, fmt.Errorf("failed to decode child control: %s", decodeErr)
|
||||
}
|
||||
result.Controls = append(result.Controls, decodedChild)
|
||||
}
|
||||
}
|
||||
|
||||
err = GetLDAPError(packet)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Bind performs a bind with the given username and password.
|
||||
//
|
||||
// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method
|
||||
// for that.
|
||||
func (l *Conn) Bind(username, password string) error {
|
||||
req := &SimpleBindRequest{
|
||||
Username: username,
|
||||
Password: password,
|
||||
AllowEmptyPassword: false,
|
||||
}
|
||||
_, err := l.SimpleBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// UnauthenticatedBind performs an unauthenticated bind.
|
||||
//
|
||||
// A username may be provided for trace (e.g. logging) purpose only, but it is normally not
|
||||
// authenticated or otherwise validated by the LDAP server.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc4513#section-5.1.2 .
|
||||
// See https://tools.ietf.org/html/rfc4513#section-6.3.1 .
|
||||
func (l *Conn) UnauthenticatedBind(username string) error {
|
||||
req := &SimpleBindRequest{
|
||||
Username: username,
|
||||
Password: "",
|
||||
AllowEmptyPassword: true,
|
||||
}
|
||||
_, err := l.SimpleBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// DigestMD5BindRequest represents a digest-md5 bind operation
|
||||
type DigestMD5BindRequest struct {
|
||||
Host string
|
||||
// Username is the name of the Directory object that the client wishes to bind as
|
||||
Username string
|
||||
// Password is the credentials to bind with
|
||||
Password string
|
||||
// Controls are optional controls to send with the bind request
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
func (req *DigestMD5BindRequest) appendTo(envelope *ber.Packet) error {
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
|
||||
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "DIGEST-MD5", "SASL Mech"))
|
||||
request.AppendChild(auth)
|
||||
envelope.AppendChild(request)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DigestMD5BindResult contains the response from the server
|
||||
type DigestMD5BindResult struct {
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// MD5Bind performs a digest-md5 bind with the given host, username and password.
|
||||
func (l *Conn) MD5Bind(host, username, password string) error {
|
||||
req := &DigestMD5BindRequest{
|
||||
Host: host,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
_, err := l.DigestMD5Bind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// DigestMD5Bind performs the digest-md5 bind operation defined in the given request
|
||||
func (l *Conn) DigestMD5Bind(digestMD5BindRequest *DigestMD5BindRequest) (*DigestMD5BindResult, error) {
|
||||
if digestMD5BindRequest.Password == "" {
|
||||
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
|
||||
}
|
||||
|
||||
msgCtx, err := l.doRequest(digestMD5BindRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if l.Debug {
|
||||
if err = addLDAPDescriptions(packet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ber.PrintPacket(packet)
|
||||
}
|
||||
|
||||
result := &DigestMD5BindResult{
|
||||
Controls: make([]Control, 0),
|
||||
}
|
||||
var params map[string]string
|
||||
if len(packet.Children) == 2 {
|
||||
if len(packet.Children[1].Children) == 4 {
|
||||
child := packet.Children[1].Children[0]
|
||||
if child.Tag != ber.TagEnumerated {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
if child.Value.(int64) != 14 {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
child = packet.Children[1].Children[3]
|
||||
if child.Tag != ber.TagObjectDescriptor {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
if child.Data == nil {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
data, _ := ioutil.ReadAll(child.Data)
|
||||
params, err = parseParams(string(data))
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("parsing digest-challenge: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if params != nil {
|
||||
resp := computeResponse(
|
||||
params,
|
||||
"ldap/"+strings.ToLower(digestMD5BindRequest.Host),
|
||||
digestMD5BindRequest.Username,
|
||||
digestMD5BindRequest.Password,
|
||||
)
|
||||
packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||||
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
|
||||
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
|
||||
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "DIGEST-MD5", "SASL Mech"))
|
||||
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, resp, "Credentials"))
|
||||
request.AppendChild(auth)
|
||||
packet.AppendChild(request)
|
||||
msgCtx, err = l.sendMessage(packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send message: %s", err)
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
packetResponse, ok := <-msgCtx.responses
|
||||
if !ok {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
|
||||
}
|
||||
packet, err = packetResponse.ReadPacket()
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read packet: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = GetLDAPError(packet)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func parseParams(str string) (map[string]string, error) {
|
||||
m := make(map[string]string)
|
||||
var key, value string
|
||||
var state int
|
||||
for i := 0; i <= len(str); i++ {
|
||||
switch state {
|
||||
case 0: // reading key
|
||||
if i == len(str) {
|
||||
return nil, fmt.Errorf("syntax error on %d", i)
|
||||
}
|
||||
if str[i] != '=' {
|
||||
key += string(str[i])
|
||||
continue
|
||||
}
|
||||
state = 1
|
||||
case 1: // reading value
|
||||
if i == len(str) {
|
||||
m[key] = value
|
||||
break
|
||||
}
|
||||
switch str[i] {
|
||||
case ',':
|
||||
m[key] = value
|
||||
state = 0
|
||||
key = ""
|
||||
value = ""
|
||||
case '"':
|
||||
if value != "" {
|
||||
return nil, fmt.Errorf("syntax error on %d", i)
|
||||
}
|
||||
state = 2
|
||||
default:
|
||||
value += string(str[i])
|
||||
}
|
||||
case 2: // inside quotes
|
||||
if i == len(str) {
|
||||
return nil, fmt.Errorf("syntax error on %d", i)
|
||||
}
|
||||
if str[i] != '"' {
|
||||
value += string(str[i])
|
||||
} else {
|
||||
state = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func computeResponse(params map[string]string, uri, username, password string) string {
|
||||
nc := "00000001"
|
||||
qop := "auth"
|
||||
cnonce := enchex.EncodeToString(randomBytes(16))
|
||||
x := username + ":" + params["realm"] + ":" + password
|
||||
y := md5Hash([]byte(x))
|
||||
|
||||
a1 := bytes.NewBuffer(y)
|
||||
a1.WriteString(":" + params["nonce"] + ":" + cnonce)
|
||||
if len(params["authzid"]) > 0 {
|
||||
a1.WriteString(":" + params["authzid"])
|
||||
}
|
||||
a2 := bytes.NewBuffer([]byte("AUTHENTICATE"))
|
||||
a2.WriteString(":" + uri)
|
||||
ha1 := enchex.EncodeToString(md5Hash(a1.Bytes()))
|
||||
ha2 := enchex.EncodeToString(md5Hash(a2.Bytes()))
|
||||
|
||||
kd := ha1
|
||||
kd += ":" + params["nonce"]
|
||||
kd += ":" + nc
|
||||
kd += ":" + cnonce
|
||||
kd += ":" + qop
|
||||
kd += ":" + ha2
|
||||
resp := enchex.EncodeToString(md5Hash([]byte(kd)))
|
||||
return fmt.Sprintf(
|
||||
`username="%s",realm="%s",nonce="%s",cnonce="%s",nc=00000001,qop=%s,digest-uri="%s",response=%s`,
|
||||
username,
|
||||
params["realm"],
|
||||
params["nonce"],
|
||||
cnonce,
|
||||
qop,
|
||||
uri,
|
||||
resp,
|
||||
)
|
||||
}
|
||||
|
||||
func md5Hash(b []byte) []byte {
|
||||
hasher := md5.New()
|
||||
hasher.Write(b)
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
func randomBytes(len int) []byte {
|
||||
b := make([]byte, len)
|
||||
for i := 0; i < len; i++ {
|
||||
b[i] = byte(rand.Intn(256))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
var externalBindRequest = requestFunc(func(envelope *ber.Packet) error {
|
||||
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
|
||||
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech"))
|
||||
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred"))
|
||||
|
||||
pkt.AppendChild(saslAuth)
|
||||
|
||||
envelope.AppendChild(pkt)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// ExternalBind performs SASL/EXTERNAL authentication.
|
||||
//
|
||||
// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc4422#appendix-A
|
||||
func (l *Conn) ExternalBind() error {
|
||||
msgCtx, err := l.doRequest(externalBindRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return GetLDAPError(packet)
|
||||
}
|
||||
|
||||
// NTLMBind performs an NTLMSSP bind leveraging https://github.com/Azure/go-ntlmssp
|
||||
|
||||
// NTLMBindRequest represents an NTLMSSP bind operation
|
||||
type NTLMBindRequest struct {
|
||||
// Domain is the AD Domain to authenticate too. If not specified, it will be grabbed from the NTLMSSP Challenge
|
||||
Domain string
|
||||
// Username is the name of the Directory object that the client wishes to bind as
|
||||
Username string
|
||||
// Password is the credentials to bind with
|
||||
Password string
|
||||
// AllowEmptyPassword sets whether the client allows binding with an empty password
|
||||
// (normally used for unauthenticated bind).
|
||||
AllowEmptyPassword bool
|
||||
// Hash is the hex NTLM hash to bind with. Password or hash must be provided
|
||||
Hash string
|
||||
// Controls are optional controls to send with the bind request
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
func (req *NTLMBindRequest) appendTo(envelope *ber.Packet) error {
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
// generate an NTLMSSP Negotiation message for the specified domain (it can be blank)
|
||||
negMessage, err := ntlmssp.NewNegotiateMessage(req.Domain, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("err creating negmessage: %s", err)
|
||||
}
|
||||
|
||||
// append the generated NTLMSSP message as a TagEnumerated BER value
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypePrimitive, ber.TagEnumerated, negMessage, "authentication")
|
||||
request.AppendChild(auth)
|
||||
envelope.AppendChild(request)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NTLMBindResult contains the response from the server
|
||||
type NTLMBindResult struct {
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// NTLMBind performs an NTLMSSP Bind with the given domain, username and password
|
||||
func (l *Conn) NTLMBind(domain, username, password string) error {
|
||||
req := &NTLMBindRequest{
|
||||
Domain: domain,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
_, err := l.NTLMChallengeBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// NTLMUnauthenticatedBind performs an bind with an empty password.
|
||||
//
|
||||
// A username is required. The anonymous bind is not (yet) supported by the go-ntlmssp library (https://github.com/Azure/go-ntlmssp/blob/819c794454d067543bc61d29f61fef4b3c3df62c/authenticate_message.go#L87)
|
||||
//
|
||||
// See https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/b38c36ed-2804-4868-a9ff-8dd3182128e4 part 3.2.5.1.2
|
||||
func (l *Conn) NTLMUnauthenticatedBind(domain, username string) error {
|
||||
req := &NTLMBindRequest{
|
||||
Domain: domain,
|
||||
Username: username,
|
||||
Password: "",
|
||||
AllowEmptyPassword: true,
|
||||
}
|
||||
_, err := l.NTLMChallengeBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// NTLMBindWithHash performs an NTLM Bind with an NTLM hash instead of plaintext password (pass-the-hash)
|
||||
func (l *Conn) NTLMBindWithHash(domain, username, hash string) error {
|
||||
req := &NTLMBindRequest{
|
||||
Domain: domain,
|
||||
Username: username,
|
||||
Hash: hash,
|
||||
}
|
||||
_, err := l.NTLMChallengeBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// NTLMChallengeBind performs the NTLMSSP bind operation defined in the given request
|
||||
func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindResult, error) {
|
||||
if !ntlmBindRequest.AllowEmptyPassword && ntlmBindRequest.Password == "" && ntlmBindRequest.Hash == "" {
|
||||
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
|
||||
}
|
||||
|
||||
msgCtx, err := l.doRequest(ntlmBindRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if l.Debug {
|
||||
if err = addLDAPDescriptions(packet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ber.PrintPacket(packet)
|
||||
}
|
||||
result := &NTLMBindResult{
|
||||
Controls: make([]Control, 0),
|
||||
}
|
||||
var ntlmsspChallenge []byte
|
||||
|
||||
// now find the NTLM Response Message
|
||||
if len(packet.Children) == 2 {
|
||||
if len(packet.Children[1].Children) == 3 {
|
||||
child := packet.Children[1].Children[1]
|
||||
ntlmsspChallenge = child.ByteValue
|
||||
// Check to make sure we got the right message. It will always start with NTLMSSP
|
||||
if len(ntlmsspChallenge) < 7 || !bytes.Equal(ntlmsspChallenge[:7], []byte("NTLMSSP")) {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
l.Debug.Printf("%d: found ntlmssp challenge", msgCtx.id)
|
||||
}
|
||||
}
|
||||
if ntlmsspChallenge != nil {
|
||||
var err error
|
||||
var responseMessage []byte
|
||||
// generate a response message to the challenge with the given Username/Password if password is provided
|
||||
if ntlmBindRequest.Hash != "" {
|
||||
responseMessage, err = ntlmssp.ProcessChallengeWithHash(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Hash)
|
||||
} else if ntlmBindRequest.Password != "" || ntlmBindRequest.AllowEmptyPassword {
|
||||
_, _, domainNeeded := ntlmssp.GetDomain(ntlmBindRequest.Username)
|
||||
responseMessage, err = ntlmssp.ProcessChallenge(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Password, domainNeeded)
|
||||
} else {
|
||||
err = fmt.Errorf("need a password or hash to generate reply")
|
||||
}
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("parsing ntlm-challenge: %s", err)
|
||||
}
|
||||
packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||||
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
|
||||
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
// append the challenge response message as a TagEmbeddedPDV BER value
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypePrimitive, ber.TagEmbeddedPDV, responseMessage, "authentication")
|
||||
|
||||
request.AppendChild(auth)
|
||||
packet.AppendChild(request)
|
||||
msgCtx, err = l.sendMessage(packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send message: %s", err)
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
packetResponse, ok := <-msgCtx.responses
|
||||
if !ok {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
|
||||
}
|
||||
packet, err = packetResponse.ReadPacket()
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read packet: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
err = GetLDAPError(packet)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// GSSAPIClient interface is used as the client-side implementation for the
|
||||
// GSSAPI SASL mechanism.
|
||||
// Interface inspired by GSSAPIClient from golang.org/x/crypto/ssh
|
||||
type GSSAPIClient interface {
|
||||
// InitSecContext initiates the establishment of a security context for
|
||||
// GSS-API between the client and server.
|
||||
// Initially the token parameter should be specified as nil.
|
||||
// The routine may return a outputToken which should be transferred to
|
||||
// the server, where the server will present it to AcceptSecContext.
|
||||
// If no token need be sent, InitSecContext will indicate this by setting
|
||||
// needContinue to false. To complete the context
|
||||
// establishment, one or more reply tokens may be required from the server;
|
||||
// if so, InitSecContext will return a needContinue which is true.
|
||||
// In this case, InitSecContext should be called again when the
|
||||
// reply token is received from the server, passing the reply token
|
||||
// to InitSecContext via the token parameters.
|
||||
// See RFC 4752 section 3.1.
|
||||
InitSecContext(target string, token []byte) (outputToken []byte, needContinue bool, err error)
|
||||
// NegotiateSaslAuth performs the last step of the Sasl handshake.
|
||||
// It takes a token, which, when unwrapped, describes the servers supported
|
||||
// security layers (first octet) and maximum receive buffer (remaining
|
||||
// three octets).
|
||||
// If the received token is unacceptable an error must be returned to abort
|
||||
// the handshake.
|
||||
// Outputs a signed token describing the client's selected security layer
|
||||
// and receive buffer size and optionally an authorization identity.
|
||||
// The returned token will be sent to the server and the handshake considered
|
||||
// completed successfully and the server authenticated.
|
||||
// See RFC 4752 section 3.1.
|
||||
NegotiateSaslAuth(token []byte, authzid string) ([]byte, error)
|
||||
// DeleteSecContext destroys any established secure context.
|
||||
DeleteSecContext() error
|
||||
}
|
||||
|
||||
// GSSAPIBindRequest represents a GSSAPI SASL mechanism bind request.
|
||||
// See rfc4752 and rfc4513 section 5.2.1.2.
|
||||
type GSSAPIBindRequest struct {
|
||||
// Service Principal Name user for the service ticket. Eg. "ldap/<host>"
|
||||
ServicePrincipalName string
|
||||
// (Optional) Authorization entity
|
||||
AuthZID string
|
||||
// (Optional) Controls to send with the bind request
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// GSSAPIBind performs the GSSAPI SASL bind using the provided GSSAPI client.
|
||||
func (l *Conn) GSSAPIBind(client GSSAPIClient, servicePrincipal, authzid string) error {
|
||||
return l.GSSAPIBindRequest(client, &GSSAPIBindRequest{
|
||||
ServicePrincipalName: servicePrincipal,
|
||||
AuthZID: authzid,
|
||||
})
|
||||
}
|
||||
|
||||
// GSSAPIBindRequest performs the GSSAPI SASL bind using the provided GSSAPI client.
|
||||
func (l *Conn) GSSAPIBindRequest(client GSSAPIClient, req *GSSAPIBindRequest) error {
|
||||
//nolint:errcheck
|
||||
defer client.DeleteSecContext()
|
||||
|
||||
var err error
|
||||
var reqToken []byte
|
||||
var recvToken []byte
|
||||
needInit := true
|
||||
for {
|
||||
if needInit {
|
||||
// Establish secure context between client and server.
|
||||
reqToken, needInit, err = client.InitSecContext(req.ServicePrincipalName, recvToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Secure context is set up, perform the last step of SASL handshake.
|
||||
reqToken, err = client.NegotiateSaslAuth(recvToken, req.AuthZID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Send Bind request containing the current token and extract the
|
||||
// token sent by server.
|
||||
recvToken, err = l.saslBindTokenExchange(req.Controls, reqToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !needInit && len(recvToken) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Conn) saslBindTokenExchange(reqControls []Control, reqToken []byte) ([]byte, error) {
|
||||
// Construct LDAP Bind request with GSSAPI SASL mechanism.
|
||||
envelope := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||||
envelope.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
|
||||
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
|
||||
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "GSSAPI", "SASL Mech"))
|
||||
if len(reqToken) > 0 {
|
||||
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, string(reqToken), "Credentials"))
|
||||
}
|
||||
request.AppendChild(auth)
|
||||
envelope.AppendChild(request)
|
||||
if len(reqControls) > 0 {
|
||||
envelope.AppendChild(encodeControls(reqControls))
|
||||
}
|
||||
|
||||
msgCtx, err := l.sendMessage(envelope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if l.Debug {
|
||||
if err = addLDAPDescriptions(packet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ber.PrintPacket(packet)
|
||||
}
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc4511#section-4.1.1
|
||||
// packet is an envelope
|
||||
// child 0 is message id
|
||||
// child 1 is protocolOp
|
||||
if len(packet.Children) != 2 {
|
||||
return nil, fmt.Errorf("bad bind response")
|
||||
}
|
||||
|
||||
protocolOp := packet.Children[1]
|
||||
RESP:
|
||||
switch protocolOp.Description {
|
||||
case "Bind Response": // Bind Response
|
||||
// Bind Reponse is an LDAP Response (https://www.rfc-editor.org/rfc/rfc4511#section-4.1.9)
|
||||
// with an additional optional serverSaslCreds string (https://www.rfc-editor.org/rfc/rfc4511#section-4.2.2)
|
||||
// child 0 is resultCode
|
||||
resultCode := protocolOp.Children[0]
|
||||
if resultCode.Tag != ber.TagEnumerated {
|
||||
break RESP
|
||||
}
|
||||
switch resultCode.Value.(int64) {
|
||||
case 14: // Sasl bind in progress
|
||||
if len(protocolOp.Children) < 3 {
|
||||
break RESP
|
||||
}
|
||||
referral := protocolOp.Children[3]
|
||||
switch referral.Description {
|
||||
case "Referral":
|
||||
if referral.ClassType != ber.ClassContext || referral.Tag != ber.TagObjectDescriptor {
|
||||
break RESP
|
||||
}
|
||||
return ioutil.ReadAll(referral.Data)
|
||||
}
|
||||
// Optional:
|
||||
//if len(protocolOp.Children) == 4 {
|
||||
// serverSaslCreds := protocolOp.Children[4]
|
||||
//}
|
||||
case 0: // Success - Bind OK.
|
||||
// SASL layer in effect (if any) (See https://www.rfc-editor.org/rfc/rfc4513#section-5.2.1.4)
|
||||
// NOTE: SASL security layers are not supported currently.
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, GetLDAPError(packet)
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client knows how to interact with an LDAP server
|
||||
type Client interface {
|
||||
Start()
|
||||
StartTLS(*tls.Config) error
|
||||
Close() error
|
||||
GetLastError() error
|
||||
IsClosing() bool
|
||||
SetTimeout(time.Duration)
|
||||
TLSConnectionState() (tls.ConnectionState, bool)
|
||||
|
||||
Bind(username, password string) error
|
||||
UnauthenticatedBind(username string) error
|
||||
SimpleBind(*SimpleBindRequest) (*SimpleBindResult, error)
|
||||
ExternalBind() error
|
||||
NTLMUnauthenticatedBind(domain, username string) error
|
||||
Unbind() error
|
||||
|
||||
Add(*AddRequest) error
|
||||
Del(*DelRequest) error
|
||||
Modify(*ModifyRequest) error
|
||||
ModifyDN(*ModifyDNRequest) error
|
||||
ModifyWithResult(*ModifyRequest) (*ModifyResult, error)
|
||||
|
||||
Compare(dn, attribute, value string) (bool, error)
|
||||
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)
|
||||
|
||||
Search(*SearchRequest) (*SearchResult, error)
|
||||
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
|
||||
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
|
||||
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
|
||||
DirSyncAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int, flags, maxAttrCount int64, cookie []byte) Response
|
||||
Syncrepl(ctx context.Context, searchRequest *SearchRequest, bufferSize int, mode ControlSyncRequestMode, cookie []byte, reloadHint bool) Response
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// CompareRequest represents an LDAP CompareRequest operation.
|
||||
type CompareRequest struct {
|
||||
DN string
|
||||
Attribute string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (req *CompareRequest) appendTo(envelope *ber.Packet) error {
|
||||
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request")
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
|
||||
|
||||
ava := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "AttributeValueAssertion")
|
||||
ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Attribute, "AttributeDesc"))
|
||||
ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Value, "AssertionValue"))
|
||||
|
||||
pkt.AppendChild(ava)
|
||||
|
||||
envelope.AppendChild(pkt)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise
|
||||
// false with any error that occurs if any.
|
||||
func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
|
||||
msgCtx, err := l.doRequest(&CompareRequest{
|
||||
DN: dn,
|
||||
Attribute: attribute,
|
||||
Value: value,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if packet.Children[1].Tag == ApplicationCompareResponse {
|
||||
err := GetLDAPError(packet)
|
||||
|
||||
switch {
|
||||
case IsErrorWithCode(err, LDAPResultCompareTrue):
|
||||
return true, nil
|
||||
case IsErrorWithCode(err, LDAPResultCompareFalse):
|
||||
return false, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
return false, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag)
|
||||
}
|
|
@ -0,0 +1,629 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
const (
|
||||
// MessageQuit causes the processMessages loop to exit
|
||||
MessageQuit = 0
|
||||
// MessageRequest sends a request to the server
|
||||
MessageRequest = 1
|
||||
// MessageResponse receives a response from the server
|
||||
MessageResponse = 2
|
||||
// MessageFinish indicates the client considers a particular message ID to be finished
|
||||
MessageFinish = 3
|
||||
// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
|
||||
MessageTimeout = 4
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLdapPort default ldap port for pure TCP connection
|
||||
DefaultLdapPort = "389"
|
||||
// DefaultLdapsPort default ldap port for SSL connection
|
||||
DefaultLdapsPort = "636"
|
||||
)
|
||||
|
||||
// PacketResponse contains the packet or error encountered reading a response
|
||||
type PacketResponse struct {
|
||||
// Packet is the packet read from the server
|
||||
Packet *ber.Packet
|
||||
// Error is an error encountered while reading
|
||||
Error error
|
||||
}
|
||||
|
||||
// ReadPacket returns the packet or an error
|
||||
func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
|
||||
if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
|
||||
}
|
||||
return pr.Packet, pr.Error
|
||||
}
|
||||
|
||||
type messageContext struct {
|
||||
id int64
|
||||
// close(done) should only be called from finishMessage()
|
||||
done chan struct{}
|
||||
// close(responses) should only be called from processMessages(), and only sent to from sendResponse()
|
||||
responses chan *PacketResponse
|
||||
}
|
||||
|
||||
// sendResponse should only be called within the processMessages() loop which
|
||||
// is also responsible for closing the responses channel.
|
||||
func (msgCtx *messageContext) sendResponse(packet *PacketResponse, timeout time.Duration) {
|
||||
timeoutCtx := context.Background()
|
||||
if timeout > 0 {
|
||||
var cancelFunc context.CancelFunc
|
||||
timeoutCtx, cancelFunc = context.WithTimeout(context.Background(), timeout)
|
||||
defer cancelFunc()
|
||||
}
|
||||
select {
|
||||
case msgCtx.responses <- packet:
|
||||
// Successfully sent packet to message handler.
|
||||
case <-msgCtx.done:
|
||||
// The request handler is done and will not receive more
|
||||
// packets.
|
||||
case <-timeoutCtx.Done():
|
||||
// The timeout was reached before the packet was sent.
|
||||
}
|
||||
}
|
||||
|
||||
type messagePacket struct {
|
||||
Op int
|
||||
MessageID int64
|
||||
Packet *ber.Packet
|
||||
Context *messageContext
|
||||
}
|
||||
|
||||
type sendMessageFlags uint
|
||||
|
||||
const (
|
||||
startTLS sendMessageFlags = 1 << iota
|
||||
)
|
||||
|
||||
// Conn represents an LDAP Connection
|
||||
type Conn struct {
|
||||
// requestTimeout is loaded atomically
|
||||
// so we need to ensure 64-bit alignment on 32-bit platforms.
|
||||
// https://github.com/go-ldap/ldap/pull/199
|
||||
requestTimeout int64
|
||||
conn net.Conn
|
||||
isTLS bool
|
||||
closing uint32
|
||||
closeErr atomic.Value
|
||||
isStartingTLS bool
|
||||
Debug debugging
|
||||
chanConfirm chan struct{}
|
||||
messageContexts map[int64]*messageContext
|
||||
chanMessage chan *messagePacket
|
||||
chanMessageID chan int64
|
||||
wgClose sync.WaitGroup
|
||||
outstandingRequests uint
|
||||
messageMutex sync.Mutex
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
var _ Client = &Conn{}
|
||||
|
||||
// DefaultTimeout is a package-level variable that sets the timeout value
|
||||
// used for the Dial and DialTLS methods.
|
||||
//
|
||||
// WARNING: since this is a package-level variable, setting this value from
|
||||
// multiple places will probably result in undesired behaviour.
|
||||
var DefaultTimeout = 60 * time.Second
|
||||
|
||||
// DialOpt configures DialContext.
|
||||
type DialOpt func(*DialContext)
|
||||
|
||||
// DialWithDialer updates net.Dialer in DialContext.
|
||||
func DialWithDialer(d *net.Dialer) DialOpt {
|
||||
return func(dc *DialContext) {
|
||||
dc.dialer = d
|
||||
}
|
||||
}
|
||||
|
||||
// DialWithTLSConfig updates tls.Config in DialContext.
|
||||
func DialWithTLSConfig(tc *tls.Config) DialOpt {
|
||||
return func(dc *DialContext) {
|
||||
dc.tlsConfig = tc
|
||||
}
|
||||
}
|
||||
|
||||
// DialWithTLSDialer is a wrapper for DialWithTLSConfig with the option to
|
||||
// specify a net.Dialer to for example define a timeout or a custom resolver.
|
||||
// @deprecated Use DialWithDialer and DialWithTLSConfig instead
|
||||
func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt {
|
||||
return func(dc *DialContext) {
|
||||
dc.tlsConfig = tlsConfig
|
||||
dc.dialer = dialer
|
||||
}
|
||||
}
|
||||
|
||||
// DialContext contains necessary parameters to dial the given ldap URL.
|
||||
type DialContext struct {
|
||||
dialer *net.Dialer
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
|
||||
if u.Scheme == "ldapi" {
|
||||
if u.Path == "" || u.Path == "/" {
|
||||
u.Path = "/var/run/slapd/ldapi"
|
||||
}
|
||||
return dc.dialer.Dial("unix", u.Path)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
// we assume that error is due to missing port
|
||||
host = u.Host
|
||||
port = ""
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "cldap":
|
||||
if port == "" {
|
||||
port = DefaultLdapPort
|
||||
}
|
||||
return dc.dialer.Dial("udp", net.JoinHostPort(host, port))
|
||||
case "ldap":
|
||||
if port == "" {
|
||||
port = DefaultLdapPort
|
||||
}
|
||||
return dc.dialer.Dial("tcp", net.JoinHostPort(host, port))
|
||||
case "ldaps":
|
||||
if port == "" {
|
||||
port = DefaultLdapsPort
|
||||
}
|
||||
return tls.DialWithDialer(dc.dialer, "tcp", net.JoinHostPort(host, port), dc.tlsConfig)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
|
||||
}
|
||||
|
||||
// Dial connects to the given address on the given network using net.Dial
|
||||
// and then returns a new Conn for the connection.
|
||||
// @deprecated Use DialURL instead.
|
||||
func Dial(network, addr string) (*Conn, error) {
|
||||
c, err := net.DialTimeout(network, addr, DefaultTimeout)
|
||||
if err != nil {
|
||||
return nil, NewError(ErrorNetwork, err)
|
||||
}
|
||||
conn := NewConn(c, false)
|
||||
conn.Start()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialTLS connects to the given address on the given network using tls.Dial
|
||||
// and then returns a new Conn for the connection.
|
||||
// @deprecated Use DialURL instead.
|
||||
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
|
||||
c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
|
||||
if err != nil {
|
||||
return nil, NewError(ErrorNetwork, err)
|
||||
}
|
||||
conn := NewConn(c, true)
|
||||
conn.Start()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialURL connects to the given ldap URL.
|
||||
// The following schemas are supported: ldap://, ldaps://, ldapi://,
|
||||
// and cldap:// (RFC1798, deprecated but used by Active Directory).
|
||||
// On success a new Conn for the connection is returned.
|
||||
func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
|
||||
u, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, NewError(ErrorNetwork, err)
|
||||
}
|
||||
|
||||
var dc DialContext
|
||||
for _, opt := range opts {
|
||||
opt(&dc)
|
||||
}
|
||||
if dc.dialer == nil {
|
||||
dc.dialer = &net.Dialer{Timeout: DefaultTimeout}
|
||||
}
|
||||
|
||||
c, err := dc.dial(u)
|
||||
if err != nil {
|
||||
return nil, NewError(ErrorNetwork, err)
|
||||
}
|
||||
|
||||
conn := NewConn(c, u.Scheme == "ldaps")
|
||||
conn.Start()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// NewConn returns a new Conn using conn for network I/O.
|
||||
func NewConn(conn net.Conn, isTLS bool) *Conn {
|
||||
l := &Conn{
|
||||
conn: conn,
|
||||
chanConfirm: make(chan struct{}),
|
||||
chanMessageID: make(chan int64),
|
||||
chanMessage: make(chan *messagePacket, 10),
|
||||
messageContexts: map[int64]*messageContext{},
|
||||
requestTimeout: 0,
|
||||
isTLS: isTLS,
|
||||
}
|
||||
l.wgClose.Add(1)
|
||||
return l
|
||||
}
|
||||
|
||||
// Start initializes goroutines to read responses and process messages
|
||||
func (l *Conn) Start() {
|
||||
go l.reader()
|
||||
go l.processMessages()
|
||||
}
|
||||
|
||||
// IsClosing returns whether or not we're currently closing.
|
||||
func (l *Conn) IsClosing() bool {
|
||||
return atomic.LoadUint32(&l.closing) == 1
|
||||
}
|
||||
|
||||
// setClosing sets the closing value to true
|
||||
func (l *Conn) setClosing() bool {
|
||||
return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (l *Conn) Close() (err error) {
|
||||
l.messageMutex.Lock()
|
||||
defer l.messageMutex.Unlock()
|
||||
|
||||
if l.setClosing() {
|
||||
l.Debug.Printf("Sending quit message and waiting for confirmation")
|
||||
l.chanMessage <- &messagePacket{Op: MessageQuit}
|
||||
|
||||
timeoutCtx := context.Background()
|
||||
if l.getTimeout() > 0 {
|
||||
var cancelFunc context.CancelFunc
|
||||
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout()))
|
||||
defer cancelFunc()
|
||||
}
|
||||
select {
|
||||
case <-l.chanConfirm:
|
||||
// Confirmation was received.
|
||||
case <-timeoutCtx.Done():
|
||||
// The timeout was reached before confirmation was received.
|
||||
}
|
||||
|
||||
close(l.chanMessage)
|
||||
|
||||
l.Debug.Printf("Closing network connection")
|
||||
err = l.conn.Close()
|
||||
l.wgClose.Done()
|
||||
}
|
||||
l.wgClose.Wait()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
|
||||
func (l *Conn) SetTimeout(timeout time.Duration) {
|
||||
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
|
||||
}
|
||||
|
||||
func (l *Conn) getTimeout() int64 {
|
||||
return atomic.LoadInt64(&l.requestTimeout)
|
||||
}
|
||||
|
||||
// Returns the next available messageID
|
||||
func (l *Conn) nextMessageID() int64 {
|
||||
if messageID, ok := <-l.chanMessageID; ok {
|
||||
return messageID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetLastError returns the last recorded error from goroutines like processMessages and reader.
|
||||
// Only the last recorded error will be returned.
|
||||
func (l *Conn) GetLastError() error {
|
||||
l.messageMutex.Lock()
|
||||
defer l.messageMutex.Unlock()
|
||||
return l.err
|
||||
}
|
||||
|
||||
// StartTLS sends the command to start a TLS session and then creates a new TLS Client
|
||||
func (l *Conn) StartTLS(config *tls.Config) error {
|
||||
if l.isTLS {
|
||||
return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
|
||||
}
|
||||
|
||||
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||||
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
|
||||
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
|
||||
packet.AppendChild(request)
|
||||
l.Debug.PrintPacket(packet)
|
||||
|
||||
msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
l.Debug.Printf("%d: waiting for response", msgCtx.id)
|
||||
|
||||
packetResponse, ok := <-msgCtx.responses
|
||||
if !ok {
|
||||
return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
|
||||
}
|
||||
packet, err = packetResponse.ReadPacket()
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if l.Debug {
|
||||
if err := addLDAPDescriptions(packet); err != nil {
|
||||
l.Close()
|
||||
return err
|
||||
}
|
||||
l.Debug.PrintPacket(packet)
|
||||
}
|
||||
|
||||
if err := GetLDAPError(packet); err == nil {
|
||||
conn := tls.Client(l.conn, config)
|
||||
|
||||
if connErr := conn.Handshake(); connErr != nil {
|
||||
l.Close()
|
||||
return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
|
||||
}
|
||||
|
||||
l.isTLS = true
|
||||
l.conn = conn
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
go l.reader()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TLSConnectionState returns the client's TLS connection state.
|
||||
// The return values are their zero values if StartTLS did
|
||||
// not succeed.
|
||||
func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
|
||||
tc, ok := l.conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
return tc.ConnectionState(), true
|
||||
}
|
||||
|
||||
func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
|
||||
return l.sendMessageWithFlags(packet, 0)
|
||||
}
|
||||
|
||||
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
|
||||
if l.IsClosing() {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
|
||||
}
|
||||
l.messageMutex.Lock()
|
||||
l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
|
||||
if l.isStartingTLS {
|
||||
l.messageMutex.Unlock()
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
|
||||
}
|
||||
if flags&startTLS != 0 {
|
||||
if l.outstandingRequests != 0 {
|
||||
l.messageMutex.Unlock()
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
|
||||
}
|
||||
l.isStartingTLS = true
|
||||
}
|
||||
l.outstandingRequests++
|
||||
|
||||
l.messageMutex.Unlock()
|
||||
|
||||
responses := make(chan *PacketResponse)
|
||||
messageID := packet.Children[0].Value.(int64)
|
||||
message := &messagePacket{
|
||||
Op: MessageRequest,
|
||||
MessageID: messageID,
|
||||
Packet: packet,
|
||||
Context: &messageContext{
|
||||
id: messageID,
|
||||
done: make(chan struct{}),
|
||||
responses: responses,
|
||||
},
|
||||
}
|
||||
if !l.sendProcessMessage(message) {
|
||||
if l.IsClosing() {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
|
||||
}
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message for unknown reason"))
|
||||
}
|
||||
return message.Context, nil
|
||||
}
|
||||
|
||||
func (l *Conn) finishMessage(msgCtx *messageContext) {
|
||||
close(msgCtx.done)
|
||||
|
||||
if l.IsClosing() {
|
||||
return
|
||||
}
|
||||
|
||||
l.messageMutex.Lock()
|
||||
l.outstandingRequests--
|
||||
if l.isStartingTLS {
|
||||
l.isStartingTLS = false
|
||||
}
|
||||
l.messageMutex.Unlock()
|
||||
|
||||
message := &messagePacket{
|
||||
Op: MessageFinish,
|
||||
MessageID: msgCtx.id,
|
||||
}
|
||||
l.sendProcessMessage(message)
|
||||
}
|
||||
|
||||
func (l *Conn) sendProcessMessage(message *messagePacket) bool {
|
||||
l.messageMutex.Lock()
|
||||
defer l.messageMutex.Unlock()
|
||||
if l.IsClosing() {
|
||||
return false
|
||||
}
|
||||
l.chanMessage <- message
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *Conn) processMessages() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
l.err = fmt.Errorf("ldap: recovered panic in processMessages: %v", err)
|
||||
}
|
||||
for messageID, msgCtx := range l.messageContexts {
|
||||
// If we are closing due to an error, inform anyone who
|
||||
// is waiting about the error.
|
||||
if l.IsClosing() && l.closeErr.Load() != nil {
|
||||
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
|
||||
}
|
||||
l.Debug.Printf("Closing channel for MessageID %d", messageID)
|
||||
close(msgCtx.responses)
|
||||
delete(l.messageContexts, messageID)
|
||||
}
|
||||
close(l.chanMessageID)
|
||||
close(l.chanConfirm)
|
||||
}()
|
||||
|
||||
var messageID int64 = 1
|
||||
for {
|
||||
select {
|
||||
case l.chanMessageID <- messageID:
|
||||
messageID++
|
||||
case message := <-l.chanMessage:
|
||||
switch message.Op {
|
||||
case MessageQuit:
|
||||
l.Debug.Printf("Shutting down - quit message received")
|
||||
return
|
||||
case MessageRequest:
|
||||
// Add to message list and write to network
|
||||
l.Debug.Printf("Sending message %d", message.MessageID)
|
||||
|
||||
buf := message.Packet.Bytes()
|
||||
_, err := l.conn.Write(buf)
|
||||
if err != nil {
|
||||
l.Debug.Printf("Error Sending Message: %s", err.Error())
|
||||
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
|
||||
close(message.Context.responses)
|
||||
break
|
||||
}
|
||||
|
||||
// Only add to messageContexts if we were able to
|
||||
// successfully write the message.
|
||||
l.messageContexts[message.MessageID] = message.Context
|
||||
|
||||
// Add timeout if defined
|
||||
requestTimeout := l.getTimeout()
|
||||
if requestTimeout > 0 {
|
||||
go func() {
|
||||
timer := time.NewTimer(time.Duration(requestTimeout))
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
|
||||
}
|
||||
|
||||
timer.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
timeoutMessage := &messagePacket{
|
||||
Op: MessageTimeout,
|
||||
MessageID: message.MessageID,
|
||||
}
|
||||
l.sendProcessMessage(timeoutMessage)
|
||||
case <-message.Context.done:
|
||||
}
|
||||
}()
|
||||
}
|
||||
case MessageResponse:
|
||||
l.Debug.Printf("Receiving message %d", message.MessageID)
|
||||
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
|
||||
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
|
||||
} else {
|
||||
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
|
||||
l.Debug.PrintPacket(message.Packet)
|
||||
}
|
||||
case MessageTimeout:
|
||||
// Handle the timeout by closing the channel
|
||||
// All reads will return immediately
|
||||
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
|
||||
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
|
||||
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
|
||||
delete(l.messageContexts, message.MessageID)
|
||||
close(msgCtx.responses)
|
||||
}
|
||||
case MessageFinish:
|
||||
l.Debug.Printf("Finished message %d", message.MessageID)
|
||||
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
|
||||
delete(l.messageContexts, message.MessageID)
|
||||
close(msgCtx.responses)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Conn) reader() {
|
||||
cleanstop := false
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
l.err = fmt.Errorf("ldap: recovered panic in reader: %v", err)
|
||||
}
|
||||
if !cleanstop {
|
||||
l.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
bufConn := bufio.NewReader(l.conn)
|
||||
for {
|
||||
if cleanstop {
|
||||
l.Debug.Printf("reader clean stopping (without closing the connection)")
|
||||
return
|
||||
}
|
||||
packet, err := ber.ReadPacket(bufConn)
|
||||
if err != nil {
|
||||
// A read error is expected here if we are closing the connection...
|
||||
if !l.IsClosing() {
|
||||
l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
|
||||
l.Debug.Printf("reader error: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := addLDAPDescriptions(packet); err != nil {
|
||||
l.Debug.Printf("descriptions error: %s", err)
|
||||
}
|
||||
if len(packet.Children) == 0 {
|
||||
l.Debug.Printf("Received bad ldap packet")
|
||||
continue
|
||||
}
|
||||
l.messageMutex.Lock()
|
||||
if l.isStartingTLS {
|
||||
cleanstop = true
|
||||
}
|
||||
l.messageMutex.Unlock()
|
||||
message := &messagePacket{
|
||||
Op: MessageResponse,
|
||||
MessageID: packet.Children[0].Value.(int64),
|
||||
Packet: packet,
|
||||
}
|
||||
if !l.sendProcessMessage(message) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,28 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// debugging type
|
||||
// - has a Printf method to write the debug output
|
||||
type debugging bool
|
||||
|
||||
// Enable controls debugging mode.
|
||||
func (debug *debugging) Enable(b bool) {
|
||||
*debug = debugging(b)
|
||||
}
|
||||
|
||||
// Printf writes debug output.
|
||||
func (debug debugging) Printf(format string, args ...interface{}) {
|
||||
if debug {
|
||||
logger.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// PrintPacket dumps a packet.
|
||||
func (debug debugging) PrintPacket(packet *ber.Packet) {
|
||||
if debug {
|
||||
ber.WritePacket(logger.Writer(), packet)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// DelRequest implements an LDAP deletion request
|
||||
type DelRequest struct {
|
||||
// DN is the name of the directory entry to delete
|
||||
DN string
|
||||
// Controls hold optional controls to send with the request
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
func (req *DelRequest) appendTo(envelope *ber.Packet) error {
|
||||
pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request")
|
||||
pkt.Data.Write([]byte(req.DN))
|
||||
|
||||
envelope.AppendChild(pkt)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewDelRequest creates a delete request for the given DN and controls
|
||||
func NewDelRequest(DN string, Controls []Control) *DelRequest {
|
||||
return &DelRequest{
|
||||
DN: DN,
|
||||
Controls: Controls,
|
||||
}
|
||||
}
|
||||
|
||||
// Del executes the given delete request
|
||||
func (l *Conn) Del(delRequest *DelRequest) error {
|
||||
msgCtx, err := l.doRequest(delRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if packet.Children[1].Tag == ApplicationDelResponse {
|
||||
err := GetLDAPError(packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("ldap: unexpected response: %d", packet.Children[1].Tag)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,350 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
enchex "encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
|
||||
type AttributeTypeAndValue struct {
|
||||
// Type is the attribute type
|
||||
Type string
|
||||
// Value is the attribute value
|
||||
Value string
|
||||
}
|
||||
|
||||
// String returns a normalized string representation of this attribute type and
|
||||
// value pair which is the a lowercased join of the Type and Value with a "=".
|
||||
func (a *AttributeTypeAndValue) String() string {
|
||||
return strings.ToLower(a.Type) + "=" + a.encodeValue()
|
||||
}
|
||||
|
||||
func (a *AttributeTypeAndValue) encodeValue() string {
|
||||
// Normalize the value first.
|
||||
// value := strings.ToLower(a.Value)
|
||||
value := a.Value
|
||||
|
||||
encodedBuf := bytes.Buffer{}
|
||||
|
||||
escapeChar := func(c byte) {
|
||||
encodedBuf.WriteByte('\\')
|
||||
encodedBuf.WriteByte(c)
|
||||
}
|
||||
|
||||
escapeHex := func(c byte) {
|
||||
encodedBuf.WriteByte('\\')
|
||||
encodedBuf.WriteString(enchex.EncodeToString([]byte{c}))
|
||||
}
|
||||
|
||||
for i := 0; i < len(value); i++ {
|
||||
char := value[i]
|
||||
if i == 0 && char == ' ' || char == '#' {
|
||||
// Special case leading space or number sign.
|
||||
escapeChar(char)
|
||||
continue
|
||||
}
|
||||
if i == len(value)-1 && char == ' ' {
|
||||
// Special case trailing space.
|
||||
escapeChar(char)
|
||||
continue
|
||||
}
|
||||
|
||||
switch char {
|
||||
case '"', '+', ',', ';', '<', '>', '\\':
|
||||
// Each of these special characters must be escaped.
|
||||
escapeChar(char)
|
||||
continue
|
||||
}
|
||||
|
||||
if char < ' ' || char > '~' {
|
||||
// All special character escapes are handled first
|
||||
// above. All bytes less than ASCII SPACE and all bytes
|
||||
// greater than ASCII TILDE must be hex-escaped.
|
||||
escapeHex(char)
|
||||
continue
|
||||
}
|
||||
|
||||
// Any other character does not require escaping.
|
||||
encodedBuf.WriteByte(char)
|
||||
}
|
||||
|
||||
return encodedBuf.String()
|
||||
}
|
||||
|
||||
// RelativeDN represents a relativeDistinguishedName from https://tools.ietf.org/html/rfc4514
|
||||
type RelativeDN struct {
|
||||
Attributes []*AttributeTypeAndValue
|
||||
}
|
||||
|
||||
// String returns a normalized string representation of this relative DN which
|
||||
// is the a join of all attributes (sorted in increasing order) with a "+".
|
||||
func (r *RelativeDN) String() string {
|
||||
attrs := make([]string, len(r.Attributes))
|
||||
for i := range r.Attributes {
|
||||
attrs[i] = r.Attributes[i].String()
|
||||
}
|
||||
sort.Strings(attrs)
|
||||
return strings.Join(attrs, "+")
|
||||
}
|
||||
|
||||
// DN represents a distinguishedName from https://tools.ietf.org/html/rfc4514
|
||||
type DN struct {
|
||||
RDNs []*RelativeDN
|
||||
}
|
||||
|
||||
// String returns a normalized string representation of this DN which is the
|
||||
// join of all relative DNs with a ",".
|
||||
func (d *DN) String() string {
|
||||
rdns := make([]string, len(d.RDNs))
|
||||
for i := range d.RDNs {
|
||||
rdns[i] = d.RDNs[i].String()
|
||||
}
|
||||
return strings.Join(rdns, ",")
|
||||
}
|
||||
|
||||
// ParseDN returns a distinguishedName or an error.
|
||||
// The function respects https://tools.ietf.org/html/rfc4514
|
||||
func ParseDN(str string) (*DN, error) {
|
||||
dn := new(DN)
|
||||
dn.RDNs = make([]*RelativeDN, 0)
|
||||
rdn := new(RelativeDN)
|
||||
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
|
||||
buffer := bytes.Buffer{}
|
||||
attribute := new(AttributeTypeAndValue)
|
||||
escaping := false
|
||||
|
||||
unescapedTrailingSpaces := 0
|
||||
stringFromBuffer := func() string {
|
||||
s := buffer.String()
|
||||
s = s[0 : len(s)-unescapedTrailingSpaces]
|
||||
buffer.Reset()
|
||||
unescapedTrailingSpaces = 0
|
||||
return s
|
||||
}
|
||||
|
||||
for i := 0; i < len(str); i++ {
|
||||
char := str[i]
|
||||
switch {
|
||||
case escaping:
|
||||
unescapedTrailingSpaces = 0
|
||||
escaping = false
|
||||
switch char {
|
||||
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
|
||||
buffer.WriteByte(char)
|
||||
continue
|
||||
}
|
||||
// Not a special character, assume hex encoded octet
|
||||
if len(str) == i+1 {
|
||||
return nil, errors.New("got corrupted escaped character")
|
||||
}
|
||||
|
||||
dst := []byte{0}
|
||||
n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode escaped character: %s", err)
|
||||
} else if n != 1 {
|
||||
return nil, fmt.Errorf("expected 1 byte when un-escaping, got %d", n)
|
||||
}
|
||||
buffer.WriteByte(dst[0])
|
||||
i++
|
||||
case char == '\\':
|
||||
unescapedTrailingSpaces = 0
|
||||
escaping = true
|
||||
case char == '=' && attribute.Type == "":
|
||||
attribute.Type = stringFromBuffer()
|
||||
// Special case: If the first character in the value is # the
|
||||
// following data is BER encoded so we can just fast forward
|
||||
// and decode.
|
||||
if len(str) > i+1 && str[i+1] == '#' {
|
||||
i += 2
|
||||
index := strings.IndexAny(str[i:], ",+")
|
||||
var data string
|
||||
if index > 0 {
|
||||
data = str[i : i+index]
|
||||
} else {
|
||||
data = str[i:]
|
||||
}
|
||||
rawBER, err := enchex.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode BER encoding: %s", err)
|
||||
}
|
||||
packet, err := ber.DecodePacketErr(rawBER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode BER packet: %s", err)
|
||||
}
|
||||
buffer.WriteString(packet.Data.String())
|
||||
i += len(data) - 1
|
||||
}
|
||||
case char == ',' || char == '+' || char == ';':
|
||||
// We're done with this RDN or value, push it
|
||||
if len(attribute.Type) == 0 {
|
||||
return nil, errors.New("incomplete type, value pair")
|
||||
}
|
||||
attribute.Value = stringFromBuffer()
|
||||
rdn.Attributes = append(rdn.Attributes, attribute)
|
||||
attribute = new(AttributeTypeAndValue)
|
||||
if char == ',' || char == ';' {
|
||||
dn.RDNs = append(dn.RDNs, rdn)
|
||||
rdn = new(RelativeDN)
|
||||
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
|
||||
}
|
||||
case char == ' ' && buffer.Len() == 0:
|
||||
// ignore unescaped leading spaces
|
||||
continue
|
||||
default:
|
||||
if char == ' ' {
|
||||
// Track unescaped spaces in case they are trailing and we need to remove them
|
||||
unescapedTrailingSpaces++
|
||||
} else {
|
||||
// Reset if we see a non-space char
|
||||
unescapedTrailingSpaces = 0
|
||||
}
|
||||
buffer.WriteByte(char)
|
||||
}
|
||||
}
|
||||
if buffer.Len() > 0 {
|
||||
if len(attribute.Type) == 0 {
|
||||
return nil, errors.New("DN ended with incomplete type, value pair")
|
||||
}
|
||||
attribute.Value = stringFromBuffer()
|
||||
rdn.Attributes = append(rdn.Attributes, attribute)
|
||||
dn.RDNs = append(dn.RDNs, rdn)
|
||||
}
|
||||
return dn, nil
|
||||
}
|
||||
|
||||
// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
|
||||
// Returns true if they have the same number of relative distinguished names
|
||||
// and corresponding relative distinguished names (by position) are the same.
|
||||
func (d *DN) Equal(other *DN) bool {
|
||||
if len(d.RDNs) != len(other.RDNs) {
|
||||
return false
|
||||
}
|
||||
for i := range d.RDNs {
|
||||
if !d.RDNs[i].Equal(other.RDNs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// AncestorOf returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN.
|
||||
// "ou=widgets,o=acme.com" is an ancestor of "ou=sprockets,ou=widgets,o=acme.com"
|
||||
// "ou=widgets,o=acme.com" is not an ancestor of "ou=sprockets,ou=widgets,o=foo.com"
|
||||
// "ou=widgets,o=acme.com" is not an ancestor of "ou=widgets,o=acme.com"
|
||||
func (d *DN) AncestorOf(other *DN) bool {
|
||||
if len(d.RDNs) >= len(other.RDNs) {
|
||||
return false
|
||||
}
|
||||
// Take the last `len(d.RDNs)` RDNs from the other DN to compare against
|
||||
otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
|
||||
for i := range d.RDNs {
|
||||
if !d.RDNs[i].Equal(otherRDNs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
|
||||
// Relative distinguished names are the same if and only if they have the same number of AttributeTypeAndValues
|
||||
// and each attribute of the first RDN is the same as the attribute of the second RDN with the same attribute type.
|
||||
// The order of attributes is not significant.
|
||||
// Case of attribute types is not significant.
|
||||
func (r *RelativeDN) Equal(other *RelativeDN) bool {
|
||||
if len(r.Attributes) != len(other.Attributes) {
|
||||
return false
|
||||
}
|
||||
return r.hasAllAttributes(other.Attributes) && other.hasAllAttributes(r.Attributes)
|
||||
}
|
||||
|
||||
func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool {
|
||||
for _, attr := range attrs {
|
||||
found := false
|
||||
for _, myattr := range r.Attributes {
|
||||
if myattr.Equal(attr) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Equal returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue
|
||||
// Case of the attribute type is not significant
|
||||
func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
|
||||
return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
|
||||
}
|
||||
|
||||
// EqualFold returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
|
||||
// Returns true if they have the same number of relative distinguished names
|
||||
// and corresponding relative distinguished names (by position) are the same.
|
||||
// Case of the attribute type and value is not significant
|
||||
func (d *DN) EqualFold(other *DN) bool {
|
||||
if len(d.RDNs) != len(other.RDNs) {
|
||||
return false
|
||||
}
|
||||
for i := range d.RDNs {
|
||||
if !d.RDNs[i].EqualFold(other.RDNs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// AncestorOfFold returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN.
|
||||
// Case of the attribute type and value is not significant
|
||||
func (d *DN) AncestorOfFold(other *DN) bool {
|
||||
if len(d.RDNs) >= len(other.RDNs) {
|
||||
return false
|
||||
}
|
||||
// Take the last `len(d.RDNs)` RDNs from the other DN to compare against
|
||||
otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
|
||||
for i := range d.RDNs {
|
||||
if !d.RDNs[i].EqualFold(otherRDNs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// EqualFold returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
|
||||
// Case of the attribute type is not significant
|
||||
func (r *RelativeDN) EqualFold(other *RelativeDN) bool {
|
||||
if len(r.Attributes) != len(other.Attributes) {
|
||||
return false
|
||||
}
|
||||
return r.hasAllAttributesFold(other.Attributes) && other.hasAllAttributesFold(r.Attributes)
|
||||
}
|
||||
|
||||
func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool {
|
||||
for _, attr := range attrs {
|
||||
found := false
|
||||
for _, myattr := range r.Attributes {
|
||||
if myattr.EqualFold(attr) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// EqualFold returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue
|
||||
// Case of the attribute type and value is not significant
|
||||
func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool {
|
||||
return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value)
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
/*
|
||||
Package ldap provides basic LDAP v3 functionality.
|
||||
*/
|
||||
package ldap
|
|
@ -0,0 +1,261 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// LDAP Result Codes
|
||||
const (
|
||||
LDAPResultSuccess = 0
|
||||
LDAPResultOperationsError = 1
|
||||
LDAPResultProtocolError = 2
|
||||
LDAPResultTimeLimitExceeded = 3
|
||||
LDAPResultSizeLimitExceeded = 4
|
||||
LDAPResultCompareFalse = 5
|
||||
LDAPResultCompareTrue = 6
|
||||
LDAPResultAuthMethodNotSupported = 7
|
||||
LDAPResultStrongAuthRequired = 8
|
||||
LDAPResultReferral = 10
|
||||
LDAPResultAdminLimitExceeded = 11
|
||||
LDAPResultUnavailableCriticalExtension = 12
|
||||
LDAPResultConfidentialityRequired = 13
|
||||
LDAPResultSaslBindInProgress = 14
|
||||
LDAPResultNoSuchAttribute = 16
|
||||
LDAPResultUndefinedAttributeType = 17
|
||||
LDAPResultInappropriateMatching = 18
|
||||
LDAPResultConstraintViolation = 19
|
||||
LDAPResultAttributeOrValueExists = 20
|
||||
LDAPResultInvalidAttributeSyntax = 21
|
||||
LDAPResultNoSuchObject = 32
|
||||
LDAPResultAliasProblem = 33
|
||||
LDAPResultInvalidDNSyntax = 34
|
||||
LDAPResultIsLeaf = 35
|
||||
LDAPResultAliasDereferencingProblem = 36
|
||||
LDAPResultInappropriateAuthentication = 48
|
||||
LDAPResultInvalidCredentials = 49
|
||||
LDAPResultInsufficientAccessRights = 50
|
||||
LDAPResultBusy = 51
|
||||
LDAPResultUnavailable = 52
|
||||
LDAPResultUnwillingToPerform = 53
|
||||
LDAPResultLoopDetect = 54
|
||||
LDAPResultSortControlMissing = 60
|
||||
LDAPResultOffsetRangeError = 61
|
||||
LDAPResultNamingViolation = 64
|
||||
LDAPResultObjectClassViolation = 65
|
||||
LDAPResultNotAllowedOnNonLeaf = 66
|
||||
LDAPResultNotAllowedOnRDN = 67
|
||||
LDAPResultEntryAlreadyExists = 68
|
||||
LDAPResultObjectClassModsProhibited = 69
|
||||
LDAPResultResultsTooLarge = 70
|
||||
LDAPResultAffectsMultipleDSAs = 71
|
||||
LDAPResultVirtualListViewErrorOrControlError = 76
|
||||
LDAPResultOther = 80
|
||||
LDAPResultServerDown = 81
|
||||
LDAPResultLocalError = 82
|
||||
LDAPResultEncodingError = 83
|
||||
LDAPResultDecodingError = 84
|
||||
LDAPResultTimeout = 85
|
||||
LDAPResultAuthUnknown = 86
|
||||
LDAPResultFilterError = 87
|
||||
LDAPResultUserCanceled = 88
|
||||
LDAPResultParamError = 89
|
||||
LDAPResultNoMemory = 90
|
||||
LDAPResultConnectError = 91
|
||||
LDAPResultNotSupported = 92
|
||||
LDAPResultControlNotFound = 93
|
||||
LDAPResultNoResultsReturned = 94
|
||||
LDAPResultMoreResultsToReturn = 95
|
||||
LDAPResultClientLoop = 96
|
||||
LDAPResultReferralLimitExceeded = 97
|
||||
LDAPResultInvalidResponse = 100
|
||||
LDAPResultAmbiguousResponse = 101
|
||||
LDAPResultTLSNotSupported = 112
|
||||
LDAPResultIntermediateResponse = 113
|
||||
LDAPResultUnknownType = 114
|
||||
LDAPResultCanceled = 118
|
||||
LDAPResultNoSuchOperation = 119
|
||||
LDAPResultTooLate = 120
|
||||
LDAPResultCannotCancel = 121
|
||||
LDAPResultAssertionFailed = 122
|
||||
LDAPResultAuthorizationDenied = 123
|
||||
LDAPResultSyncRefreshRequired = 4096
|
||||
|
||||
ErrorNetwork = 200
|
||||
ErrorFilterCompile = 201
|
||||
ErrorFilterDecompile = 202
|
||||
ErrorDebugging = 203
|
||||
ErrorUnexpectedMessage = 204
|
||||
ErrorUnexpectedResponse = 205
|
||||
ErrorEmptyPassword = 206
|
||||
)
|
||||
|
||||
// LDAPResultCodeMap contains string descriptions for LDAP error codes
|
||||
var LDAPResultCodeMap = map[uint16]string{
|
||||
LDAPResultSuccess: "Success",
|
||||
LDAPResultOperationsError: "Operations Error",
|
||||
LDAPResultProtocolError: "Protocol Error",
|
||||
LDAPResultTimeLimitExceeded: "Time Limit Exceeded",
|
||||
LDAPResultSizeLimitExceeded: "Size Limit Exceeded",
|
||||
LDAPResultCompareFalse: "Compare False",
|
||||
LDAPResultCompareTrue: "Compare True",
|
||||
LDAPResultAuthMethodNotSupported: "Auth Method Not Supported",
|
||||
LDAPResultStrongAuthRequired: "Strong Auth Required",
|
||||
LDAPResultReferral: "Referral",
|
||||
LDAPResultAdminLimitExceeded: "Admin Limit Exceeded",
|
||||
LDAPResultUnavailableCriticalExtension: "Unavailable Critical Extension",
|
||||
LDAPResultConfidentialityRequired: "Confidentiality Required",
|
||||
LDAPResultSaslBindInProgress: "Sasl Bind In Progress",
|
||||
LDAPResultNoSuchAttribute: "No Such Attribute",
|
||||
LDAPResultUndefinedAttributeType: "Undefined Attribute Type",
|
||||
LDAPResultInappropriateMatching: "Inappropriate Matching",
|
||||
LDAPResultConstraintViolation: "Constraint Violation",
|
||||
LDAPResultAttributeOrValueExists: "Attribute Or Value Exists",
|
||||
LDAPResultInvalidAttributeSyntax: "Invalid Attribute Syntax",
|
||||
LDAPResultNoSuchObject: "No Such Object",
|
||||
LDAPResultAliasProblem: "Alias Problem",
|
||||
LDAPResultInvalidDNSyntax: "Invalid DN Syntax",
|
||||
LDAPResultIsLeaf: "Is Leaf",
|
||||
LDAPResultAliasDereferencingProblem: "Alias Dereferencing Problem",
|
||||
LDAPResultInappropriateAuthentication: "Inappropriate Authentication",
|
||||
LDAPResultInvalidCredentials: "Invalid Credentials",
|
||||
LDAPResultInsufficientAccessRights: "Insufficient Access Rights",
|
||||
LDAPResultBusy: "Busy",
|
||||
LDAPResultUnavailable: "Unavailable",
|
||||
LDAPResultUnwillingToPerform: "Unwilling To Perform",
|
||||
LDAPResultLoopDetect: "Loop Detect",
|
||||
LDAPResultSortControlMissing: "Sort Control Missing",
|
||||
LDAPResultOffsetRangeError: "Result Offset Range Error",
|
||||
LDAPResultNamingViolation: "Naming Violation",
|
||||
LDAPResultObjectClassViolation: "Object Class Violation",
|
||||
LDAPResultResultsTooLarge: "Results Too Large",
|
||||
LDAPResultNotAllowedOnNonLeaf: "Not Allowed On Non Leaf",
|
||||
LDAPResultNotAllowedOnRDN: "Not Allowed On RDN",
|
||||
LDAPResultEntryAlreadyExists: "Entry Already Exists",
|
||||
LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited",
|
||||
LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs",
|
||||
LDAPResultVirtualListViewErrorOrControlError: "Failed because of a problem related to the virtual list view",
|
||||
LDAPResultOther: "Other",
|
||||
LDAPResultServerDown: "Cannot establish a connection",
|
||||
LDAPResultLocalError: "An error occurred",
|
||||
LDAPResultEncodingError: "LDAP encountered an error while encoding",
|
||||
LDAPResultDecodingError: "LDAP encountered an error while decoding",
|
||||
LDAPResultTimeout: "LDAP timeout while waiting for a response from the server",
|
||||
LDAPResultAuthUnknown: "The auth method requested in a bind request is unknown",
|
||||
LDAPResultFilterError: "An error occurred while encoding the given search filter",
|
||||
LDAPResultUserCanceled: "The user canceled the operation",
|
||||
LDAPResultParamError: "An invalid parameter was specified",
|
||||
LDAPResultNoMemory: "Out of memory error",
|
||||
LDAPResultConnectError: "A connection to the server could not be established",
|
||||
LDAPResultNotSupported: "An attempt has been made to use a feature not supported LDAP",
|
||||
LDAPResultControlNotFound: "The controls required to perform the requested operation were not found",
|
||||
LDAPResultNoResultsReturned: "No results were returned from the server",
|
||||
LDAPResultMoreResultsToReturn: "There are more results in the chain of results",
|
||||
LDAPResultClientLoop: "A loop has been detected. For example when following referrals",
|
||||
LDAPResultReferralLimitExceeded: "The referral hop limit has been exceeded",
|
||||
LDAPResultCanceled: "Operation was canceled",
|
||||
LDAPResultNoSuchOperation: "Server has no knowledge of the operation requested for cancellation",
|
||||
LDAPResultTooLate: "Too late to cancel the outstanding operation",
|
||||
LDAPResultCannotCancel: "The identified operation does not support cancellation or the cancel operation cannot be performed",
|
||||
LDAPResultAssertionFailed: "An assertion control given in the LDAP operation evaluated to false causing the operation to not be performed",
|
||||
LDAPResultSyncRefreshRequired: "Refresh Required",
|
||||
LDAPResultInvalidResponse: "Invalid Response",
|
||||
LDAPResultAmbiguousResponse: "Ambiguous Response",
|
||||
LDAPResultTLSNotSupported: "Tls Not Supported",
|
||||
LDAPResultIntermediateResponse: "Intermediate Response",
|
||||
LDAPResultUnknownType: "Unknown Type",
|
||||
LDAPResultAuthorizationDenied: "Authorization Denied",
|
||||
|
||||
ErrorNetwork: "Network Error",
|
||||
ErrorFilterCompile: "Filter Compile Error",
|
||||
ErrorFilterDecompile: "Filter Decompile Error",
|
||||
ErrorDebugging: "Debugging Error",
|
||||
ErrorUnexpectedMessage: "Unexpected Message",
|
||||
ErrorUnexpectedResponse: "Unexpected Response",
|
||||
ErrorEmptyPassword: "Empty password not allowed by the client",
|
||||
}
|
||||
|
||||
// Error holds LDAP error information
|
||||
type Error struct {
|
||||
// Err is the underlying error
|
||||
Err error
|
||||
// ResultCode is the LDAP error code
|
||||
ResultCode uint16
|
||||
// MatchedDN is the matchedDN returned if any
|
||||
MatchedDN string
|
||||
// Packet is the returned packet if any
|
||||
Packet *ber.Packet
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error())
|
||||
}
|
||||
|
||||
func (e *Error) Unwrap() error { return e.Err }
|
||||
|
||||
// GetLDAPError creates an Error out of a BER packet representing a LDAPResult
|
||||
// The return is an error object. It can be casted to a Error structure.
|
||||
// This function returns nil if resultCode in the LDAPResult sequence is success(0).
|
||||
func GetLDAPError(packet *ber.Packet) error {
|
||||
if packet == nil {
|
||||
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty packet")}
|
||||
}
|
||||
|
||||
if len(packet.Children) >= 2 {
|
||||
response := packet.Children[1]
|
||||
if response == nil {
|
||||
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet}
|
||||
}
|
||||
if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 {
|
||||
if ber.Type(response.Children[0].Tag) == ber.Type(ber.TagInteger) || ber.Type(response.Children[0].Tag) == ber.Type(ber.TagEnumerated) {
|
||||
resultCode := uint16(response.Children[0].Value.(int64))
|
||||
if resultCode == 0 { // No error
|
||||
return nil
|
||||
}
|
||||
|
||||
if ber.Type(response.Children[1].Tag) == ber.Type(ber.TagOctetString) &&
|
||||
ber.Type(response.Children[2].Tag) == ber.Type(ber.TagOctetString) {
|
||||
return &Error{
|
||||
ResultCode: resultCode,
|
||||
MatchedDN: response.Children[1].Value.(string),
|
||||
Err: fmt.Errorf("%s", response.Children[2].Value.(string)),
|
||||
Packet: packet,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format"), Packet: packet}
|
||||
}
|
||||
|
||||
// NewError creates an LDAP error with the given code and underlying error
|
||||
func NewError(resultCode uint16, err error) error {
|
||||
return &Error{ResultCode: resultCode, Err: err}
|
||||
}
|
||||
|
||||
// IsErrorAnyOf returns true if the given error is an LDAP error with any one of the given result codes
|
||||
func IsErrorAnyOf(err error, codes ...uint16) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
serverError, ok := err.(*Error)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, code := range codes {
|
||||
if serverError.ResultCode == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsErrorWithCode returns true if the given error is an LDAP error with the given result code
|
||||
func IsErrorWithCode(err error, desiredResultCode uint16) bool {
|
||||
return IsErrorAnyOf(err, desiredResultCode)
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue