pdns-auth-proxy/pdns_http.go

188 lines
4.7 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
type (
// APIKeyTransport definition
APIKeyTransport struct {
Transport http.RoundTripper
APIKey string
}
// Handle the answer from the PowerDNS API
httpStatusError struct {
Response *http.Response
Body []byte // Copied from `Response.Body` to avoid problems with unclosed bodies later.
}
// ErrorTransport defines the data structure for returning errors from the round tripper
ErrorTransport struct {
Transport http.RoundTripper
}
)
// RoundTrip implements http.RoundTripper for ApiKeyTransport (RoundTrip method)
func (t *APIKeyTransport) RoundTrip(request *http.Request) (*http.Response, error) {
request.Header.Set("X-API-Key", fmt.Sprintf("%s", t.APIKey))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
return t.Transport.RoundTrip(request)
}
func (err *httpStatusError) Error() string {
return fmt.Sprintf("http: non-successful response (status=%v body=%q)", err.Response.StatusCode, err.Body)
}
// RoundTrip is the round tripper for the error transport
func (t *ErrorTransport) RoundTrip(request *http.Request) (*http.Response, error) {
resp, err := t.Transport.RoundTrip(request)
if err != nil {
return resp, err
}
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusUnauthorized {
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("http: failed to read response body (status=%v, err=%q)", resp.StatusCode, err)
}
return nil, &httpStatusError{
Response: resp,
Body: body,
}
}
return resp, err
}
// Proxy is a HTTP Handler which will send the queries directly to PDNS
func (p *PowerDNS) Proxy(w http.ResponseWriter, r *http.Request) {
// ctx is the Context for this handler.
var (
ctx context.Context
cancel context.CancelFunc
)
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
// Thanks to context, we could do some req checks before submitting it to powerDNS client
// Maybe later
var response []byte
// Get body if provided
body, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, "Cannot read body", 400)
return
}
if r.ContentLength == 0 {
body = nil
}
uri := r.RequestURI
stats := false
if strings.HasPrefix(uri, "/stats/") {
stats = true
uri = strings.Replace(uri, "/stats/", "/", -1)
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
_, head, err := p.sendQuery(ctx, uri, r.Method, body, &response)
if err == nil {
// copy the headers
for name := range head {
if stats && name == "Content-Length" {
continue
}
w.Header().Set(name, head.Get(name))
}
stringRet := string(response)
if stats {
stringRet = strings.Replace(stringRet, `href="/"`, `href="/stats/"`, -1)
}
fmt.Fprintf(w, "%s", stringRet)
return
}
urlErr, ok := err.(*url.Error)
if !ok {
fmt.Fprintf(w, "Response is: %v", err)
return
}
httpErr, ok := urlErr.Err.(*httpStatusError)
if !ok {
fmt.Fprintf(w, "Response is: %v", urlErr)
return
}
fmt.Fprintf(w, "Response is: %v", httpErr)
}
// sendQuery returns the anwser of a request made to the API
func (p *PowerDNS) sendQuery(ctx context.Context, endpoint, method string, body []byte, response interface{}) (int, http.Header, error) {
var req *http.Request
var err error
u := url.URL{}
u.Host = p.Hostname + ":" + p.Port
u.Scheme = p.Scheme
if !strings.HasPrefix(endpoint, "/") {
u.Path = "/"
}
uri := u.String() + endpoint
p.LogDebug(fmt.Sprintf("pdns.client %s", uri))
if body != nil {
p.LogDebug(string(body))
}
if body != nil {
req, err = http.NewRequest(method, uri, bytes.NewReader(body))
} else {
req, err = http.NewRequest(method, uri, nil)
}
if err != nil {
return 500, nil, err
}
resp, err := p.Client.Do(req.WithContext(ctx))
if err != nil {
return 500, nil, err
}
defer resp.Body.Close()
p.LogDebug(fmt.Sprintf(" resp.Status=%s", resp.Status))
raw, _ := ioutil.ReadAll(resp.Body)
if response == nil {
return resp.StatusCode, resp.Header, nil
}
if err := json.Unmarshal(raw, response); err == nil {
return resp.StatusCode, resp.Header, nil
}
if ar, ok := response.(*[]byte); ok {
*ar = raw
}
p.LogDebug(string(raw))
return resp.StatusCode, resp.Header, nil
}
// parseBaseURL decompose an URL into parts
func parseBaseURL(baseURL string) (string, string, string, string, error) {
u, err := url.Parse(baseURL)
if err != nil {
return "", "", "", "", err
}
hp := strings.Split(u.Host, ":")
hostname := hp[0]
var port string
if len(hp) > 1 {
port = hp[1]
} else {
if u.Scheme == "https" {
port = "443"
} else {
port = "8081"
}
}
return u.Scheme, hostname, port, u.Path, nil
}