188 lines
4.7 KiB
Go
188 lines
4.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
)
|
|
|
|
type (
|
|
// APIKeyTransport definition
|
|
APIKeyTransport struct {
|
|
Transport http.RoundTripper
|
|
APIKey string
|
|
}
|
|
|
|
// Handle the answer from the PowerDNS API
|
|
httpStatusError struct {
|
|
Response *http.Response
|
|
Body []byte // Copied from `Response.Body` to avoid problems with unclosed bodies later.
|
|
}
|
|
|
|
// ErrorTransport defines the data structure for returning errors from the round tripper
|
|
ErrorTransport struct {
|
|
Transport http.RoundTripper
|
|
}
|
|
)
|
|
|
|
// RoundTrip implements http.RoundTripper for ApiKeyTransport (RoundTrip method)
|
|
func (t *APIKeyTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
|
request.Header.Set("X-API-Key", fmt.Sprintf("%s", t.APIKey))
|
|
request.Header.Set("Content-Type", "application/json")
|
|
request.Header.Set("Accept", "application/json")
|
|
return t.Transport.RoundTrip(request)
|
|
}
|
|
|
|
func (err *httpStatusError) Error() string {
|
|
return fmt.Sprintf("http: non-successful response (status=%v body=%q)", err.Response.StatusCode, err.Body)
|
|
}
|
|
|
|
// RoundTrip is the round tripper for the error transport
|
|
func (t *ErrorTransport) RoundTrip(request *http.Request) (*http.Response, error) {
|
|
resp, err := t.Transport.RoundTrip(request)
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusUnauthorized {
|
|
defer resp.Body.Close()
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("http: failed to read response body (status=%v, err=%q)", resp.StatusCode, err)
|
|
}
|
|
return nil, &httpStatusError{
|
|
Response: resp,
|
|
Body: body,
|
|
}
|
|
}
|
|
return resp, err
|
|
}
|
|
|
|
// Proxy is a HTTP Handler which will send the queries directly to PDNS
|
|
func (p *PowerDNS) Proxy(w http.ResponseWriter, r *http.Request) {
|
|
// ctx is the Context for this handler.
|
|
var (
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
)
|
|
ctx, cancel = context.WithCancel(context.Background())
|
|
defer cancel()
|
|
// Thanks to context, we could do some req checks before submitting it to powerDNS client
|
|
// Maybe later
|
|
|
|
var response []byte
|
|
// Get body if provided
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, "Cannot read body", 400)
|
|
return
|
|
}
|
|
if r.ContentLength == 0 {
|
|
body = nil
|
|
}
|
|
uri := r.RequestURI
|
|
stats := false
|
|
if strings.HasPrefix(uri, "/stats/") {
|
|
stats = true
|
|
uri = strings.Replace(uri, "/stats/", "/", -1)
|
|
}
|
|
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
|
_, head, err := p.sendQuery(ctx, uri, r.Method, body, &response)
|
|
if err == nil {
|
|
// copy the headers
|
|
for name := range head {
|
|
if stats && name == "Content-Length" {
|
|
continue
|
|
}
|
|
w.Header().Set(name, head.Get(name))
|
|
}
|
|
stringRet := string(response)
|
|
if stats {
|
|
stringRet = strings.Replace(stringRet, `href="/"`, `href="/stats/"`, -1)
|
|
}
|
|
fmt.Fprintf(w, "%s", stringRet)
|
|
return
|
|
}
|
|
urlErr, ok := err.(*url.Error)
|
|
if !ok {
|
|
fmt.Fprintf(w, "Response is: %v", err)
|
|
return
|
|
}
|
|
httpErr, ok := urlErr.Err.(*httpStatusError)
|
|
if !ok {
|
|
fmt.Fprintf(w, "Response is: %v", urlErr)
|
|
return
|
|
}
|
|
fmt.Fprintf(w, "Response is: %v", httpErr)
|
|
}
|
|
|
|
// sendQuery returns the anwser of a request made to the API
|
|
func (p *PowerDNS) sendQuery(ctx context.Context, endpoint, method string, body []byte, response interface{}) (int, http.Header, error) {
|
|
var req *http.Request
|
|
var err error
|
|
u := url.URL{}
|
|
u.Host = p.Hostname + ":" + p.Port
|
|
u.Scheme = p.Scheme
|
|
if !strings.HasPrefix(endpoint, "/") {
|
|
u.Path = "/"
|
|
}
|
|
uri := u.String() + endpoint
|
|
|
|
p.LogDebug(fmt.Sprintf("pdns.client %s", uri))
|
|
if body != nil {
|
|
p.LogDebug(string(body))
|
|
}
|
|
if body != nil {
|
|
req, err = http.NewRequest(method, uri, bytes.NewReader(body))
|
|
} else {
|
|
req, err = http.NewRequest(method, uri, nil)
|
|
}
|
|
if err != nil {
|
|
return 500, nil, err
|
|
}
|
|
resp, err := p.Client.Do(req.WithContext(ctx))
|
|
if err != nil {
|
|
return 500, nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
p.LogDebug(fmt.Sprintf(" resp.Status=%s", resp.Status))
|
|
raw, _ := ioutil.ReadAll(resp.Body)
|
|
|
|
if response == nil {
|
|
return resp.StatusCode, resp.Header, nil
|
|
}
|
|
if err := json.Unmarshal(raw, response); err == nil {
|
|
return resp.StatusCode, resp.Header, nil
|
|
}
|
|
if ar, ok := response.(*[]byte); ok {
|
|
*ar = raw
|
|
}
|
|
p.LogDebug(string(raw))
|
|
return resp.StatusCode, resp.Header, nil
|
|
}
|
|
|
|
// parseBaseURL decompose an URL into parts
|
|
func parseBaseURL(baseURL string) (string, string, string, string, error) {
|
|
u, err := url.Parse(baseURL)
|
|
if err != nil {
|
|
return "", "", "", "", err
|
|
}
|
|
hp := strings.Split(u.Host, ":")
|
|
hostname := hp[0]
|
|
var port string
|
|
if len(hp) > 1 {
|
|
port = hp[1]
|
|
} else {
|
|
if u.Scheme == "https" {
|
|
port = "443"
|
|
} else {
|
|
port = "8081"
|
|
}
|
|
}
|
|
return u.Scheme, hostname, port, u.Path, nil
|
|
}
|