diff options
Diffstat (limited to 'dhcpv4')
-rw-r--r-- | dhcpv4/async/client.go | 208 | ||||
-rw-r--r-- | dhcpv4/async/client_test.go | 125 |
2 files changed, 333 insertions, 0 deletions
diff --git a/dhcpv4/async/client.go b/dhcpv4/async/client.go new file mode 100644 index 0000000..e6c7302 --- /dev/null +++ b/dhcpv4/async/client.go @@ -0,0 +1,208 @@ +package async + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/fanliao/go-promise" + "github.com/insomniacslk/dhcp/dhcpv4" +) + +// Default ports +const ( + DefaultServerPort = 67 + DefaultClientPort = 68 +) + +// Client implements an asynchronous DHCPv4 client +// It doesn't use the broadcast socket! Which means it should be used only when +// the network is already established. +// https://github.com/insomniacslk/dhcp/issues/143 +type Client struct { + ReadTimeout time.Duration + WriteTimeout time.Duration + LocalAddr net.Addr + RemoteAddr net.Addr + IgnoreErrors bool + + connection *net.UDPConn + cancel context.CancelFunc + stopping *sync.WaitGroup + receiveQueue chan *dhcpv4.DHCPv4 + sendQueue chan *dhcpv4.DHCPv4 + packetsLock sync.Mutex + packets map[uint32]*promise.Promise + errors chan error +} + +// NewClient creates an asynchronous client +func NewClient() *Client { + return &Client{ + ReadTimeout: dhcpv4.DefaultReadTimeout, + WriteTimeout: dhcpv4.DefaultWriteTimeout, + } +} + +// Open starts the client. The requests made with Send function call are first +// put to the buffered channel and dispatched in FIFO order. BufferSize +// indicates the number of packets that can be waiting to be send before +// blocking the caller exectution. +func (c *Client) Open(bufferSize int) error { + var ( + addr *net.UDPAddr + ok bool + err error + ) + + if addr, ok = c.LocalAddr.(*net.UDPAddr); !ok { + return fmt.Errorf("Invalid local address: %v not a net.UDPAddr", c.LocalAddr) + } + + // prepare the socket to listen on for replies + c.connection, err = net.ListenUDP("udp4", addr) + if err != nil { + return err + } + c.stopping = new(sync.WaitGroup) + c.sendQueue = make(chan *dhcpv4.DHCPv4, bufferSize) + c.receiveQueue = make(chan *dhcpv4.DHCPv4, bufferSize) + c.packets = make(map[uint32]*promise.Promise) + c.packetsLock = sync.Mutex{} + c.errors = make(chan error) + + var ctx context.Context + ctx, c.cancel = context.WithCancel(context.Background()) + go c.receiverLoop(ctx) + go c.senderLoop(ctx) + + return nil +} + +// Close stops the client +func (c *Client) Close() { + // Wait for sender and receiver loops + c.stopping.Add(2) + c.cancel() + c.stopping.Wait() + + close(c.sendQueue) + close(c.receiveQueue) + close(c.errors) + + c.connection.Close() +} + +// Errors returns a channel where runtime errors are posted +func (c *Client) Errors() <-chan error { + return c.errors +} + +func (c *Client) addError(err error) { + if !c.IgnoreErrors { + c.errors <- err + } +} + +func (c *Client) receiverLoop(ctx context.Context) { + defer func() { c.stopping.Done() }() + for { + select { + case <-ctx.Done(): + return + case packet := <-c.receiveQueue: + c.receive(packet) + } + } +} + +func (c *Client) senderLoop(ctx context.Context) { + defer func() { c.stopping.Done() }() + for { + select { + case <-ctx.Done(): + return + case packet := <-c.sendQueue: + c.send(packet) + } + } +} + +func (c *Client) send(packet *dhcpv4.DHCPv4) { + c.packetsLock.Lock() + p := c.packets[packet.TransactionID()] + c.packetsLock.Unlock() + + raddr, err := c.remoteAddr() + if err != nil { + p.Reject(err) + return + } + + c.connection.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + _, err = c.connection.WriteTo(packet.ToBytes(), raddr) + if err != nil { + p.Reject(err) + return + } + + c.receiveQueue <- packet +} + +func (c *Client) receive(_ *dhcpv4.DHCPv4) { + var ( + oobdata = []byte{} + received *dhcpv4.DHCPv4 + ) + + c.connection.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + for { + buffer := make([]byte, dhcpv4.MaxUDPReceivedPacketSize) + n, _, _, _, err := c.connection.ReadMsgUDP(buffer, oobdata) + if err != nil { + if err, ok := err.(net.Error); !ok || !err.Timeout() { + c.addError(fmt.Errorf("Error receiving the message: %s", err)) + } + return + } + received, err = dhcpv4.FromBytes(buffer[:n]) + if err == nil { + break + } + } + + c.packetsLock.Lock() + if p, ok := c.packets[received.TransactionID()]; ok { + delete(c.packets, received.TransactionID()) + p.Resolve(received) + } + c.packetsLock.Unlock() +} + +func (c *Client) remoteAddr() (*net.UDPAddr, error) { + if c.RemoteAddr == nil { + return &net.UDPAddr{IP: net.IPv4bcast, Port: DefaultServerPort}, nil + } + + if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok { + return addr, nil + } + return nil, fmt.Errorf("Invalid remote address: %v not a net.UDPAddr", c.RemoteAddr) +} + +// Send inserts a message to the queue to be sent asynchronously. +// Returns a future which resolves to response and error. +func (c *Client) Send(message *dhcpv4.DHCPv4, modifiers ...dhcpv4.Modifier) *promise.Future { + for _, mod := range modifiers { + message = mod(message) + } + + p := promise.NewPromise() + c.packetsLock.Lock() + c.packets[message.TransactionID()] = p + c.packetsLock.Unlock() + c.sendQueue <- message + return p.Future +} diff --git a/dhcpv4/async/client_test.go b/dhcpv4/async/client_test.go new file mode 100644 index 0000000..4be6edd --- /dev/null +++ b/dhcpv4/async/client_test.go @@ -0,0 +1,125 @@ +package async + +import ( + "context" + "net" + "testing" + "time" + + "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/stretchr/testify/require" +) + +// server creates a server which responds with a predefined response +func serve(ctx context.Context, addr *net.UDPAddr, response *dhcpv4.DHCPv4) error { + conn, err := net.ListenUDP("udp4", addr) + if err != nil { + return err + } + go func() { + defer conn.Close() + oobdata := []byte{} + buffer := make([]byte, dhcpv4.MaxUDPReceivedPacketSize) + for { + select { + case <-ctx.Done(): + return + default: + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) + if err != nil { + continue + } + _, err = dhcpv4.FromBytes(buffer[:n]) + if err != nil { + continue + } + conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err = conn.WriteTo(response.ToBytes(), src) + if err != nil { + continue + } + } + } + }() + return nil +} + +func TestNewClient(t *testing.T) { + c := NewClient() + require.NotNil(t, c) + require.Equal(t, c.ReadTimeout, dhcpv4.DefaultReadTimeout) + require.Equal(t, c.ReadTimeout, dhcpv4.DefaultWriteTimeout) +} + +func TestOpenInvalidAddrFailes(t *testing.T) { + c := NewClient() + err := c.Open(512) + require.Error(t, err) +} + +// This test uses port 15438 so please make sure its not used before running +func TestOpenClose(t *testing.T) { + c := NewClient() + addr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15438") + require.NoError(t, err) + c.LocalAddr = addr + err = c.Open(512) + require.NoError(t, err) + defer c.Close() +} + +// This test uses ports 15438 and 15439 so please make sure they are not used +// before running +func TestSendTimeout(t *testing.T) { + c := NewClient() + addr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15438") + require.NoError(t, err) + remote, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15439") + require.NoError(t, err) + c.ReadTimeout = 50 * time.Millisecond + c.WriteTimeout = 50 * time.Millisecond + c.LocalAddr = addr + c.RemoteAddr = remote + err = c.Open(512) + require.NoError(t, err) + defer c.Close() + m, err := dhcpv4.New() + require.NoError(t, err) + _, err, timeout := c.Send(m).GetOrTimeout(200) + require.NoError(t, err) + require.True(t, timeout) +} + +// This test uses ports 15438 and 15439 so please make sure they are not used +// before running +func TestSend(t *testing.T) { + m, err := dhcpv4.New() + require.NoError(t, err) + require.NotNil(t, m) + + c := NewClient() + addr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15438") + require.NoError(t, err) + remote, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15439") + require.NoError(t, err) + c.LocalAddr = addr + c.RemoteAddr = remote + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err = serve(ctx, remote, m) + require.NoError(t, err) + + err = c.Open(16) + require.NoError(t, err) + defer c.Close() + + f := c.Send(m) + response, err, timeout := f.GetOrTimeout(2000) + r, ok := response.(*dhcpv4.DHCPv4) + require.True(t, ok) + require.False(t, timeout) + require.NoError(t, err) + require.Equal(t, m.TransactionID(), r.TransactionID()) +} |