summaryrefslogtreecommitdiffhomepage
path: root/dhcpv6/nclient6/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv6/nclient6/client.go')
-rw-r--r--dhcpv6/nclient6/client.go77
1 files changed, 56 insertions, 21 deletions
diff --git a/dhcpv6/nclient6/client.go b/dhcpv6/nclient6/client.go
index 8705a6a..baec824 100644
--- a/dhcpv6/nclient6/client.go
+++ b/dhcpv6/nclient6/client.go
@@ -44,7 +44,7 @@ type pendingCh struct {
done <-chan struct{}
// ch is used by the receive loop to distribute DHCP messages.
- ch chan<- *dhcpv6.Message
+ ch chan<- dhcpv6.DHCPv6
}
// Client is a DHCPv6 client.
@@ -84,13 +84,13 @@ type Client struct {
type logger interface {
Printf(format string, v ...interface{})
- PrintMessage(prefix string, message *dhcpv6.Message)
+ PrintMessage(prefix string, message dhcpv6.DHCPv6)
}
type emptyLogger struct{}
func (e emptyLogger) Printf(format string, v ...interface{}) {}
-func (e emptyLogger) PrintMessage(prefix string, message *dhcpv6.Message) {}
+func (e emptyLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) {}
type shortSummaryLogger struct {
*log.Logger
@@ -99,7 +99,7 @@ type shortSummaryLogger struct {
func (s shortSummaryLogger) Printf(format string, v ...interface{}) {
s.Logger.Printf(format, v...)
}
-func (s shortSummaryLogger) PrintMessage(prefix string, message *dhcpv6.Message) {
+func (s shortSummaryLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) {
s.Printf("%s: %s", prefix, message)
}
@@ -110,7 +110,7 @@ type debugLogger struct {
func (d debugLogger) Printf(format string, v ...interface{}) {
d.Logger.Printf(format, v...)
}
-func (d debugLogger) PrintMessage(prefix string, message *dhcpv6.Message) {
+func (d debugLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) {
d.Printf("%s: %s", prefix, message.Summary())
}
@@ -218,7 +218,7 @@ func (c *Client) receiveLoop() {
return
}
- msg, err := dhcpv6.MessageFromBytes(b[:n])
+ msg, err := dhcpv6.FromBytes(b[:n])
if err != nil {
// Not a valid DHCP packet; keep listening.
if c.printDropped {
@@ -230,13 +230,20 @@ func (c *Client) receiveLoop() {
continue
}
+ transactionID, err := dhcpv6.GetTransactionID(msg)
+ if err != nil {
+ if c.printDropped {
+ c.logger.Printf("Invalid RelayMessage message received: %s", msg)
+ }
+ continue
+ }
c.pendingMu.Lock()
- p, ok := c.pending[msg.TransactionID]
+ p, ok := c.pending[transactionID]
if ok {
select {
case <-p.done:
close(p.ch)
- delete(c.pending, msg.TransactionID)
+ delete(c.pending, transactionID)
// This send may block.
case p.ch <- msg:
@@ -311,16 +318,16 @@ func WithDebugLogger() ClientOpt {
}
// Matcher matches DHCP packets.
-type Matcher func(*dhcpv6.Message) bool
+type Matcher func(dhcpv6.DHCPv6) bool
// IsMessageType returns a matcher that checks for the message type.
func IsMessageType(t dhcpv6.MessageType, tt ...dhcpv6.MessageType) Matcher {
- return func(p *dhcpv6.Message) bool {
- if p.MessageType == t {
+ return func(p dhcpv6.DHCPv6) bool {
+ if p.Type() == t {
return true
}
for _, mt := range tt {
- if p.MessageType == mt {
+ if p.Type() == mt {
return true
}
}
@@ -355,7 +362,7 @@ func (c *Client) RapidSolicit(ctx context.Context, modifiers ...dhcpv6.Modifier)
return nil, err
}
- switch msg.MessageType {
+ switch msg.Type() {
case dhcpv6.MessageTypeReply:
// We got RapidCommitted.
return msg, nil
@@ -398,16 +405,20 @@ func (c *Client) Request(ctx context.Context, advertise *dhcpv6.Message, modifie
// received.
//
// Responses will be matched by transaction ID.
-func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Message, func(), error) {
+func (c *Client) send(dest net.Addr, msg dhcpv6.DHCPv6) (<-chan dhcpv6.DHCPv6, func(), error) {
+ transactionID, err := dhcpv6.GetTransactionID(msg)
+ if err != nil {
+ return nil, nil, err
+ }
c.pendingMu.Lock()
- if _, ok := c.pending[msg.TransactionID]; ok {
+ if _, ok := c.pending[transactionID]; ok {
c.pendingMu.Unlock()
- return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID)
+ return nil, nil, fmt.Errorf("transaction ID %s already in use", transactionID)
}
- ch := make(chan *dhcpv6.Message, c.bufferCap)
+ ch := make(chan dhcpv6.DHCPv6, c.bufferCap)
done := make(chan struct{})
- c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch}
+ c.pending[transactionID] = &pendingCh{done: done, ch: ch}
c.pendingMu.Unlock()
cancel := func() {
@@ -420,9 +431,9 @@ func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Messag
close(done)
c.pendingMu.Lock()
- if p, ok := c.pending[msg.TransactionID]; ok {
+ if p, ok := c.pending[transactionID]; ok {
close(p.ch)
- delete(c.pending, msg.TransactionID)
+ delete(c.pending, transactionID)
}
c.pendingMu.Unlock()
}
@@ -442,7 +453,31 @@ var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded")
//
// If match is nil, the first packet matching the Transaction ID is returned.
func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6.Message, match Matcher) (*dhcpv6.Message, error) {
- var response *dhcpv6.Message
+ response, err := c.SendAndReadDHCPv6(ctx, dest, msg, match)
+ if err != nil {
+ return nil, err
+ }
+ responseMsg, ok := response.(*dhcpv6.Message)
+ if !ok {
+ return nil, fmt.Errorf("require Message response")
+ }
+ return responseMsg, nil
+}
+
+func (c *Client) SendAndReadRelay(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6.RelayMessage, match Matcher) (*dhcpv6.RelayMessage, error) {
+ response, err := c.SendAndReadDHCPv6(ctx, dest, msg, match)
+ if err != nil {
+ return nil, err
+ }
+ responseRelay, ok := response.(*dhcpv6.RelayMessage)
+ if !ok {
+ return nil, fmt.Errorf("require RelayMessage response")
+ }
+ return responseRelay, nil
+}
+
+func (c *Client) SendAndReadDHCPv6(ctx context.Context, dest *net.UDPAddr, msg dhcpv6.DHCPv6, match Matcher) (dhcpv6.DHCPv6, error) {
+ var response dhcpv6.DHCPv6
err := c.retryFn(func(timeout time.Duration) error {
ch, rem, err := c.send(dest, msg)
if err != nil {