summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMikoĊ‚aj Walczak <mikiwalczak+github@gmail.com>2018-07-12 10:51:38 +0100
committerinsomniac <insomniacslk@users.noreply.github.com>2018-07-12 10:51:38 +0100
commit8e3bcdab237624421034ccc4eb16f260d4338aec (patch)
tree0cb93b736c59506f68df67ac1150e80047dc202c
parent34154e71da6f5b4527809dc0babdefcbd262281c (diff)
Asynchronous client for DHCPv6 (#80)
-rw-r--r--dhcpv6/async.go235
-rw-r--r--dhcpv6/async_test.go54
-rw-r--r--dhcpv6/dhcpv6relay.go2
-rw-r--r--dhcpv6/future.go111
-rw-r--r--dhcpv6/future_test.go170
-rw-r--r--dhcpv6/utils.go17
-rw-r--r--dhcpv6/utils_test.go18
7 files changed, 606 insertions, 1 deletions
diff --git a/dhcpv6/async.go b/dhcpv6/async.go
new file mode 100644
index 0000000..e9930bf
--- /dev/null
+++ b/dhcpv6/async.go
@@ -0,0 +1,235 @@
+package dhcpv6
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+)
+
+// AsyncClient implements an asynchronous DHCPv6 client
+type AsyncClient struct {
+ ReadTimeout time.Duration
+ WriteTimeout time.Duration
+ LocalAddr net.Addr
+ RemoteAddr net.Addr
+ IgnoreErrors bool
+
+ connection *net.UDPConn
+ cancel context.CancelFunc
+ stopping *sync.WaitGroup
+ receiveQueue chan DHCPv6
+ sendQueue chan DHCPv6
+ packetsLock sync.Mutex
+ packets map[uint32](chan Response)
+ errors chan error
+}
+
+// NewAsyncClient creates an asynchronous client
+func NewAsyncClient() *AsyncClient {
+ return &AsyncClient{
+ ReadTimeout: DefaultReadTimeout,
+ WriteTimeout: DefaultWriteTimeout,
+ }
+}
+
+// OpenForInterface starts the client on the specified interface, replacing
+// client LocalAddr with a link-local address of the given interface and
+// standard DHCP port (546).
+func (c *AsyncClient) OpenForInterface(ifname string, bufferSize int) error {
+ addr, err := GetLinkLocalAddr(ifname)
+ if err != nil {
+ return err
+ }
+ c.LocalAddr = &net.UDPAddr{IP: *addr, Port: DefaultClientPort, Zone: ifname}
+ return c.Open(bufferSize)
+}
+
+// Open starts the client
+func (c *AsyncClient) Open(bufferSize int) error {
+ var (
+ addr *net.UDPAddr
+ ok bool
+ err error
+ )
+
+ if addr, ok = c.LocalAddr.(*net.UDPAddr); !ok {
+ return fmt.Errorf("Invalid local address: %v not a net.UDPAddr", c.LocalAddr)
+ }
+
+ // prepare the socket to listen on for replies
+ c.connection, err = net.ListenUDP("udp6", addr)
+ if err != nil {
+ return err
+ }
+ c.stopping = new(sync.WaitGroup)
+ c.sendQueue = make(chan DHCPv6, bufferSize)
+ c.receiveQueue = make(chan DHCPv6, bufferSize)
+ c.packets = make(map[uint32](chan Response))
+ c.packetsLock = sync.Mutex{}
+ c.errors = make(chan error)
+
+ var ctx context.Context
+ ctx, c.cancel = context.WithCancel(context.Background())
+ go c.receiverLoop(ctx)
+ go c.senderLoop(ctx)
+
+ return nil
+}
+
+// Close stops the client
+func (c *AsyncClient) Close() {
+ // Wait for sender and receiver loops
+ c.stopping.Add(2)
+ c.cancel()
+ c.stopping.Wait()
+
+ close(c.sendQueue)
+ close(c.receiveQueue)
+ close(c.errors)
+
+ c.connection.Close()
+}
+
+// Errors returns a channel where runtime errors are posted
+func (c *AsyncClient) Errors() <-chan error {
+ return c.errors
+}
+
+func (c *AsyncClient) addError(err error) {
+ if !c.IgnoreErrors {
+ c.errors <- err
+ }
+}
+
+func (c *AsyncClient) receiverLoop(ctx context.Context) {
+ defer func() { c.stopping.Done() }()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case packet := <-c.receiveQueue:
+ c.receive(packet)
+ }
+ }
+}
+
+func (c *AsyncClient) senderLoop(ctx context.Context) {
+ defer func() { c.stopping.Done() }()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case packet := <-c.sendQueue:
+ c.send(packet)
+ }
+ }
+}
+
+func (c *AsyncClient) send(packet DHCPv6) {
+ transactionID, err := GetTransactionID(packet)
+ if err != nil {
+ c.addError(fmt.Errorf("Warning: This should never happen, there is no transaction ID on %s", packet))
+ return
+ }
+ c.packetsLock.Lock()
+ f := c.packets[transactionID]
+ c.packetsLock.Unlock()
+
+ raddr, err := c.remoteAddr()
+ if err != nil {
+ f <- NewResponse(nil, err)
+ return
+ }
+
+ c.connection.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
+ _, err = c.connection.WriteTo(packet.ToBytes(), raddr)
+ if err != nil {
+ f <- NewResponse(nil, err)
+ return
+ }
+
+ c.receiveQueue <- packet
+}
+
+func (c *AsyncClient) receive(_ DHCPv6) {
+ var (
+ oobdata = []byte{}
+ received DHCPv6
+ )
+
+ c.connection.SetReadDeadline(time.Now().Add(c.ReadTimeout))
+ for {
+ buffer := make([]byte, maxUDPReceivedPacketSize)
+ n, _, _, _, err := c.connection.ReadMsgUDP(buffer, oobdata)
+ if err != nil {
+ if err, ok := err.(net.Error); !ok || !err.Timeout() {
+ c.addError(fmt.Errorf("Error receiving the message: %s", err))
+ }
+ return
+ }
+ received, err = FromBytes(buffer[:n])
+ if err != nil {
+ // skip non-DHCP packets
+ continue
+ }
+ break
+ }
+
+ transactionID, err := GetTransactionID(received)
+ if err != nil {
+ c.addError(fmt.Errorf("Unable to get a transactionID for %s: %s", received, err))
+ return
+ }
+
+ c.packetsLock.Lock()
+ if f, ok := c.packets[transactionID]; ok {
+ delete(c.packets, transactionID)
+ f <- NewResponse(received, nil)
+ }
+ c.packetsLock.Unlock()
+}
+
+func (c *AsyncClient) remoteAddr() (*net.UDPAddr, error) {
+ if c.RemoteAddr == nil {
+ return &net.UDPAddr{IP: AllDHCPRelayAgentsAndServers, Port: DefaultServerPort}, nil
+ }
+
+ if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok {
+ return addr, nil
+ }
+ return nil, fmt.Errorf("Invalid remote address: %v not a net.UDPAddr", c.RemoteAddr)
+}
+
+// Send inserts a message to the queue to be sent asynchronously.
+// Returns a future which resolves to response and error.
+func (c *AsyncClient) Send(message DHCPv6, modifiers ...Modifier) Future {
+ for _, mod := range modifiers {
+ message = mod(message)
+ }
+
+ transactionID, err := GetTransactionID(message)
+ if err != nil {
+ return NewFailureFuture(err)
+ }
+
+ f := NewFuture()
+ c.packetsLock.Lock()
+ c.packets[transactionID] = f
+ c.packetsLock.Unlock()
+ c.sendQueue <- message
+ return f
+}
+
+// Exchange executes asynchronously a 4-way DHCPv6 request (SOLICIT,
+// ADVERTISE, REQUEST, REPLY).
+func (c *AsyncClient) Exchange(solicit DHCPv6, modifiers ...Modifier) Future {
+ return c.Send(solicit).OnSuccess(func(advertise DHCPv6) Future {
+ request, err := NewRequestFromAdvertise(advertise)
+ if err != nil {
+ return NewFailureFuture(err)
+ }
+ return c.Send(request, modifiers...)
+ })
+}
diff --git a/dhcpv6/async_test.go b/dhcpv6/async_test.go
new file mode 100644
index 0000000..4f5a750
--- /dev/null
+++ b/dhcpv6/async_test.go
@@ -0,0 +1,54 @@
+package dhcpv6
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewAsyncClient(t *testing.T) {
+ c := NewAsyncClient()
+ require.NotNil(t, c)
+ require.Equal(t, c.ReadTimeout, DefaultReadTimeout)
+ require.Equal(t, c.ReadTimeout, DefaultWriteTimeout)
+}
+
+func TestOpenInvalidAddrFailes(t *testing.T) {
+ c := NewAsyncClient()
+ err := c.Open(512)
+ require.Error(t, err)
+}
+
+// This test uses port 15438 so please make sure its not used before running
+func TestOpenClose(t *testing.T) {
+ c := NewAsyncClient()
+ addr, err := net.ResolveUDPAddr("udp6", ":15438")
+ require.NoError(t, err)
+ c.LocalAddr = addr
+ err = c.Open(512)
+ require.NoError(t, err)
+ defer c.Close()
+}
+
+// This test uses ports 15438 and 15439 so please make sure they are not used
+// before running
+func TestSendTimeout(t *testing.T) {
+ c := NewAsyncClient()
+ addr, err := net.ResolveUDPAddr("udp6", ":15438")
+ require.NoError(t, err)
+ remote, err := net.ResolveUDPAddr("udp6", ":15439")
+ require.NoError(t, err)
+ c.ReadTimeout = 50 * time.Millisecond
+ c.WriteTimeout = 50 * time.Millisecond
+ c.LocalAddr = addr
+ c.RemoteAddr = remote
+ err = c.Open(512)
+ require.NoError(t, err)
+ defer c.Close()
+ m, err := NewMessage()
+ require.NoError(t, err)
+ _, err = c.Send(m).WaitTimeout(200 * time.Millisecond)
+ require.Error(t, err)
+}
diff --git a/dhcpv6/dhcpv6relay.go b/dhcpv6/dhcpv6relay.go
index fb2a0c1..dbdb623 100644
--- a/dhcpv6/dhcpv6relay.go
+++ b/dhcpv6/dhcpv6relay.go
@@ -199,7 +199,7 @@ func NewRelayReplFromRelayForw(relayForw, msg DHCPv6) (DHCPv6, error) {
}
if relay.Type() != RELAY_FORW {
return nil, errors.New("The passed packet is not of type RELAY_FORW")
- }
+ }
if msg == nil {
return nil, errors.New("The passed message cannot be nil")
}
diff --git a/dhcpv6/future.go b/dhcpv6/future.go
new file mode 100644
index 0000000..b431419
--- /dev/null
+++ b/dhcpv6/future.go
@@ -0,0 +1,111 @@
+package dhcpv6
+
+import (
+ "errors"
+ "time"
+)
+
+// Response represents a value which Future resolves to
+type Response interface {
+ Value() DHCPv6
+ Error() error
+}
+
+// Future is a result of an asynchronous DHCPv6 call
+type Future (<-chan Response)
+
+// SuccessFun can be used as a success callback
+type SuccessFun func(val DHCPv6) Future
+
+// FailureFun can be used as a failure callback
+type FailureFun func(err error) Future
+
+type response struct {
+ val DHCPv6
+ err error
+}
+
+func (r *response) Value() DHCPv6 {
+ return r.val
+}
+
+func (r *response) Error() error {
+ return r.err
+}
+
+// NewFuture creates a new future, which can be written to
+func NewFuture() chan Response {
+ return make(chan Response)
+}
+
+// NewResponse creates a new future response
+func NewResponse(val DHCPv6, err error) Response {
+ return &response{val: val, err: err}
+}
+
+// NewSuccessFuture creates a future that resolves to a value
+func NewSuccessFuture(val DHCPv6) Future {
+ f := NewFuture()
+ go func() {
+ f <- NewResponse(val, nil)
+ }()
+ return f
+}
+
+// NewFailureFuture creates a future that resolves to an error
+func NewFailureFuture(err error) Future {
+ f := NewFuture()
+ go func() {
+ f <- NewResponse(nil, err)
+ }()
+ return f
+}
+
+// Then allows to chain the futures executing appropriate function depending
+// on the previous future value
+func (f Future) Then(success SuccessFun, failure FailureFun) Future {
+ g := NewFuture()
+ go func() {
+ r := <-f
+ if r.Error() != nil {
+ r = <-failure(r.Error())
+ g <- r
+ } else {
+ r = <-success(r.Value())
+ g <- r
+ }
+ }()
+ return g
+}
+
+// OnSuccess allows to chain the futures executing next one only if the first
+// one succeeds
+func (f Future) OnSuccess(success SuccessFun) Future {
+ return f.Then(success, func(err error) Future {
+ return NewFailureFuture(err)
+ })
+}
+
+// OnFailure allows to chain the futures executing next one only if the first
+// one fails
+func (f Future) OnFailure(failure FailureFun) Future {
+ return f.Then(func(val DHCPv6) Future {
+ return NewSuccessFuture(val)
+ }, failure)
+}
+
+// Wait blocks the execution until a future resolves
+func (f Future) Wait() (DHCPv6, error) {
+ r := <-f
+ return r.Value(), r.Error()
+}
+
+// WaitTimeout blocks the execution until a future resolves or times out
+func (f Future) WaitTimeout(timeout time.Duration) (DHCPv6, error) {
+ select {
+ case r := <-f:
+ return r.Value(), r.Error()
+ case <-time.After(timeout):
+ return nil, errors.New("Timed out")
+ }
+}
diff --git a/dhcpv6/future_test.go b/dhcpv6/future_test.go
new file mode 100644
index 0000000..bee87e3
--- /dev/null
+++ b/dhcpv6/future_test.go
@@ -0,0 +1,170 @@
+package dhcpv6
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestResponseValue(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ r := NewResponse(m, nil)
+ require.Equal(t, r.Value(), m)
+ require.Equal(t, r.Error(), nil)
+}
+
+func TestResponseError(t *testing.T) {
+ e := errors.New("Test error")
+ r := NewResponse(nil, e)
+ require.Equal(t, r.Value(), nil)
+ require.Equal(t, r.Error(), e)
+}
+
+func TestSuccessFuture(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ f := NewSuccessFuture(m)
+
+ val, err := f.Wait()
+ require.NoError(t, err)
+ require.Equal(t, val, m)
+}
+
+func TestFailureFuture(t *testing.T) {
+ e := errors.New("Test error")
+ f := NewFailureFuture(e)
+
+ val, err := f.Wait()
+ require.Equal(t, err, e)
+ require.Equal(t, val, nil)
+}
+
+func TestThenSuccess(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ s, err := NewMessage()
+ require.NoError(t, err)
+ e := errors.New("Test error")
+
+ f := NewSuccessFuture(m).
+ Then(func(_ DHCPv6) Future {
+ return NewSuccessFuture(s)
+ }, func(_ error) Future {
+ return NewFailureFuture(e)
+ })
+
+ val, err := f.Wait()
+ require.NoError(t, err)
+ require.NotEqual(t, val, m)
+ require.Equal(t, val, s)
+}
+
+func TestThenFailure(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ s, err := NewMessage()
+ require.NoError(t, err)
+ e := errors.New("Test error")
+ e2 := errors.New("Test error 2")
+
+ f := NewFailureFuture(e).
+ Then(func(_ DHCPv6) Future {
+ return NewSuccessFuture(s)
+ }, func(_ error) Future {
+ return NewFailureFuture(e2)
+ })
+
+ val, err := f.Wait()
+ require.Error(t, err)
+ require.NotEqual(t, val, m)
+ require.NotEqual(t, val, s)
+ require.NotEqual(t, err, e)
+ require.Equal(t, err, e2)
+}
+
+func TestOnSuccess(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ s, err := NewMessage()
+ require.NoError(t, err)
+
+ f := NewSuccessFuture(m).
+ OnSuccess(func(_ DHCPv6) Future {
+ return NewSuccessFuture(s)
+ })
+
+ val, err := f.Wait()
+ require.NoError(t, err)
+ require.NotEqual(t, val, m)
+ require.Equal(t, val, s)
+}
+
+func TestOnSuccessForFailureFuture(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ e := errors.New("Test error")
+
+ f := NewFailureFuture(e).
+ OnSuccess(func(_ DHCPv6) Future {
+ return NewSuccessFuture(m)
+ })
+
+ val, err := f.Wait()
+ require.Error(t, err)
+ require.Equal(t, err, e)
+ require.NotEqual(t, val, m)
+}
+
+func TestOnFailure(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ s, err := NewMessage()
+ require.NoError(t, err)
+ e := errors.New("Test error")
+
+ f := NewFailureFuture(e).
+ OnFailure(func(_ error) Future {
+ return NewSuccessFuture(s)
+ })
+
+ val, err := f.Wait()
+ require.NoError(t, err)
+ require.NotEqual(t, val, m)
+ require.Equal(t, val, s)
+}
+
+func TestOnFailureForSuccessFuture(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ s, err := NewMessage()
+ require.NoError(t, err)
+
+ f := NewSuccessFuture(m).
+ OnFailure(func(_ error) Future {
+ return NewSuccessFuture(s)
+ })
+
+ val, err := f.Wait()
+ require.NoError(t, err)
+ require.NotEqual(t, val, s)
+ require.Equal(t, val, m)
+}
+
+func TestWaitTimeout(t *testing.T) {
+ m, err := NewMessage()
+ require.NoError(t, err)
+ s, err := NewMessage()
+ require.NoError(t, err)
+ f := NewSuccessFuture(m).OnSuccess(func(_ DHCPv6) Future {
+ time.Sleep(1 * time.Second)
+ return NewSuccessFuture(s)
+ })
+ val, err := f.WaitTimeout(50 * time.Millisecond)
+ require.Error(t, err)
+ require.Equal(t, err.Error(), "Timed out")
+ require.NotEqual(t, val, m)
+ require.NotEqual(t, val, s)
+}
diff --git a/dhcpv6/utils.go b/dhcpv6/utils.go
index 81ebaae..b1c0b93 100644
--- a/dhcpv6/utils.go
+++ b/dhcpv6/utils.go
@@ -1,6 +1,7 @@
package dhcpv6
import (
+ "errors"
"strings"
)
@@ -58,3 +59,19 @@ func IsUsingUEFI(msg DHCPv6) bool {
}
return false
}
+
+// GetTransactionID returns a transactionID of a message or its inner message
+// in case of relay
+func GetTransactionID(packet DHCPv6) (uint32, error) {
+ if message, ok := packet.(*DHCPv6Message); ok {
+ return message.TransactionID(), nil
+ }
+ if relay, ok := packet.(*DHCPv6Relay); ok {
+ message, err := relay.GetInnerMessage()
+ if err != nil {
+ return 0, err
+ }
+ return GetTransactionID(message)
+ }
+ return 0, errors.New("Invalid DHCPv6 packet")
+}
diff --git a/dhcpv6/utils_test.go b/dhcpv6/utils_test.go
index 2373691..77205b4 100644
--- a/dhcpv6/utils_test.go
+++ b/dhcpv6/utils_test.go
@@ -49,3 +49,21 @@ func TestIsUsingUEFIUserClassFalse(t *testing.T) {
msg.AddOption(&opt)
require.False(t, IsUsingUEFI(&msg))
}
+
+func TestGetTransactionIDMessage(t *testing.T) {
+ message, err := NewMessage()
+ require.NoError(t, err)
+ transactionID, err := GetTransactionID(message)
+ require.NoError(t, err)
+ require.Equal(t, transactionID, message.(*DHCPv6Message).TransactionID())
+}
+
+func TestGetTransactionIDRelay(t *testing.T) {
+ message, err := NewMessage()
+ require.NoError(t, err)
+ relay, err := EncapsulateRelay(message, RELAY_FORW, nil, nil)
+ require.NoError(t, err)
+ transactionID, err := GetTransactionID(relay)
+ require.NoError(t, err)
+ require.Equal(t, transactionID, message.(*DHCPv6Message).TransactionID())
+}