diff --git a/dhcp.go b/dhcp.go new file mode 100644 index 0000000..195f086 --- /dev/null +++ b/dhcp.go @@ -0,0 +1,34 @@ +package main + +import ( + "errors" +) + +func (s *OpenVpnMgt) isFree(ip string) bool { + for _, remote := range s.clients { + for _, c := range remote { + if c.PrivIP == ip { + return false + } + } + } + return true +} + +// internal DHCP +func (s *OpenVpnMgt) getIP(c *vpnSession) (string, error) { + s.m.Lock() + defer s.m.Unlock() + + ipmax := nextIP(s.ldap[c.Profile].ipMax).String() + + sip := s.ldap[c.Profile].ipMin.String() + for ip := s.ldap[c.Profile].ipMin; sip != ipmax; ip = nextIP(ip) { + sip = ip.String() + if s.isFree(sip) { + return sip, nil + } + } + + return "", errors.New("no more IP") +} diff --git a/utils.go b/utils.go index d09882c..8171c8b 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,8 @@ package main import ( + "math/big" + "net" "sort" "github.com/pyke369/golang-support/uconfig" @@ -32,3 +34,14 @@ func parseConfigArray(config *uconfig.UConfig, configpath string) []string { } return result } + +func nextIP(ip net.IP) net.IP { + // Convert to big.Int and increment + ipb := big.NewInt(0).SetBytes([]byte(ip)) + ipb.Add(ipb, big.NewInt(1)) + + // Add leading zeros + b := ipb.Bytes() + b = append(make([]byte, len(ip)-len(b)), b...) + return net.IP(b) +} diff --git a/vpnserver.go b/vpnserver.go index c463477..1ae2c2e 100644 --- a/vpnserver.go +++ b/vpnserver.go @@ -126,14 +126,6 @@ func (s *OpenVpnMgt) Version() (error, map[string][]string) { return nil, ret } -// internal DHCP -func (s *OpenVpnMgt) getIP(c *vpnSession) (string, error) { - // TODO implement - ip := s.ldap[c.Profile].ipMin - - return ip.String(), nil -} - // called after a client is confirmed connected and authenticated func (s *OpenVpnMgt) ClientValidated(line, remote string) { err, c := s.getClient(line, remote) @@ -153,7 +145,6 @@ func (s *OpenVpnMgt) ClientValidated(line, remote string) { // called after a client is disconnected, including for auth issues func (s *OpenVpnMgt) ClientDisconnect(line, remote string) { - //TODO free the IP err, c := s.getClient(line, remote) if err != nil { log.Println(err) @@ -218,8 +209,6 @@ func (s *OpenVpnMgt) handleConn(conn net.Conn) { defer delete(s.buf, remote) defer delete(s.clients, remote) - // TODO : free all IPs if disconnected - // we store the buffer pointer in the struct, to be accessed from other methods s.buf[remote] = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) s.clients[remote] = make(map[int]*vpnSession) @@ -311,9 +300,5 @@ func (s *OpenVpnMgt) handleConn(conn net.Conn) { default: response = append(response, line) } - // TODO remove this - if false && strings.Index(line, "password") == -1 { - log.Print(line) - } } }