diff options
Diffstat (limited to 'dhcpv6/async')
-rw-r--r-- | dhcpv6/async/client.go | 236 | ||||
-rw-r--r-- | dhcpv6/async/client_test.go | 151 |
2 files changed, 0 insertions, 387 deletions
diff --git a/dhcpv6/async/client.go b/dhcpv6/async/client.go deleted file mode 100644 index 8a7fbff..0000000 --- a/dhcpv6/async/client.go +++ /dev/null @@ -1,236 +0,0 @@ -package async - -import ( - "context" - "fmt" - "log" - "net" - "sync" - "time" - - promise "github.com/fanliao/go-promise" - "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/insomniacslk/dhcp/dhcpv6/client6" -) - -// Client implements an asynchronous DHCPv6 client -type Client struct { - ReadTimeout time.Duration - WriteTimeout time.Duration - LocalAddr net.Addr - RemoteAddr net.Addr - IgnoreErrors bool - - connection *net.UDPConn - cancel context.CancelFunc - stopping *sync.WaitGroup - receiveQueue chan dhcpv6.DHCPv6 - sendQueue chan dhcpv6.DHCPv6 - packetsLock sync.Mutex - packets map[dhcpv6.TransactionID]*promise.Promise - errors chan error -} - -// NewClient creates an asynchronous client -func NewClient() *Client { - return &Client{ - ReadTimeout: client6.DefaultReadTimeout, - WriteTimeout: client6.DefaultWriteTimeout, - } -} - -// OpenForInterface starts the client on the specified interface, replacing -// client LocalAddr with a link-local address of the given interface and -// standard DHCP port (546). -func (c *Client) OpenForInterface(ifname string, bufferSize int) error { - addr, err := dhcpv6.GetLinkLocalAddr(ifname) - if err != nil { - return err - } - c.LocalAddr = &net.UDPAddr{IP: addr, Port: dhcpv6.DefaultClientPort, Zone: ifname} - return c.Open(bufferSize) -} - -// Open starts the client -func (c *Client) Open(bufferSize int) error { - var ( - addr *net.UDPAddr - ok bool - err error - ) - - if addr, ok = c.LocalAddr.(*net.UDPAddr); !ok { - return fmt.Errorf("Invalid local address: %v not a net.UDPAddr", c.LocalAddr) - } - - // prepare the socket to listen on for replies - c.connection, err = net.ListenUDP("udp6", addr) - if err != nil { - return err - } - c.stopping = new(sync.WaitGroup) - c.sendQueue = make(chan dhcpv6.DHCPv6, bufferSize) - c.receiveQueue = make(chan dhcpv6.DHCPv6, bufferSize) - c.packets = make(map[dhcpv6.TransactionID]*promise.Promise) - c.packetsLock = sync.Mutex{} - c.errors = make(chan error) - - var ctx context.Context - ctx, c.cancel = context.WithCancel(context.Background()) - go c.receiverLoop(ctx) - go c.senderLoop(ctx) - - return nil -} - -// Close stops the client -func (c *Client) Close() { - // Wait for sender and receiver loops - c.stopping.Add(2) - c.cancel() - c.stopping.Wait() - - close(c.sendQueue) - close(c.receiveQueue) - close(c.errors) - - c.connection.Close() -} - -// Errors returns a channel where runtime errors are posted -func (c *Client) Errors() <-chan error { - return c.errors -} - -func (c *Client) addError(err error) { - if !c.IgnoreErrors { - c.errors <- err - } -} - -func (c *Client) receiverLoop(ctx context.Context) { - defer func() { c.stopping.Done() }() - for { - select { - case <-ctx.Done(): - return - case packet := <-c.receiveQueue: - c.receive(packet) - } - } -} - -func (c *Client) senderLoop(ctx context.Context) { - defer func() { c.stopping.Done() }() - for { - select { - case <-ctx.Done(): - return - case packet := <-c.sendQueue: - c.send(packet) - } - } -} - -func (c *Client) send(packet dhcpv6.DHCPv6) { - transactionID, err := dhcpv6.GetTransactionID(packet) - if err != nil { - c.addError(fmt.Errorf("Warning: This should never happen, there is no transaction ID on %s", packet)) - return - } - c.packetsLock.Lock() - p := c.packets[transactionID] - c.packetsLock.Unlock() - - raddr, err := c.remoteAddr() - if err != nil { - _ = p.Reject(err) - log.Printf("Warning: cannot get remote address :%v", err) - return - } - - if err := c.connection.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil { - _ = p.Reject(err) - log.Printf("Warning: cannot set write deadline :%v", err) - return - } - _, err = c.connection.WriteTo(packet.ToBytes(), raddr) - if err != nil { - _ = p.Reject(err) - log.Printf("Warning: cannot write to %s :%v", raddr, err) - return - } - - c.receiveQueue <- packet -} - -func (c *Client) receive(_ dhcpv6.DHCPv6) { - var ( - oobdata = []byte{} - received dhcpv6.DHCPv6 - ) - - if err := c.connection.SetReadDeadline(time.Now().Add(c.ReadTimeout)); err != nil { - log.Printf("Warning: cannot set read deadline :%v", err) - } - for { - buffer := make([]byte, client6.MaxUDPReceivedPacketSize) - n, _, _, _, err := c.connection.ReadMsgUDP(buffer, oobdata) - if err != nil { - if err, ok := err.(net.Error); !ok || !err.Timeout() { - c.addError(fmt.Errorf("Error receiving the message: %s", err)) - } - return - } - received, err = dhcpv6.FromBytes(buffer[:n]) - if err != nil { - // skip non-DHCP packets - continue - } - break - } - - transactionID, err := dhcpv6.GetTransactionID(received) - if err != nil { - c.addError(fmt.Errorf("Unable to get a transactionID for %s: %s", received, err)) - return - } - - c.packetsLock.Lock() - if p, ok := c.packets[transactionID]; ok { - delete(c.packets, transactionID) - _ = p.Resolve(received) - } - c.packetsLock.Unlock() -} - -func (c *Client) remoteAddr() (*net.UDPAddr, error) { - if c.RemoteAddr == nil { - return &net.UDPAddr{IP: dhcpv6.AllDHCPRelayAgentsAndServers, Port: dhcpv6.DefaultServerPort}, nil - } - - if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok { - return addr, nil - } - return nil, fmt.Errorf("Invalid remote address: %v not a net.UDPAddr", c.RemoteAddr) -} - -// Send inserts a message to the queue to be sent asynchronously. -// Returns a future which resolves to response and error. -func (c *Client) Send(message dhcpv6.DHCPv6, modifiers ...dhcpv6.Modifier) *promise.Future { - for _, mod := range modifiers { - mod(message) - } - - transactionID, err := dhcpv6.GetTransactionID(message) - if err != nil { - return promise.Wrap(err) - } - - p := promise.NewPromise() - c.packetsLock.Lock() - c.packets[transactionID] = p - c.packetsLock.Unlock() - c.sendQueue <- message - return p.Future -} diff --git a/dhcpv6/async/client_test.go b/dhcpv6/async/client_test.go deleted file mode 100644 index 14b8026..0000000 --- a/dhcpv6/async/client_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package async - -import ( - "context" - "net" - "testing" - "time" - - "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/insomniacslk/dhcp/dhcpv6/client6" - "github.com/stretchr/testify/require" -) - -const retries = 5 - -// solicit creates new solicit based on the mac address -func solicit(input string) (*dhcpv6.Message, error) { - mac, err := net.ParseMAC(input) - if err != nil { - return nil, err - } - return dhcpv6.NewSolicit(mac) -} - -// server creates a server which responds with a predefined response -func serve(ctx context.Context, addr *net.UDPAddr, response dhcpv6.DHCPv6) error { - conn, err := net.ListenUDP("udp6", addr) - if err != nil { - return err - } - go func() { - defer conn.Close() - oobdata := []byte{} - buffer := make([]byte, client6.MaxUDPReceivedPacketSize) - for { - select { - case <-ctx.Done(): - return - default: - if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { - panic(err) - } - n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) - if err != nil { - continue - } - _, err = dhcpv6.FromBytes(buffer[:n]) - if err != nil { - continue - } - if err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil { - panic(err) - } - _, err = conn.WriteTo(response.ToBytes(), src) - if err != nil { - continue - } - } - } - }() - return nil -} - -func TestNewClient(t *testing.T) { - c := NewClient() - require.NotNil(t, c) - require.Equal(t, c.ReadTimeout, client6.DefaultReadTimeout) - require.Equal(t, c.ReadTimeout, client6.DefaultWriteTimeout) -} - -func TestOpenInvalidAddrFailes(t *testing.T) { - c := NewClient() - err := c.Open(512) - require.Error(t, err) -} - -// This test uses port 15438 so please make sure its not used before running -func TestOpenClose(t *testing.T) { - c := NewClient() - addr, err := net.ResolveUDPAddr("udp6", ":15438") - require.NoError(t, err) - c.LocalAddr = addr - err = c.Open(512) - require.NoError(t, err) - defer c.Close() -} - -// This test uses ports 15438 and 15439 so please make sure they are not used -// before running -func TestSendTimeout(t *testing.T) { - c := NewClient() - addr, err := net.ResolveUDPAddr("udp6", ":15438") - require.NoError(t, err) - remote, err := net.ResolveUDPAddr("udp6", ":15439") - require.NoError(t, err) - c.ReadTimeout = 50 * time.Millisecond - c.WriteTimeout = 50 * time.Millisecond - c.LocalAddr = addr - c.RemoteAddr = remote - err = c.Open(512) - require.NoError(t, err) - defer c.Close() - m, err := dhcpv6.NewMessage() - require.NoError(t, err) - _, err, timeout := c.Send(m).GetOrTimeout(200) - require.NoError(t, err) - require.True(t, timeout) -} - -// This test uses ports 15438 and 15439 so please make sure they are not used -// before running -func TestSend(t *testing.T) { - s, err := solicit("c8:6c:2c:47:96:fd") - require.NoError(t, err) - require.NotNil(t, s) - - a, err := dhcpv6.NewAdvertiseFromSolicit(s) - require.NoError(t, err) - require.NotNil(t, a) - - c := NewClient() - addr, err := net.ResolveUDPAddr("udp6", ":15438") - require.NoError(t, err) - remote, err := net.ResolveUDPAddr("udp6", ":15439") - require.NoError(t, err) - c.LocalAddr = addr - c.RemoteAddr = remote - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - err = serve(ctx, remote, a) - require.NoError(t, err) - - err = c.Open(16) - require.NoError(t, err) - defer c.Close() - - f := c.Send(s) - - var passed bool - for i := 0; i < retries; i++ { - response, err, timeout := f.GetOrTimeout(1000) - if timeout { - continue - } - passed = true - require.NoError(t, err) - require.Equal(t, a, response) - } - require.True(t, passed, "All attempts to TestSend timed out") -} |