diff options
Diffstat (limited to 'dhcpv6')
-rw-r--r-- | dhcpv6/nclient6/client.go | 77 |
1 files changed, 56 insertions, 21 deletions
diff --git a/dhcpv6/nclient6/client.go b/dhcpv6/nclient6/client.go index 8705a6a..baec824 100644 --- a/dhcpv6/nclient6/client.go +++ b/dhcpv6/nclient6/client.go @@ -44,7 +44,7 @@ type pendingCh struct { done <-chan struct{} // ch is used by the receive loop to distribute DHCP messages. - ch chan<- *dhcpv6.Message + ch chan<- dhcpv6.DHCPv6 } // Client is a DHCPv6 client. @@ -84,13 +84,13 @@ type Client struct { type logger interface { Printf(format string, v ...interface{}) - PrintMessage(prefix string, message *dhcpv6.Message) + PrintMessage(prefix string, message dhcpv6.DHCPv6) } type emptyLogger struct{} func (e emptyLogger) Printf(format string, v ...interface{}) {} -func (e emptyLogger) PrintMessage(prefix string, message *dhcpv6.Message) {} +func (e emptyLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) {} type shortSummaryLogger struct { *log.Logger @@ -99,7 +99,7 @@ type shortSummaryLogger struct { func (s shortSummaryLogger) Printf(format string, v ...interface{}) { s.Logger.Printf(format, v...) } -func (s shortSummaryLogger) PrintMessage(prefix string, message *dhcpv6.Message) { +func (s shortSummaryLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) { s.Printf("%s: %s", prefix, message) } @@ -110,7 +110,7 @@ type debugLogger struct { func (d debugLogger) Printf(format string, v ...interface{}) { d.Logger.Printf(format, v...) } -func (d debugLogger) PrintMessage(prefix string, message *dhcpv6.Message) { +func (d debugLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) { d.Printf("%s: %s", prefix, message.Summary()) } @@ -218,7 +218,7 @@ func (c *Client) receiveLoop() { return } - msg, err := dhcpv6.MessageFromBytes(b[:n]) + msg, err := dhcpv6.FromBytes(b[:n]) if err != nil { // Not a valid DHCP packet; keep listening. if c.printDropped { @@ -230,13 +230,20 @@ func (c *Client) receiveLoop() { continue } + transactionID, err := dhcpv6.GetTransactionID(msg) + if err != nil { + if c.printDropped { + c.logger.Printf("Invalid RelayMessage message received: %s", msg) + } + continue + } c.pendingMu.Lock() - p, ok := c.pending[msg.TransactionID] + p, ok := c.pending[transactionID] if ok { select { case <-p.done: close(p.ch) - delete(c.pending, msg.TransactionID) + delete(c.pending, transactionID) // This send may block. case p.ch <- msg: @@ -311,16 +318,16 @@ func WithDebugLogger() ClientOpt { } // Matcher matches DHCP packets. -type Matcher func(*dhcpv6.Message) bool +type Matcher func(dhcpv6.DHCPv6) bool // IsMessageType returns a matcher that checks for the message type. func IsMessageType(t dhcpv6.MessageType, tt ...dhcpv6.MessageType) Matcher { - return func(p *dhcpv6.Message) bool { - if p.MessageType == t { + return func(p dhcpv6.DHCPv6) bool { + if p.Type() == t { return true } for _, mt := range tt { - if p.MessageType == mt { + if p.Type() == mt { return true } } @@ -355,7 +362,7 @@ func (c *Client) RapidSolicit(ctx context.Context, modifiers ...dhcpv6.Modifier) return nil, err } - switch msg.MessageType { + switch msg.Type() { case dhcpv6.MessageTypeReply: // We got RapidCommitted. return msg, nil @@ -398,16 +405,20 @@ func (c *Client) Request(ctx context.Context, advertise *dhcpv6.Message, modifie // received. // // Responses will be matched by transaction ID. -func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Message, func(), error) { +func (c *Client) send(dest net.Addr, msg dhcpv6.DHCPv6) (<-chan dhcpv6.DHCPv6, func(), error) { + transactionID, err := dhcpv6.GetTransactionID(msg) + if err != nil { + return nil, nil, err + } c.pendingMu.Lock() - if _, ok := c.pending[msg.TransactionID]; ok { + if _, ok := c.pending[transactionID]; ok { c.pendingMu.Unlock() - return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID) + return nil, nil, fmt.Errorf("transaction ID %s already in use", transactionID) } - ch := make(chan *dhcpv6.Message, c.bufferCap) + ch := make(chan dhcpv6.DHCPv6, c.bufferCap) done := make(chan struct{}) - c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch} + c.pending[transactionID] = &pendingCh{done: done, ch: ch} c.pendingMu.Unlock() cancel := func() { @@ -420,9 +431,9 @@ func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Messag close(done) c.pendingMu.Lock() - if p, ok := c.pending[msg.TransactionID]; ok { + if p, ok := c.pending[transactionID]; ok { close(p.ch) - delete(c.pending, msg.TransactionID) + delete(c.pending, transactionID) } c.pendingMu.Unlock() } @@ -442,7 +453,31 @@ var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded") // // If match is nil, the first packet matching the Transaction ID is returned. func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6.Message, match Matcher) (*dhcpv6.Message, error) { - var response *dhcpv6.Message + response, err := c.SendAndReadDHCPv6(ctx, dest, msg, match) + if err != nil { + return nil, err + } + responseMsg, ok := response.(*dhcpv6.Message) + if !ok { + return nil, fmt.Errorf("require Message response") + } + return responseMsg, nil +} + +func (c *Client) SendAndReadRelay(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6.RelayMessage, match Matcher) (*dhcpv6.RelayMessage, error) { + response, err := c.SendAndReadDHCPv6(ctx, dest, msg, match) + if err != nil { + return nil, err + } + responseRelay, ok := response.(*dhcpv6.RelayMessage) + if !ok { + return nil, fmt.Errorf("require RelayMessage response") + } + return responseRelay, nil +} + +func (c *Client) SendAndReadDHCPv6(ctx context.Context, dest *net.UDPAddr, msg dhcpv6.DHCPv6, match Matcher) (dhcpv6.DHCPv6, error) { + var response dhcpv6.DHCPv6 err := c.retryFn(func(timeout time.Duration) error { ch, rem, err := c.send(dest, msg) if err != nil { |