pdns-auth-proxy/client/main.go

720 lines
20 KiB
Go

package main
import (
"bytes"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"os/exec"
"regexp"
"sort"
"strconv"
"strings"
"time"
)
const (
// the default API url
defaultAPI = "https://localhost:8443"
)
type (
// json definitions for the web API
jsonArray []*jsonCommand
jsonCommand struct {
Method string `json:"method"`
ID int `json:"id"`
JSONRPC string `json:"jsonrpc"`
Params jsonCommandParams `json:"params"`
args []string
}
jsonCommandParams struct {
Name string `json:"name"`
Value string `json:"value"`
TTL int `json:"ttl"`
ForceReverse bool `json:"reverse"`
Append bool `json:"append"`
Comment string `json:"comment"`
DryRun bool `json:"dry-run"`
IgnoreError bool `json:"ignore-error"`
Nonce string `json:"nonce"`
Debug bool `json:"-"`
}
apiResponse struct {
ID int `json:"id"`
Result []struct {
Changes string `json:"changes"`
Comment string `json:"comment"`
Result string `json:"result"`
} `json:"result"`
Error struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
// structure for command line options
command struct {
args []string
descr string
options []string
check func(*jsonCommand) bool
}
)
// global vars
var (
dnsAPI string
httpClient http.Client
commands map[string]command
generalFlags []string
)
// Populate the generalFlags, commands and dnsAPI variables
func init() {
generalFlags = []string{"dry-run", "comment"}
commands = map[string]command{
"a": {
[]string{"<record>", "<ipv4>"},
"Create a A record, points to an IPv4. If the IP has no reverse, create it with the record value (won't apply for wildcard)",
[]string{"append", "reverse", "ttl"}, cmdCheckIPv4},
"aaaa": {
[]string{"<record>", "<ipv6>"},
"Create a AAAA record, points to an IPv6. If the IP has no reverse, create it with the record value (won't apply for wildcard)",
[]string{"append", "reverse", "ttl"}, cmdCheckIPv6},
"cname": {
[]string{"<record>", "<destination>"},
"Create a CNAME record, points to another name. If that name is managed by the DNS server, it must exist",
[]string{"ttl"}, nil},
"dname": {
[]string{"<record>", "<destination>"},
"Create a DNAME record, points to another name. If that name is managed by the DNS server, it must exist",
[]string{"ttl"}, nil},
"caa": {
[]string{"<domain>", "<flag>", "<tag>", "<value>"},
"Create a CAA record. If value contains spaces, it must be quoted",
[]string{"append", "ttl"}, cmdCheckCAA},
"srv": {
[]string{"<_service._proto.name>", "<priority>", "<weight>", "<port>", "<target>"},
"Create a SRV record. https://en.wikipedia.org/wiki/SRV_record",
[]string{"append", "ttl"}, cmdCheckSRV},
"txt": {
[]string{"<record>", "<\"text\">"},
"Create a TXT record. No need to escape quotes, it will be done automatically on the server",
[]string{"append", "ttl"}, nil},
"mx": {
[]string{"<record>", "<priority>", "<mail-server>"},
"Create a MX record. mail-server must exist",
[]string{"append", "ttl"}, cmdCheckMX},
"ns": {
[]string{"<record>", "<dns-server>"},
"Create a NS record. dns-server must exist and must be an external server",
[]string{"append", "ttl"}, nil},
"ptr": {
[]string{"<ip | something.arpa>", "<name>"},
"Add a PTR record. If the first argument is an IP, will convert it to an .arpa name",
[]string{"append", "ttl"}, cmdCheckPTR},
"delete": {
[]string{"<record>", "[value]"},
"Remove the record. If value is not provided, remove all records of all types sharing the name. If value is specified, only remove that particuliar value. If a deleted record corresponds to a matching reverse, will remove the reverse too",
[]string{}, nil},
"ttl": {
[]string{"<record>", "[value]", "<ttl>"},
"Change the TTL of a record. If value is not provided, change the TTL of all types sharing the name. If value is specified, only change that particuliar TTL",
[]string{}, cmdCheckTTL},
"newzone": {
[]string{"<zone>"},
"Add a new zone. The zone will be private for a private IPs reverse zone. NS and SOA options are automatically changed",
[]string{}, nil},
"search": {
[]string{"<query>"},
"Search the pdns database",
[]string{}, nil},
"dump": {
[]string{"<zone>"},
"Display the zone",
[]string{}, nil},
"list": {
[]string{"[regexp]"},
"List zones, filter on regexp if provided",
[]string{}, nil},
"batch": {
[]string{"<file>"},
"Batch mode: file is a jsonRPC 2.0 complient json file with all commands you want to execute. Example : \n [\n { \"jsonrpc\": \"2.0\", \"id\": 0, \"method\": \"list\" },\n { \"jsonrpc\": \"2.0\", \"id\": 1, \"method\": \"newzone\", \"params\": { \"name\": \"example.com\", \"ignore-error\": true } },\n { \"jsonrpc\": \"2.0\", \"id\": 2, \"method\": \"a\", \"params\": {\n \"comment\": \"it's the fault\", \"name\": \"toto.example.com\",\n \"value\": \"192.0.2.1\" } }\n ]\n By default, an error will stop the batch. Use the boolean ignore-error to change the behaviour for a particular line. The comment and ttl switches are applied only if no explicit value is provided in a line",
[]string{"ttl"}, nil},
}
// we can override the default API url (for dev/debug). In this case,
// ignore the SSL
dnsAPI = defaultAPI
if os.Getenv("DNS_API") != "" {
dnsAPI = os.Getenv("DNS_API")
}
// ignore security for localhost
if strings.HasPrefix(dnsAPI, "https://localhost") {
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
// set the http timeout
httpClient.Timeout = time.Duration(60 * time.Second)
flag.Usage = Usage
}
// cmdCheckIPv4 is the validation function for the "a" action.
// Checks the 2nd argument is an IPv4 address
func cmdCheckIPv4(j *jsonCommand) bool {
ip := net.ParseIP(j.args[2])
return ip != nil && ip.To4() != nil
}
// cmdCheckIPv6 is the validation function for the "aaaa" action.
// Checks the 2nd argument is an IPv6 address
func cmdCheckIPv6(j *jsonCommand) bool {
ip := net.ParseIP(j.args[2])
return ip != nil && ip.To4() == nil
}
// cmdCheckPTR is the validation function for the "reverse" action.
// Checks the 1st argument is an IP or a valid .arpa name,
// converts the first argument to the valid .arpa name if necessary,
// checks that the second argument points back to the first
func cmdCheckPTR(j *jsonCommand) bool {
var ip net.IP
j.args[1] = strings.Trim(j.args[1], ".")
j.args[2] = strings.Trim(j.args[2], ".")
if strings.HasSuffix(j.args[1], ".arpa") {
ip = net.ParseIP(ptrToIP(j.args[1]))
} else {
ip = net.ParseIP(j.args[1])
}
// if ip is not valid, stop here
if ip == nil {
fmt.Fprintf(os.Stderr, "Error: %s is not a valid IP or reverse\n\n", j.args[1])
return false
}
j.Params.Name = iPtoReverse(ip)
strIP := ip.String()
names, err := net.LookupHost(j.args[2])
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: %s must exist\n\n", j.args[2])
return true
}
cname, err := net.LookupCNAME(j.args[2])
cname = strings.Trim(cname, ".")
if err == nil && cname != j.args[2] {
fmt.Fprintf(os.Stderr, "Error: %s cannot be a CNAME\n\n", j.args[2])
return false
}
for _, n := range names {
if n == strIP {
return true
}
}
fmt.Fprintf(os.Stderr, "Warning: %s must point to %s\n\n", j.args[2], strIP)
return true
}
// cmdCheckCAA is the validation function for the "caa" action.
func cmdCheckCAA(j *jsonCommand) bool {
const q = "\""
var err error
var tag int
if tag, err = strconv.Atoi(j.args[2]); err != nil || tag < 0 || tag > 255 {
return false
}
j.args[4] = strings.Trim(j.args[4], q)
j.args[4] = strings.Replace(j.args[4], q, "\\\"", -1)
j.args[4] = q + j.args[4] + q
j.Params.Value = fmt.Sprintf("%d %s %s", tag, j.args[3], j.args[4])
return true
}
// cmdCheckSRV is the validation function for the "srv" action.
// Checks that the arguments are valid, and put them all in arg[2]
func cmdCheckSRV(j *jsonCommand) bool {
var err error
var prio, weight, port int
validSRV := regexp.MustCompile("^_[a-z0-9]+\\._(tcp|udp|tls)[^ ]+$")
if !validSRV.MatchString(j.args[1]) {
return false
}
if prio, err = strconv.Atoi(j.args[2]); err != nil || prio < 0 {
return false
}
if weight, err = strconv.Atoi(j.args[3]); err != nil || weight < 0 {
return false
}
if port, err = strconv.Atoi(j.args[4]); err != nil || port < 1 {
return false
}
j.Params.Value = fmt.Sprintf("%d %d %d %s", prio, weight, port, j.args[5])
return true
}
// cmdCheckMX is the validation function for the "mx" action.
// Checks that the arguments are a weight and a mail servers
func cmdCheckMX(j *jsonCommand) bool {
var err error
var prio int
if prio, err = strconv.Atoi(j.args[2]); err != nil || prio < 0 {
return false
}
j.Params.Value = fmt.Sprintf("%d %s", prio, j.args[3])
return true
}
// cmdCheckTTL is the validation function for the "ttl" action.
func cmdCheckTTL(j *jsonCommand) bool {
var err error
// The args for this command are "<record>", "[value]", "<ttl>" (the
// middle argument is optional).
n := len(j.args)
if j.Params.TTL, err = strconv.Atoi(j.args[n-1]); err != nil || j.Params.TTL < 0 {
return false
}
// Default value if the optional argument is not given
if n < 4 {
j.Params.Value = ""
}
return true
}
// Output the usage of this specific command
func (c *command) Usage(name string) string {
ret := fmt.Sprintf(" %s [options] %s %s : \n %s\n Options:\n",
os.Args[0], name, c.formatArgs(), c.formatDescr())
findSame := map[string]string{}
alt := map[string]string{}
for _, opt := range append(generalFlags, c.options...) {
fl := flag.Lookup(opt)
findSame[fl.Usage] = opt
}
flag.VisitAll(func(f *flag.Flag) {
if opt, ok := findSame[f.Usage]; ok && opt != f.Name {
alt[opt] = f.Name
}
})
for _, opt := range append(generalFlags, c.options...) {
var short, arg, defValue string
f := flag.Lookup(opt)
if s, ok := alt[opt]; ok {
short = fmt.Sprintf("-%s,", s)
}
if f.DefValue != "false" && f.DefValue != "true" {
arg = "<value>"
defValue = fmt.Sprintf(", default is %s", f.DefValue)
}
if f.DefValue == "" {
defValue += "\"\""
}
ret += fmt.Sprintf(" %-3v --%-7v %-7v : %s%s\n",
short, f.Name, arg, f.Usage, defValue)
}
return ret
}
// CheckArgs populate the json structure with the given arguments and check
// they are valid
func (c *command) CheckArgs(j *jsonCommand) bool {
args := flag.Args()
min := 1
for _, a := range c.args {
if a[0] == '<' {
min++
}
}
if len(args) > len(c.args)+1 {
return false
}
if len(args) < min {
return false
}
// per construction, there is always at least 2 arguments and the commands
// in args
if len(args) > 1 {
j.Params.Name = args[1]
}
if len(args) > 2 {
j.Params.Value = args[2]
}
// use the adapted check function. It can modify the structure
if c.check != nil {
return c.check(j)
}
return true
}
// SetDryRun set the dry-run flag on every entry
func (ja jsonArray) SetDryRun() {
for i := range ja {
ja[i].Params.DryRun = true
}
}
// SetNonce gets a nonce from the API, and stores it in the structure.
// The nonce is valid for 10 min, to avoid replay attacks
func (ja jsonArray) SetNonce() error {
var valid = regexp.MustCompile(`^[A-Za-z0-9+/]*$`)
req, err := http.NewRequest("GET", fmt.Sprintf("%s/nonce", dnsAPI), nil)
if err != nil {
return err
}
resp, err := httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, _ := ioutil.ReadAll(resp.Body)
if valid.MatchString(string(body)) {
ja[0].Params.Nonce = string(body)
return nil
}
return errors.New("Cannot get Nonce : Invalid response")
}
// SetArgs copy the argument from the command line into the structure
func (ja *jsonArray) SetArgs(j *jsonCommand, args []string) error {
if len(args) < 1 {
return errors.New("Not enough args")
}
cmd, ok := commands[args[0]]
if !ok {
return fmt.Errorf("Unknown command %s", args[0])
}
j.Method = args[0]
j.args = args
j.JSONRPC = "2.0"
j.ID = 1
// check the arguments
if !cmd.CheckArgs(j) {
Usage()
}
// evacuate the simple case first
if args[0] != "batch" {
*ja = append(*ja, j)
return nil
}
batch, err := ioutil.ReadFile(args[1])
if err != nil {
return err
}
if err := json.Unmarshal(batch, ja); err != nil {
return err
}
for i := range *ja {
(*ja)[i].JSONRPC = j.JSONRPC
(*ja)[i].ID = i
if (*ja)[i].Params.TTL == 0 {
(*ja)[i].Params.TTL = j.Params.TTL
}
if (*ja)[i].Params.Comment == "" {
(*ja)[i].Params.Comment = j.Params.Comment
}
if (*ja)[i].Params.DryRun == false {
(*ja)[i].Params.DryRun = j.Params.DryRun
}
}
return nil
}
// Usage print the usage for all actions, or just the one used
func Usage() {
fmt.Fprintln(flag.CommandLine.Output(), "Usage: ")
if flag.NArg() >= 1 {
name := flag.Arg(0)
if cmd, ok := commands[name]; ok {
fmt.Fprintf(flag.CommandLine.Output(), "%s\n\n", cmd.Usage(name))
os.Exit(1)
}
}
s := []string{}
for name := range commands {
s = append(s, name)
}
sort.Strings(s)
for _, name := range s {
cmd := commands[name]
fmt.Fprintf(flag.CommandLine.Output(), "%s\n\n", cmd.Usage(name))
}
os.Exit(1)
}
// GetYubikey try to find a PGP Smartcard and return its ID
func GetYubikey() (string, error) {
out, err := exec.Command("gpg", "--card-status", "--with-colons").CombinedOutput()
if err != nil {
return "", err
}
for _, line := range strings.Split(string(out), "\n") {
if strings.HasPrefix(line, "fpr:") {
return strings.Split(line, ":")[1], nil
}
}
return "", errors.New("Yubikey issue (is it plugged in?)")
}
// Sign runs the "gpg --clear-sign" command on the input.
// If the key ID is empty, it will return the input address
func Sign(payload []byte, gpgKey string) ([]byte, error) {
if gpgKey == "" {
return payload, nil
}
cmd := exec.Command("gpg", "--clearsign", "-u", gpgKey)
stdin, err := cmd.StdinPipe()
if err != nil {
return payload, err
}
stdin.Write(payload)
stdin.Close()
out, err := cmd.CombinedOutput()
if err != nil {
return payload, err
}
return out, nil
}
// sendQuery sends the payload to the API server
func sendQuery(payload []byte) ([]*apiResponse, error) {
apiRespSimple := &apiResponse{}
apiRespArray := []*apiResponse{}
req, err := http.NewRequest("POST", fmt.Sprintf("%s/jsonrpc", dnsAPI), bytes.NewBuffer(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "text/plain")
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
s, _ := ioutil.ReadAll(resp.Body)
if json.Unmarshal(s, &apiRespArray); len(apiRespArray) > 0 {
return apiRespArray, nil
}
if err := json.Unmarshal(s, apiRespSimple); err != nil {
return nil, err
}
return append(apiRespArray, apiRespSimple), nil
}
func main() {
var j jsonCommand
var jsonStruct jsonArray
// parse the command line
flag.BoolVar(&j.Params.ForceReverse, "reverse", false, "Force the creation of a reverse")
flag.BoolVar(&j.Params.ForceReverse, "r", false, "Force the creation of a reverse")
flag.BoolVar(&j.Params.DryRun, "dry-run", false, "Explain what the command would do, but do nothing")
flag.BoolVar(&j.Params.DryRun, "n", false, "Explain what the command would do, but do nothing")
flag.StringVar(&j.Params.Comment, "comment", "", "Add a comment to the operation")
flag.StringVar(&j.Params.Comment, "c", "", "Add a comment to the operation")
flag.IntVar(&j.Params.TTL, "ttl", 172800, "Specify the TTL")
flag.IntVar(&j.Params.TTL, "t", 172800, "Specify the TTL")
flag.BoolVar(&j.Params.Append, "append", false, "Append the value, don't replace the whole record")
flag.BoolVar(&j.Params.Append, "a", false, "Append the value, don't replace the whole record")
flag.BoolVar(&j.Params.Debug, "debug", false, "Add debug info")
flag.BoolVar(&j.Params.Debug, "d", false, "Add debug info")
flag.Parse()
// copy and check the arguments
if err := jsonStruct.SetArgs(&j, flag.Args()); err != nil {
fmt.Fprintln(os.Stderr, err)
Usage()
}
// now we have the right number of arguments, and a valid command
// Add Nonce
if err := jsonStruct.SetNonce(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(5)
}
// get the GPG Public Key from DM_GPG_KEY env var (usefull in case of several Yubikey)
var gpgEnvKey = os.Getenv("DM_GPG_KEY")
// select Key from env if define
var gpgKey string
if gpgEnvKey != "" {
gpgKey = gpgEnvKey
// or the Yubikey
} else {
// get the Yubikey Public Key
gpgYubikey, err := GetYubikey()
if err != nil {
fmt.Fprintln(os.Stderr, "Warning, no Yubikey; forcing dry-run mode")
jsonStruct.SetDryRun()
}
gpgKey = gpgYubikey
}
// Transform into json
jsonStr, err := json.MarshalIndent(jsonStruct, " ", " ")
if err != nil {
fmt.Fprintln(os.Stderr, "panic")
os.Exit(6)
}
// debug mode
if j.Params.Debug {
fmt.Println(string(jsonStr))
}
// Sign
signed, err := Sign(jsonStr, gpgKey)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(7)
}
// Send to the server
ret, err := sendQuery(signed)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(9)
}
for _, response := range ret {
if response.Error.Message != "" {
fmt.Fprintln(os.Stderr, "; "+
strings.Replace(response.Error.Message, "\n", "\n; ", -1)+"\n")
continue
}
for _, result := range response.Result {
if strings.Trim(result.Comment, " \n") != "" {
fmt.Fprintln(os.Stderr, "; "+
strings.Replace(result.Comment, "\n", "\n; ", -1)+"\n")
if result.Result != "" {
fmt.Fprintln(os.Stderr, "; "+result.Result)
}
}
if result.Changes != "" {
fmt.Println(result.Changes)
}
}
}
}
// formatDescr outputs the description while adding CR when the line is too long
func (c *command) formatDescr() string {
max := 70
ret := ""
current := 0
for _, word := range strings.Split(c.descr, " ") {
if strings.Contains(word, "\n") {
current = 0
ret += word + " "
continue
}
if current+len(word) > max {
ret += "\n "
current = 0
}
current += len(word + " ")
ret += word + " "
}
ret = strings.Trim(ret, " ")
return ret
}
// formatArgs concatenate the arguments for the Usage() function
func (c *command) formatArgs() string {
return strings.Join(c.args, " ")
}
// iPtoReverse calculates the reverse name associated with an IPv4 or IPv6
func iPtoReverse(ip net.IP) (arpa string) {
const hexDigit = "0123456789abcdef"
// code copied and adapted from the net library
// ip can be 4 or 16 bytes long
if ip.To4() != nil {
if len(ip) == 16 {
return uitoa(uint(ip[15])) + "." + uitoa(uint(ip[14])) + "." + uitoa(uint(ip[13])) + "." + uitoa(uint(ip[12])) + ".in-addr.arpa."
}
return uitoa(uint(ip[3])) + "." + uitoa(uint(ip[2])) + "." + uitoa(uint(ip[1])) + "." + uitoa(uint(ip[0])) + ".in-addr.arpa."
}
// Must be IPv6
buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
// Add it, in reverse, to the buffer
for i := len(ip) - 1; i >= 0; i-- {
v := ip[i]
buf = append(buf, hexDigit[v&0xF])
buf = append(buf, '.')
buf = append(buf, hexDigit[v>>4])
buf = append(buf, '.')
}
// Append "ip6.arpa." and return (buf already has the final .)
buf = append(buf, "ip6.arpa."...)
return string(buf)
}
// Convert unsigned integer to decimal string.
// code copied from the net library
func uitoa(val uint) string {
if val == 0 { // avoid string allocation
return "0"
}
var buf [20]byte // big enough for 64bit value base 10
i := len(buf) - 1
for val >= 10 {
q := val / 10
buf[i] = byte('0' + val - q*10)
i--
val = q
}
// val < 10
buf[i] = byte('0' + val)
return string(buf[i:])
}
// ptrToIP converts a .arpa name to the corresponding IP
func ptrToIP(s string) string {
s = reverseParts(s) // reverse parts between dots (".")
count := 0
ip := ""
version := 4
for _, elt := range strings.Split(s, ".") {
switch elt {
case "":
case "ip6":
version = 6
case "in-addr":
case "arpa":
default:
count++
ip += elt
if version == 4 && count != 4 {
ip += "."
}
if version == 6 && count%4 == 0 && count != 32 {
ip += ":"
}
}
}
return ip
}
// reverse the part order on every member of the array
func reverseParts(s string) string {
parts := strings.Split(s, ".")
for i, j := 0, len(parts)-1; i < j; i, j = i+1, j-1 {
parts[i], parts[j] = parts[j], parts[i]
}
return strings.Join(parts, ".")
}