From 546b08039f125193856b2508d5b50a29767f057b Mon Sep 17 00:00:00 2001 From: Xavier Henner Date: Wed, 26 Mar 2025 13:47:05 +0100 Subject: [PATCH] add multi instance support --- Makefile | 15 ++--- config.go | 51 +++++++++++++++-- http.go | 57 ++++++++++++++---- jsonrpc.go | 89 +++++++++++++++++------------ pdns-proxy.conf.example | 7 ++- test_todo/pdns-proxy-test.conf | 15 ++--- test_todo/pdns-proxy-unit-test.conf | 9 ++- 7 files changed, 170 insertions(+), 73 deletions(-) diff --git a/Makefile b/Makefile index cd30bcc..b6e6b06 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ BUILDDIR := ${PREFIX}/build DOCKER_IMAGE := $(docker images --format "{{.Repository}}" --filter=reference='$(PROJECTNAME):$(GITCOMMIT)') GOOSARCHES := linux/amd64 -all: build fmt lint test +all: build fmt test .PHONY: fmt fmt: ## Verifies all files have been `gofmt`ed. @@ -30,13 +30,6 @@ fmt: ## Verifies all files have been `gofmt`ed. @if [ ! -z "${var}" ]; then exit 1; fi $(MAKE) -C client fmt -.PHONY: lint -lint: ## Verifies `golint` passes. - @echo "+ $@" - $(eval var = $(shell golint ./... | grep -v vendor | tee /dev/stderr)) - @if [ ! -z "${var}" ]; then exit 1; fi - $(MAKE) -C client lint - .PHONY: tag tag: ## Create a new git tag to prepare to build a release git tag -sa $(VERSION) -m "$(VERSION)" @@ -45,8 +38,8 @@ tag: ## Create a new git tag to prepare to build a release .PHONY: test test: ## Generates test certificates & run unit tests. $(MAKE) -C fixtures/test - go vet -mod=vendor - go test -v -mod=vendor -coverprofile=coverage.out + go vet + go test -v -coverprofile=coverage.out $(MAKE) -C client test .PHONY: resetreplay @@ -62,7 +55,7 @@ deb: .PHONY: build build: ## Builds a static executable. @echo "+ $@" - CGO_ENABLED=0 $(GO) build -mod=vendor \ + CGO_ENABLED=0 $(GO) build \ -o $(PROJECTNAME) \ -tags "static_build netgo" \ -trimpath \ diff --git a/config.go b/config.go index e2c15fd..f83bfd4 100644 --- a/config.go +++ b/config.go @@ -1,29 +1,72 @@ package main import ( + "errors" "fmt" "log" + "regexp" "strings" "github.com/pyke369/golang-support/uconfig" ) +func getPDNSInstances(config *uconfig.UConfig) ([]pdnsInstance, error) { + var err error + instances := []pdnsInstance{} + hasDefaultInstance := false + i := pdnsInstance{} + for _, profile := range config.GetPaths("config.pdns") { + if i.dns, err = NewClient( + config.GetString(profile+".api-url", "http://127.0.0.1:8081/api/v1/servers/localhost"), + config.GetString(profile+".api-key", ""), + int(config.GetInteger(profile+".timeout", 3)), + int(config.GetInteger(profile+".defaultTTL", 3600))); err != nil { + return nil, err + } + i.isDefault = config.GetBoolean(profile+".default", false) + if i.isDefault && hasDefaultInstance { + return nil, errors.New("There can be only one default instance") + } else { + hasDefaultInstance = true + } + + for _, r := range parseConfigArray(config, profile+".whenRegexp") { + re, err := regexp.Compile(r) + if err != nil { + return nil, err + } + i.regexp = append(i.regexp, re) + } + instances = append(instances, i) + } + if len(instances) == 0 { + return nil, errors.New("There must be at least one instance") + } + if len(instances) > 1 && !hasDefaultInstance { + return nil, errors.New("There must be at least one default instance") + } + return instances, nil +} + func loadConfig(configFile string) (*HTTPServer, error) { // Parse the configuration file config, err := uconfig.New(configFile) if err != nil { return nil, err } + + instances, err := getPDNSInstances(config) + if err != nil { + return nil, err + } + h := NewHTTPServer( config.GetString("config.http.port", "127.0.0.01:8080"), config.GetString("config.http.key", ""), config.GetString("config.http.cert", ""), config.GetString("config.http.crl", ""), config.GetString("config.http.ca", ""), - config.GetString("config.pdns.api-url", "http://127.0.0.1:8081/api/v1/servers/localhost"), - config.GetString("config.pdns.api-key", ""), - int(config.GetInteger("config.pdns.timeout", 3)), - int(config.GetInteger("config.pdns.defaultTTL", 3600)), + instances, ) populateHTTPServerZoneProfiles(config, h) populateHTTPServerAcls(config, h) diff --git a/http.go b/http.go index 35f6dc6..4cd96b2 100644 --- a/http.go +++ b/http.go @@ -40,11 +40,16 @@ type ( certPool *x509.CertPool debug bool m sync.RWMutex - dns *PowerDNS + dnsInstances []pdnsInstance nonceGen string certCache map[string]time.Time zoneProfiles map[string]*zoneProfile } + pdnsInstance struct { + dns *PowerDNS + isDefault bool + regexp []*regexp.Regexp + } zoneProfile struct { Default bool NameServers []string @@ -73,7 +78,7 @@ func (e JSONRPCError) String() string { } // NewHTTPServer initializes HTTPServer -func NewHTTPServer(port, key, cert, crl, ca, pdnsServer, pdnsKey string, timeout, ttl int) *HTTPServer { +func NewHTTPServer(port string, key string, cert string, crl string, ca string, instances []pdnsInstance) *HTTPServer { rand.Seed(time.Now().UnixNano()) h := HTTPServer{ Port: port, @@ -87,6 +92,7 @@ func NewHTTPServer(port, key, cert, crl, ca, pdnsServer, pdnsKey string, timeout pdnsAcls: []*PdnsACL{}, certPool: x509.NewCertPool(), decodedCA: []*x509.Certificate{}, + dnsInstances: instances, } rawCA, err := ioutil.ReadFile(ca) @@ -109,9 +115,6 @@ func NewHTTPServer(port, key, cert, crl, ca, pdnsServer, pdnsKey string, timeout h.certPool.AddCert(cert) h.decodedCA = append(h.decodedCA, cert) } - if h.dns, err = NewClient(pdnsServer, pdnsKey, timeout, ttl); err != nil { - log.Fatal(err) - } if _, err := h.RefreshCRL(); err != nil { log.Fatal(err) } @@ -177,7 +180,9 @@ func (h *HTTPServer) unlock() { // Debug facilities func (h *HTTPServer) Debug() { h.debug = true - h.dns.Debug() + for _, d := range h.dnsInstances { + d.dns.Debug() + } } // verifyNonce check that the once is valid and less than 10s old @@ -359,7 +364,7 @@ func (h *HTTPServer) jrpcDecodeQuery(body []byte) (JSONArray, []byte, []byte, bo b, message := clearsign.Decode(body) // this is not a valid PGP signed payload, meaning message is all we got if b == nil { - jsonRPC, wasArray, err := ParsejsonRPCRequest(message, h.dns) + jsonRPC, wasArray, err := ParsejsonRPCRequest(message, h) return jsonRPC, message, nil, wasArray, err } // this is a valid PGP signed payload, we can extract the real payload and @@ -375,7 +380,7 @@ func (h *HTTPServer) jrpcDecodeQuery(body []byte) (JSONArray, []byte, []byte, bo if err != nil { return JSONArray{}, message, signature, false, err } - jsonRPC, wasArray, err := ParsejsonRPCRequest(message, h.dns) + jsonRPC, wasArray, err := ParsejsonRPCRequest(message, h) return jsonRPC, message, signature, wasArray, err } @@ -416,6 +421,27 @@ func (h *HTTPServer) jsonRPCServe(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, jsonRPC.Run(h, username, wasArray, r.Header.Get("PDNS-Output") == "plaintext")) } +// Find the right PDNS instance to use for a given query +func (h *HTTPServer) findDNSInstance(query string) *PowerDNS { + var d *PowerDNS + + if len(h.dnsInstances) == 1 { + return h.dnsInstances[0].dns + } + for _, i := range h.dnsInstances { + if i.isDefault { + d = i.dns + } + for _, re := range i.regexp { + if !re.MatchString(query) { + continue + } + return i.dns + } + } + return d +} + // PowerDNS native API support. Add certificate support for auth func (h *HTTPServer) nativeAPIServe(w http.ResponseWriter, r *http.Request) { // Check if the user/server is allowed to perform the required action @@ -424,20 +450,29 @@ func (h *HTTPServer) nativeAPIServe(w http.ResponseWriter, r *http.Request) { http.Error(w, certError.Error(), 403) return } + + // find the target PDNS instance based on the URL + dns := h.findDNSInstance(r.RequestURI) + log.Println("[Native API] User", commonName, "requested", r.Method, "on", r.RequestURI) - if !h.nativeValidAuth(strings.TrimPrefix(r.RequestURI, h.dns.apiURL+"/"), commonName, r.Method) { + if !h.nativeValidAuth(strings.TrimPrefix(r.RequestURI, dns.apiURL+"/"), commonName, r.Method) { http.Error(w, "The user "+commonName+" is not authorized to perform this action", 403) log.Println("[Native API] User ", commonName, " was not authorized to perform the action") return } - h.dns.Proxy(w, r) + dns.Proxy(w, r) } // GetZoneConfig returns a configuration for a new zone func (h *HTTPServer) GetZoneConfig(s string) (string, string, []string, JSONArray, bool, error) { valid := "" def := JSONArray{} + for zoneType, profile := range h.zoneProfiles { + // if there is only one profile, it's the good one + if len(h.zoneProfiles) == 1 { + return zoneType, profile.SOA, profile.NameServers, def, profile.AutoInc, nil + } for _, re := range profile.Regexp { if !re.MatchString(s) { continue @@ -557,7 +592,7 @@ func (h *HTTPServer) Run() { mux.HandleFunc("/stats/", h.nativeAPIServe) mux.HandleFunc("/style.css", h.nativeAPIServe) if _, err := os.Stat("./web"); err == nil && h.debug { - h.dns.LogDebug("serving local files") + h.dnsInstances[0].dns.LogDebug("serving local files") mux.Handle("/", http.FileServer(http.Dir("web"))) } else { mux.Handle("/", http.FileServer(http.FS(serverRoot))) diff --git a/jsonrpc.go b/jsonrpc.go index b93fb4a..45f913f 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -106,10 +106,10 @@ func (j JSONInput) String() string { } // SetDefaults modify the input by setting the right default parameters -func (ja JSONArray) SetDefaults(p *PowerDNS) JSONArray { +func (ja JSONArray) SetDefaults(h *HTTPServer) JSONArray { for i, j := range ja { if j.Params.TTL == 0 { - ja[i].Params.TTL = p.DefaultTTL + ja[i].Params.TTL = h.findDNSInstance(j.Params.Name).DefaultTTL } } return ja @@ -136,7 +136,8 @@ func (ja JSONArray) Run(h *HTTPServer, user string, wasArray, textOnly bool) str continue } listResult = append(listResult, JSONRPCNewError(-32000, j.ID, last.Error)) - h.dns.LogDebug(last.Error) + + h.findDNSInstance(j.Params.Name).LogDebug(last.Error) if !j.Params.IgnoreError { break } @@ -162,23 +163,31 @@ func (j *JSONInput) Run(h *HTTPServer, user string) []*JSONRPCResult { if err := j.Normalize(); err != nil { return append(ret, j.JSONRPCResult("", "", err)) } + + d := h.findDNSInstance(j.Params.Name) + switch j.Method { // list is a spacial case, it doesn't imply a DNSQuery() object case "list": - result, err := h.dns.ListZones(j.Params.Name) - // we apply the acl after the fact for the list method - result = j.FilterList(result) + var err error + total := DNSZones{} + for _, dns := range h.dnsInstances { + result, err := dns.dns.ListZones(j.Params.Name) + // we apply the acl after the fact for the list method + result = j.FilterList(result) - if err == nil && len(result) == 0 { - err = errors.New("Unknown domain") + if err == nil && len(result) == 0 { + err = errors.New("Unknown domain") + } + total = append(total, result...) } - res := j.JSONRPCResult(result.List("\n"), "", err) - for i := range result { - res.Raw = append(res.Raw, result[i].Name) + res := j.JSONRPCResult(total.List("\n"), "", err) + for i := range total { + res.Raw = append(res.Raw, total[i].Name) } return append(ret, res) case "domain": - parentName, err := h.dns.GetDomain(j.Params.Name) + parentName, err := d.GetDomain(j.Params.Name) res := j.JSONRPCResult(parentName, "", err) res.Raw = append(res.Raw, parentName) return append(ret, res) @@ -212,7 +221,7 @@ func (j *JSONInput) Run(h *HTTPServer, user string) []*JSONRPCResult { ret = append(ret, result) continue } - code, _, err := h.dns.Execute(act) + code, _, err := d.Execute(act) switch { case err == nil && code == 204: result.Result = "Command Successfull" @@ -277,15 +286,17 @@ func (j *JSONInput) Normalize() error { // pdns. It can change the content of j to force dry run mode func (j *JSONInput) DNSQueries(h *HTTPServer) ([]*DNSQuery, error) { var err error + + d := h.findDNSInstance(j.Params.Name) switch j.Method { case "search": - result, err := h.dns.Search(j.Params.Name) + result, err := d.Search(j.Params.Name) // we apply the acl after the fact for the search method result = j.FilterSearch(result) j.Params.DryRun = true return []*DNSQuery{result.DNSQuery()}, err case "dump": - result, err := h.dns.Zone(j.Params.Name) + result, err := d.Zone(j.Params.Name) j.Params.DryRun = true return []*DNSQuery{result}, err case "newzone": @@ -297,13 +308,13 @@ func (j *JSONInput) DNSQueries(h *HTTPServer) ([]*DNSQuery, error) { return append(otherActions, newZone.TransformIntoDNSQuery()), nil } result := &DNSQuery{} - code, _, err := h.dns.ExecuteZone(newZone, result) + code, _, err := d.ExecuteZone(newZone, result) if err == nil && code != 201 { err = fmt.Errorf("The return code was %d for the creation of zone %s", code, j.Params.Name) } return append(otherActions, result), err } - current, err := h.dns.GetRecord(j.Params.Name, j.Method, j.ignoreBadDomain) + current, err := d.GetRecord(j.Params.Name, j.Method, j.ignoreBadDomain) if err != nil { return nil, err } @@ -371,7 +382,8 @@ func (j *JSONInput) DNSQueriesNS(h *HTTPServer, current *DNSQuery) ([]*DNSQuery, if trimPoint(j.Params.Name) == trimPoint(current.Domain) { return nil, fmt.Errorf("You cannot change the NS of a local zone") } - subZone, err := h.dns.Zone(j.Params.Name) + d := h.findDNSInstance(j.Params.Name) + subZone, err := d.Zone(j.Params.Name) if err != nil { return nil, err } @@ -443,7 +455,8 @@ func (j *JSONInput) DNSQueriesTTL(current *DNSQuery) ([]*DNSQuery, error) { // DNSQueriesA is the DNSQueries method for the A and AAAA commands // it checks the command is legal, and can make reverse DNS changes if needed func (j *JSONInput) DNSQueriesA(h *HTTPServer, forward *DNSQuery) ([]*DNSQuery, error) { - reverse, err := h.dns.GetReverse(j.Params.Value, j.Method == "a") + d := h.findDNSInstance(j.Params.Name) + reverse, err := d.GetReverse(j.Params.Value, j.Method == "a") if err != nil { return nil, err } @@ -458,13 +471,14 @@ func (j *JSONInput) DNSQueriesA(h *HTTPServer, forward *DNSQuery) ([]*DNSQuery, forward.ChangeValue(j.Params.Name, j.Params.Value, j.Method, false, askForReverse) return []*DNSQuery{forward}, nil } - actions, err := h.dns.ReverseChanges(forward, j.Params.Value) + actions, err := d.ReverseChanges(forward, j.Params.Value) forward.ChangeValue(j.Params.Name, j.Params.Value, j.Method, true, askForReverse) return append(actions, forward), err } // DNSQueriesDelete is the DNSQueries method for the Delete command func (j *JSONInput) DNSQueriesDelete(h *HTTPServer, current *DNSQuery) ([]*DNSQuery, error) { + d := h.findDNSInstance(j.Params.Name) reverses, useful, err := current.SplitDeletionQuery(j.Params.Name, j.Params.Value) if err != nil { return nil, err @@ -477,7 +491,7 @@ func (j *JSONInput) DNSQueriesDelete(h *HTTPServer, current *DNSQuery) ([]*DNSQu ret := []*DNSQuery{current} // add the reverse changes if needed for _, r := range reverses { - parentName, err := h.dns.GetDomain(r.Name) + parentName, err := d.GetDomain(r.Name) if err != nil { return nil, err } @@ -486,7 +500,7 @@ func (j *JSONInput) DNSQueriesDelete(h *HTTPServer, current *DNSQuery) ([]*DNSQu continue } ip := ptrToIP(r.Name) - if h.dns.IsUsed(ip, j.Params.Name, []string{"A", "AAAA"}) { + if d.IsUsed(ip, j.Params.Name, []string{"A", "AAAA"}) { message := "Reverse issue : %s is the reverse for %s and will be removed\n" message += "But other records are pointing to %s as well. Please cleanup first\n" return ret, fmt.Errorf(message, j.Params.Name, ip, ip) @@ -497,12 +511,13 @@ func (j *JSONInput) DNSQueriesDelete(h *HTTPServer, current *DNSQuery) ([]*DNSQu // CheckPTR validates that the query is a valid PTR func (j *JSONInput) CheckPTR(h *HTTPServer) error { + d := h.findDNSInstance(j.Params.Name) ip := ptrToIP(j.Params.Name) if ip == "" { return fmt.Errorf("%s is not a valid PTR", j.Params.Name) } target := &DNSQuery{} - if err := h.dns.CanCreate(j.Params.Value, true, target); err != nil || target.Len() == 0 { + if err := d.CanCreate(j.Params.Value, true, target); err != nil || target.Len() == 0 { return err } for _, rec := range target.RRSets[0].Records { @@ -515,11 +530,12 @@ func (j *JSONInput) CheckPTR(h *HTTPServer) error { // CheckSRV validates that the query is a valid SRV func (j *JSONInput) CheckSRV(h *HTTPServer) error { + d := h.findDNSInstance(j.Params.Name) name := validSRV(j.Params.Value) if name == "" { return fmt.Errorf("%s is not a valid SRV", j.Params.Value) } - return h.dns.CanCreate(name, false, nil) + return d.CanCreate(name, false, nil) } // CheckCAA validates that the query is a valid CAA @@ -535,10 +551,11 @@ func (j *JSONInput) CheckCAA() error { // CheckMX validates that the query is a valid MX func (j *JSONInput) CheckMX(h *HTTPServer) error { name := validMX(j.Params.Value) + d := h.findDNSInstance(j.Params.Name) if name == "" { return fmt.Errorf("%s is not a valid MX", j.Params.Value) } - return h.dns.CanCreate(name, true, nil) + return d.CanCreate(name, true, nil) } // CheckCNAME validates that the query is a valid CNAME @@ -550,7 +567,8 @@ func (j *JSONInput) CheckCNAME(h *HTTPServer) error { if !validName(j.Params.Value, false) { return fmt.Errorf("%s is not a valid DNS Name", j.Params.Value) } - return h.dns.CanCreate(j.Params.Value, false, nil) + d := h.findDNSInstance(j.Params.Name) + return d.CanCreate(j.Params.Value, false, nil) } // NewZone create a new zone in the DNS. @@ -565,10 +583,11 @@ func (j *JSONInput) NewZone(h *HTTPServer) (z *DNSZone, otherActions []*DNSQuery if soa == "" { soa = fmt.Sprintf("%s hostmaster.%s 0 28800 7200 604800 86400", nameServers[0], j.Params.Name) } - parentName, zoneErr := h.dns.GetDomain(j.Params.Name) + dns := h.findDNSInstance(j.Params.Name) + parentName, zoneErr := dns.GetDomain(j.Params.Name) parentName = trimPoint(parentName) - z, err = h.dns.NewZone(j.Params.Name, zoneType, soa, j.user, j.Params.Comment, + z, err = dns.NewZone(j.Params.Name, zoneType, soa, j.user, j.Params.Comment, j.Params.TTL, nameServers, autoInc) if err != nil { return nil, nil, err @@ -589,7 +608,7 @@ func (j *JSONInput) NewZone(h *HTTPServer) (z *DNSZone, otherActions []*DNSQuery entry.Params.Name = fmt.Sprintf("%s.%s", d.Params.Name, j.Params.Name) } if err := entry.Normalize(); err != nil { - h.dns.LogDebug(err) + dns.LogDebug(err) continue } actions, err := entry.DNSQueries(h) @@ -605,14 +624,14 @@ func (j *JSONInput) NewZone(h *HTTPServer) (z *DNSZone, otherActions []*DNSQuery return } // we must create the NS records in the parent zone too - glue, err := h.dns.SetNameServers(j.Params.Name, parentName, nameServers) + glue, err := dns.SetNameServers(j.Params.Name, parentName, nameServers) if err != nil { return nil, nil, err } otherActions = append(otherActions, glue) // check if there are records in the parent zone - records, err := h.dns.Zone(j.Params.Name) + records, err := dns.Zone(j.Params.Name) if err != nil { err = nil return @@ -622,7 +641,7 @@ func (j *JSONInput) NewZone(h *HTTPServer) (z *DNSZone, otherActions []*DNSQuery z.AddEntries(records) // and delete them in the parent - delete, err := h.dns.Zone(j.Params.Name) + delete, err := dns.Zone(j.Params.Name) if err != nil { return nil, nil, err } @@ -677,16 +696,16 @@ func (j *JSONInput) FilterList(z []*DNSZone) []*DNSZone { } // ParsejsonRPCRequest read the payload of the query and put it in the structure -func ParsejsonRPCRequest(s []byte, d *PowerDNS) (JSONArray, bool, error) { +func ParsejsonRPCRequest(s []byte, h *HTTPServer) (JSONArray, bool, error) { var inSimple *JSONInput var inArray JSONArray if json.Unmarshal(s, &inArray); len(inArray) > 0 { - return inArray.SetDefaults(d), true, nil + return inArray.SetDefaults(h), true, nil } if err := json.Unmarshal(s, &inSimple); err != nil { return nil, false, err } - return JSONArray{inSimple}.SetDefaults(d), false, nil + return JSONArray{inSimple}.SetDefaults(h), false, nil } func isDryRun(j JSONArray) bool { diff --git a/pdns-proxy.conf.example b/pdns-proxy.conf.example index 116194c..a010ee0 100644 --- a/pdns-proxy.conf.example +++ b/pdns-proxy.conf.example @@ -70,8 +70,11 @@ config } pdns: { - api-key: "" - api-url: "http://127.0.0.1:8081/api/v1/servers/localhost" + myServer: + { + api-key: "" + api-url: "http://127.0.0.1:8081/api/v1/servers/localhost" + } } zoneProfile: { diff --git a/test_todo/pdns-proxy-test.conf b/test_todo/pdns-proxy-test.conf index e1bd0b0..c720b91 100644 --- a/test_todo/pdns-proxy-test.conf +++ b/test_todo/pdns-proxy-test.conf @@ -76,14 +76,15 @@ config key: "fixtures/test/server-key.pem" cert: "fixtures/test/server-cert.pem" } - pdns + pdns: { - api-key: "123password" - api-url: "http://127.0.0.1:8081/api/v1/servers/localhost" - timeout: 300 - defaultTTL: 172800 - - + instance: + { + api-key: "123password" + api-url: "http://127.0.0.1:8081/api/v1/servers/localhost" + timeout: 300 + defaultTTL: 172800 + } } zoneProfile { diff --git a/test_todo/pdns-proxy-unit-test.conf b/test_todo/pdns-proxy-unit-test.conf index e46ab75..144083d 100644 --- a/test_todo/pdns-proxy-unit-test.conf +++ b/test_todo/pdns-proxy-unit-test.conf @@ -65,9 +65,12 @@ config } pdns: { - api-key: "123password" - api-url: "http://127.0.0.1:8081/api/v1/servers/localhost" - defaultTTL: 172800 + instance: + { + api-key: "123password" + api-url: "http://127.0.0.1:8081/api/v1/servers/localhost" + defaultTTL: 172800 + } } zoneProfile: {