package async import ( "context" "fmt" "net" "sync" "time" "github.com/fanliao/go-promise" "github.com/insomniacslk/dhcp/dhcpv6" ) // Client implements an asynchronous DHCPv6 client type Client struct { ReadTimeout time.Duration WriteTimeout time.Duration LocalAddr net.Addr RemoteAddr net.Addr IgnoreErrors bool connection *net.UDPConn cancel context.CancelFunc stopping *sync.WaitGroup receiveQueue chan dhcpv6.DHCPv6 sendQueue chan dhcpv6.DHCPv6 packetsLock sync.Mutex packets map[uint32]*promise.Promise errors chan error } // NewClient creates an asynchronous client func NewClient() *Client { return &Client{ ReadTimeout: dhcpv6.DefaultReadTimeout, WriteTimeout: dhcpv6.DefaultWriteTimeout, } } // OpenForInterface starts the client on the specified interface, replacing // client LocalAddr with a link-local address of the given interface and // standard DHCP port (546). func (c *Client) OpenForInterface(ifname string, bufferSize int) error { addr, err := dhcpv6.GetLinkLocalAddr(ifname) if err != nil { return err } c.LocalAddr = &net.UDPAddr{IP: addr, Port: dhcpv6.DefaultClientPort, Zone: ifname} return c.Open(bufferSize) } // Open starts the client func (c *Client) Open(bufferSize int) error { var ( addr *net.UDPAddr ok bool err error ) if addr, ok = c.LocalAddr.(*net.UDPAddr); !ok { return fmt.Errorf("Invalid local address: %v not a net.UDPAddr", c.LocalAddr) } // prepare the socket to listen on for replies c.connection, err = net.ListenUDP("udp6", addr) if err != nil { return err } c.stopping = new(sync.WaitGroup) c.sendQueue = make(chan dhcpv6.DHCPv6, bufferSize) c.receiveQueue = make(chan dhcpv6.DHCPv6, bufferSize) c.packets = make(map[uint32]*promise.Promise) c.packetsLock = sync.Mutex{} c.errors = make(chan error) var ctx context.Context ctx, c.cancel = context.WithCancel(context.Background()) go c.receiverLoop(ctx) go c.senderLoop(ctx) return nil } // Close stops the client func (c *Client) Close() { // Wait for sender and receiver loops c.stopping.Add(2) c.cancel() c.stopping.Wait() close(c.sendQueue) close(c.receiveQueue) close(c.errors) c.connection.Close() } // Errors returns a channel where runtime errors are posted func (c *Client) Errors() <-chan error { return c.errors } func (c *Client) addError(err error) { if !c.IgnoreErrors { c.errors <- err } } func (c *Client) receiverLoop(ctx context.Context) { defer func() { c.stopping.Done() }() for { select { case <-ctx.Done(): return case packet := <-c.receiveQueue: c.receive(packet) } } } func (c *Client) senderLoop(ctx context.Context) { defer func() { c.stopping.Done() }() for { select { case <-ctx.Done(): return case packet := <-c.sendQueue: c.send(packet) } } } func (c *Client) send(packet dhcpv6.DHCPv6) { transactionID, err := dhcpv6.GetTransactionID(packet) if err != nil { c.addError(fmt.Errorf("Warning: This should never happen, there is no transaction ID on %s", packet)) return } c.packetsLock.Lock() p := c.packets[transactionID] c.packetsLock.Unlock() raddr, err := c.remoteAddr() if err != nil { p.Reject(err) return } c.connection.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) _, err = c.connection.WriteTo(packet.ToBytes(), raddr) if err != nil { p.Reject(err) return } c.receiveQueue <- packet } func (c *Client) receive(_ dhcpv6.DHCPv6) { var ( oobdata = []byte{} received dhcpv6.DHCPv6 ) c.connection.SetReadDeadline(time.Now().Add(c.ReadTimeout)) for { buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize) n, _, _, _, err := c.connection.ReadMsgUDP(buffer, oobdata) if err != nil { if err, ok := err.(net.Error); !ok || !err.Timeout() { c.addError(fmt.Errorf("Error receiving the message: %s", err)) } return } received, err = dhcpv6.FromBytes(buffer[:n]) if err != nil { // skip non-DHCP packets continue } break } transactionID, err := dhcpv6.GetTransactionID(received) if err != nil { c.addError(fmt.Errorf("Unable to get a transactionID for %s: %s", received, err)) return } c.packetsLock.Lock() if p, ok := c.packets[transactionID]; ok { delete(c.packets, transactionID) p.Resolve(received) } c.packetsLock.Unlock() } func (c *Client) remoteAddr() (*net.UDPAddr, error) { if c.RemoteAddr == nil { return &net.UDPAddr{IP: dhcpv6.AllDHCPRelayAgentsAndServers, Port: dhcpv6.DefaultServerPort}, nil } if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok { return addr, nil } return nil, fmt.Errorf("Invalid remote address: %v not a net.UDPAddr", c.RemoteAddr) } // Send inserts a message to the queue to be sent asynchronously. // Returns a future which resolves to response and error. func (c *Client) Send(message dhcpv6.DHCPv6, modifiers ...dhcpv6.Modifier) *promise.Future { for _, mod := range modifiers { message = mod(message) } transactionID, err := dhcpv6.GetTransactionID(message) if err != nil { return promise.Wrap(err) } p := promise.NewPromise() c.packetsLock.Lock() c.packets[transactionID] = p c.packetsLock.Unlock() c.sendQueue <- message return p.Future }