summaryrefslogtreecommitdiffhomepage
path: root/dhcpv6/async
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv6/async')
-rw-r--r--dhcpv6/async/client.go226
-rw-r--r--dhcpv6/async/client_test.go122
2 files changed, 348 insertions, 0 deletions
diff --git a/dhcpv6/async/client.go b/dhcpv6/async/client.go
new file mode 100644
index 0000000..08c2cfb
--- /dev/null
+++ b/dhcpv6/async/client.go
@@ -0,0 +1,226 @@
+package async
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/fanliao/go-promise"
+ "github.com/insomniacslk/dhcp/dhcpv6"
+)
+
+// Client implements an asynchronous DHCPv6 client
+type Client struct {
+ ReadTimeout time.Duration
+ WriteTimeout time.Duration
+ LocalAddr net.Addr
+ RemoteAddr net.Addr
+ IgnoreErrors bool
+
+ connection *net.UDPConn
+ cancel context.CancelFunc
+ stopping *sync.WaitGroup
+ receiveQueue chan dhcpv6.DHCPv6
+ sendQueue chan dhcpv6.DHCPv6
+ packetsLock sync.Mutex
+ packets map[uint32]*promise.Promise
+ errors chan error
+}
+
+// NewClient creates an asynchronous client
+func NewClient() *Client {
+ return &Client{
+ ReadTimeout: dhcpv6.DefaultReadTimeout,
+ WriteTimeout: dhcpv6.DefaultWriteTimeout,
+ }
+}
+
+// OpenForInterface starts the client on the specified interface, replacing
+// client LocalAddr with a link-local address of the given interface and
+// standard DHCP port (546).
+func (c *Client) OpenForInterface(ifname string, bufferSize int) error {
+ addr, err := dhcpv6.GetLinkLocalAddr(ifname)
+ if err != nil {
+ return err
+ }
+ c.LocalAddr = &net.UDPAddr{IP: *addr, Port: dhcpv6.DefaultClientPort, Zone: ifname}
+ return c.Open(bufferSize)
+}
+
+// Open starts the client
+func (c *Client) Open(bufferSize int) error {
+ var (
+ addr *net.UDPAddr
+ ok bool
+ err error
+ )
+
+ if addr, ok = c.LocalAddr.(*net.UDPAddr); !ok {
+ return fmt.Errorf("Invalid local address: %v not a net.UDPAddr", c.LocalAddr)
+ }
+
+ // prepare the socket to listen on for replies
+ c.connection, err = net.ListenUDP("udp6", addr)
+ if err != nil {
+ return err
+ }
+ c.stopping = new(sync.WaitGroup)
+ c.sendQueue = make(chan dhcpv6.DHCPv6, bufferSize)
+ c.receiveQueue = make(chan dhcpv6.DHCPv6, bufferSize)
+ c.packets = make(map[uint32]*promise.Promise)
+ c.packetsLock = sync.Mutex{}
+ c.errors = make(chan error)
+
+ var ctx context.Context
+ ctx, c.cancel = context.WithCancel(context.Background())
+ go c.receiverLoop(ctx)
+ go c.senderLoop(ctx)
+
+ return nil
+}
+
+// Close stops the client
+func (c *Client) Close() {
+ // Wait for sender and receiver loops
+ c.stopping.Add(2)
+ c.cancel()
+ c.stopping.Wait()
+
+ close(c.sendQueue)
+ close(c.receiveQueue)
+ close(c.errors)
+
+ c.connection.Close()
+}
+
+// Errors returns a channel where runtime errors are posted
+func (c *Client) Errors() <-chan error {
+ return c.errors
+}
+
+func (c *Client) addError(err error) {
+ if !c.IgnoreErrors {
+ c.errors <- err
+ }
+}
+
+func (c *Client) receiverLoop(ctx context.Context) {
+ defer func() { c.stopping.Done() }()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case packet := <-c.receiveQueue:
+ c.receive(packet)
+ }
+ }
+}
+
+func (c *Client) senderLoop(ctx context.Context) {
+ defer func() { c.stopping.Done() }()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case packet := <-c.sendQueue:
+ c.send(packet)
+ }
+ }
+}
+
+func (c *Client) send(packet dhcpv6.DHCPv6) {
+ transactionID, err := dhcpv6.GetTransactionID(packet)
+ if err != nil {
+ c.addError(fmt.Errorf("Warning: This should never happen, there is no transaction ID on %s", packet))
+ return
+ }
+ c.packetsLock.Lock()
+ p := c.packets[transactionID]
+ c.packetsLock.Unlock()
+
+ raddr, err := c.remoteAddr()
+ if err != nil {
+ p.Reject(err)
+ return
+ }
+
+ c.connection.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
+ _, err = c.connection.WriteTo(packet.ToBytes(), raddr)
+ if err != nil {
+ p.Reject(err)
+ return
+ }
+
+ c.receiveQueue <- packet
+}
+
+func (c *Client) receive(_ dhcpv6.DHCPv6) {
+ var (
+ oobdata = []byte{}
+ received dhcpv6.DHCPv6
+ )
+
+ c.connection.SetReadDeadline(time.Now().Add(c.ReadTimeout))
+ for {
+ buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize)
+ n, _, _, _, err := c.connection.ReadMsgUDP(buffer, oobdata)
+ if err != nil {
+ if err, ok := err.(net.Error); !ok || !err.Timeout() {
+ c.addError(fmt.Errorf("Error receiving the message: %s", err))
+ }
+ return
+ }
+ received, err = dhcpv6.FromBytes(buffer[:n])
+ if err != nil {
+ // skip non-DHCP packets
+ continue
+ }
+ break
+ }
+
+ transactionID, err := dhcpv6.GetTransactionID(received)
+ if err != nil {
+ c.addError(fmt.Errorf("Unable to get a transactionID for %s: %s", received, err))
+ return
+ }
+
+ c.packetsLock.Lock()
+ if p, ok := c.packets[transactionID]; ok {
+ delete(c.packets, transactionID)
+ p.Resolve(received)
+ }
+ c.packetsLock.Unlock()
+}
+
+func (c *Client) remoteAddr() (*net.UDPAddr, error) {
+ if c.RemoteAddr == nil {
+ return &net.UDPAddr{IP: dhcpv6.AllDHCPRelayAgentsAndServers, Port: dhcpv6.DefaultServerPort}, nil
+ }
+
+ if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok {
+ return addr, nil
+ }
+ return nil, fmt.Errorf("Invalid remote address: %v not a net.UDPAddr", c.RemoteAddr)
+}
+
+// Send inserts a message to the queue to be sent asynchronously.
+// Returns a future which resolves to response and error.
+func (c *Client) Send(message dhcpv6.DHCPv6, modifiers ...dhcpv6.Modifier) *promise.Future {
+ for _, mod := range modifiers {
+ message = mod(message)
+ }
+
+ transactionID, err := dhcpv6.GetTransactionID(message)
+ if err != nil {
+ return promise.Wrap(err)
+ }
+
+ p := promise.NewPromise()
+ c.packetsLock.Lock()
+ c.packets[transactionID] = p
+ c.packetsLock.Unlock()
+ c.sendQueue <- message
+ return p.Future
+}
diff --git a/dhcpv6/async/client_test.go b/dhcpv6/async/client_test.go
new file mode 100644
index 0000000..8665589
--- /dev/null
+++ b/dhcpv6/async/client_test.go
@@ -0,0 +1,122 @@
+package async
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "github.com/insomniacslk/dhcp/dhcpv6"
+ "github.com/insomniacslk/dhcp/iana"
+)
+
+
+// solicit creates new solicit based on the mac address
+func solicit(input string) (dhcpv6.DHCPv6, error) {
+ mac, err := net.ParseMAC(input)
+ if err != nil {
+ return nil, err
+ }
+ duid := dhcpv6.Duid{
+ Type: dhcpv6.DUID_LLT,
+ HwType: iana.HwTypeEthernet,
+ Time: dhcpv6.GetTime(),
+ LinkLayerAddr: mac,
+ }
+ return dhcpv6.NewSolicitWithCID(duid)
+}
+
+// server creates a server which responds with predefined answers
+func serve(t *testing.T, addr *net.UDPAddr, responses ...dhcpv6.DHCPv6) {
+ conn, err := net.ListenUDP("udp6", addr)
+ require.NoError(t, err)
+ defer conn.Close()
+ oobdata := []byte{}
+ buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize)
+ for _, packet := range responses {
+ n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata)
+ require.NoError(t, err)
+ _, err = dhcpv6.FromBytes(buffer[:n])
+ require.NoError(t, err)
+ _, err = conn.WriteTo(packet.ToBytes(), src)
+ require.NoError(t, err)
+ }
+}
+
+func TestNewClient(t *testing.T) {
+ c := NewClient()
+ require.NotNil(t, c)
+ require.Equal(t, c.ReadTimeout, dhcpv6.DefaultReadTimeout)
+ require.Equal(t, c.ReadTimeout, dhcpv6.DefaultWriteTimeout)
+}
+
+func TestOpenInvalidAddrFailes(t *testing.T) {
+ c := NewClient()
+ err := c.Open(512)
+ require.Error(t, err)
+}
+
+// This test uses port 15438 so please make sure its not used before running
+func TestOpenClose(t *testing.T) {
+ c := NewClient()
+ addr, err := net.ResolveUDPAddr("udp6", ":15438")
+ require.NoError(t, err)
+ c.LocalAddr = addr
+ err = c.Open(512)
+ require.NoError(t, err)
+ defer c.Close()
+}
+
+// This test uses ports 15438 and 15439 so please make sure they are not used
+// before running
+func TestSendTimeout(t *testing.T) {
+ c := NewClient()
+ addr, err := net.ResolveUDPAddr("udp6", ":15438")
+ require.NoError(t, err)
+ remote, err := net.ResolveUDPAddr("udp6", ":15439")
+ require.NoError(t, err)
+ c.ReadTimeout = 50 * time.Millisecond
+ c.WriteTimeout = 50 * time.Millisecond
+ c.LocalAddr = addr
+ c.RemoteAddr = remote
+ err = c.Open(512)
+ require.NoError(t, err)
+ defer c.Close()
+ m, err := dhcpv6.NewMessage()
+ require.NoError(t, err)
+ _, err, timeout := c.Send(m).GetOrTimeout(200)
+ require.NoError(t, err)
+ require.True(t, timeout)
+}
+
+// This test uses ports 15438 and 15439 so please make sure they are not used
+// before running
+func TestSend(t *testing.T) {
+ s, err := solicit("c8:6c:2c:47:96:fd")
+ require.NoError(t, err)
+ require.NotNil(t, s)
+
+ a, err := dhcpv6.NewAdvertiseFromSolicit(s)
+ require.NoError(t, err)
+ require.NotNil(t, a)
+
+ c := NewClient()
+ addr, err := net.ResolveUDPAddr("udp6", ":15438")
+ require.NoError(t, err)
+ remote, err := net.ResolveUDPAddr("udp6", ":15439")
+ require.NoError(t, err)
+ c.LocalAddr = addr
+ c.RemoteAddr = remote
+
+ go serve(t, remote, a)
+
+ err = c.Open(16)
+ require.NoError(t, err)
+ defer c.Close()
+
+ f := c.Send(s)
+ response, err, timeout := f.GetOrTimeout(1000)
+ require.False(t, timeout)
+ require.NoError(t, err)
+ require.Equal(t, a, response)
+}