package main import ( "bytes" "context" "encoding/json" "fmt" "io/ioutil" "net/http" "net/url" "strings" ) type ( // APIKeyTransport definition APIKeyTransport struct { Transport http.RoundTripper APIKey string } // Handle the answer from the PowerDNS API httpStatusError struct { Response *http.Response Body []byte // Copied from `Response.Body` to avoid problems with unclosed bodies later. } // ErrorTransport defines the data structure for returning errors from the round tripper ErrorTransport struct { Transport http.RoundTripper } ) // RoundTrip implements http.RoundTripper for ApiKeyTransport (RoundTrip method) func (t *APIKeyTransport) RoundTrip(request *http.Request) (*http.Response, error) { request.Header.Set("X-API-Key", fmt.Sprintf("%s", t.APIKey)) request.Header.Set("Content-Type", "application/json") request.Header.Set("Accept", "application/json") return t.Transport.RoundTrip(request) } func (err *httpStatusError) Error() string { return fmt.Sprintf("http: non-successful response (status=%v body=%q)", err.Response.StatusCode, err.Body) } // RoundTrip is the round tripper for the error transport func (t *ErrorTransport) RoundTrip(request *http.Request) (*http.Response, error) { resp, err := t.Transport.RoundTrip(request) if err != nil { return resp, err } if resp.StatusCode >= 500 || resp.StatusCode == http.StatusUnauthorized { defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("http: failed to read response body (status=%v, err=%q)", resp.StatusCode, err) } return nil, &httpStatusError{ Response: resp, Body: body, } } return resp, err } // Proxy is a HTTP Handler which will send the queries directly to PDNS func (p *PowerDNS) Proxy(w http.ResponseWriter, r *http.Request) { // ctx is the Context for this handler. var ( ctx context.Context cancel context.CancelFunc ) ctx, cancel = context.WithCancel(context.Background()) defer cancel() // Thanks to context, we could do some req checks before submitting it to powerDNS client // Maybe later var response []byte // Get body if provided body, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, "Cannot read body", 400) return } if r.ContentLength == 0 { body = nil } uri := r.RequestURI stats := false if strings.HasPrefix(uri, "/stats/") { stats = true uri = strings.Replace(uri, "/stats/", "/", -1) } r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) _, head, err := p.sendQuery(ctx, uri, r.Method, body, &response) if err == nil { // copy the headers for name := range head { if stats && name == "Content-Length" { continue } w.Header().Set(name, head.Get(name)) } stringRet := string(response) if stats { stringRet = strings.Replace(stringRet, `href="/"`, `href="/stats/"`, -1) } fmt.Fprintf(w, "%s", stringRet) return } urlErr, ok := err.(*url.Error) if !ok { fmt.Fprintf(w, "Response is: %v", err) return } httpErr, ok := urlErr.Err.(*httpStatusError) if !ok { fmt.Fprintf(w, "Response is: %v", urlErr) return } fmt.Fprintf(w, "Response is: %v", httpErr) } // sendQuery returns the anwser of a request made to the API func (p *PowerDNS) sendQuery(ctx context.Context, endpoint, method string, body []byte, response interface{}) (int, http.Header, error) { var req *http.Request var err error u := url.URL{} u.Host = p.Hostname + ":" + p.Port u.Scheme = p.Scheme if !strings.HasPrefix(endpoint, "/") { u.Path = "/" } uri := u.String() + endpoint p.LogDebug(fmt.Sprintf("pdns.client %s", uri)) if body != nil { p.LogDebug(string(body)) } if body != nil { req, err = http.NewRequest(method, uri, bytes.NewReader(body)) } else { req, err = http.NewRequest(method, uri, nil) } if err != nil { return 500, nil, err } resp, err := p.Client.Do(req.WithContext(ctx)) if err != nil { return 500, nil, err } defer resp.Body.Close() p.LogDebug(fmt.Sprintf(" resp.Status=%s", resp.Status)) raw, _ := ioutil.ReadAll(resp.Body) if response == nil { return resp.StatusCode, resp.Header, nil } if err := json.Unmarshal(raw, response); err == nil { return resp.StatusCode, resp.Header, nil } if ar, ok := response.(*[]byte); ok { *ar = raw } p.LogDebug(string(raw)) return resp.StatusCode, resp.Header, nil } // parseBaseURL decompose an URL into parts func parseBaseURL(baseURL string) (string, string, string, string, error) { u, err := url.Parse(baseURL) if err != nil { return "", "", "", "", err } hp := strings.Split(u.Host, ":") hostname := hp[0] var port string if len(hp) > 1 { port = hp[1] } else { if u.Scheme == "https" { port = "443" } else { port = "8081" } } return u.Scheme, hostname, port, u.Path, nil }