diff options
Diffstat (limited to 'dhcpv4/nclient4/client.go')
-rw-r--r-- | dhcpv4/nclient4/client.go | 244 |
1 files changed, 178 insertions, 66 deletions
diff --git a/dhcpv4/nclient4/client.go b/dhcpv4/nclient4/client.go index 2092e01..17af89c 100644 --- a/dhcpv4/nclient4/client.go +++ b/dhcpv4/nclient4/client.go @@ -18,7 +18,6 @@ import ( "log" "net" "os" - "strings" "sync" "sync/atomic" "time" @@ -27,10 +26,16 @@ import ( ) const ( - defaultTimeout = 5 * time.Second - defaultRetries = 3 defaultBufferCap = 5 - maxMessageSize = 1500 + + // DefaultTimeout is the default value for read-timeout if option WithTimeout is not set + DefaultTimeout = 5 * time.Second + + // DefaultRetries is amount of retries will be done if no answer was received within read-timeout amount of time + DefaultRetries = 3 + + // MaxMessageSize is the value to be used for DHCP option "MaxMessageSize". + MaxMessageSize = 1500 // ClientPort is the port that DHCP clients listen on. ClientPort = 68 @@ -51,6 +56,12 @@ var ( var ( // ErrNoResponse is returned when no response packet is received. ErrNoResponse = errors.New("no matching response packet received") + + // ErrNoConn is returned when NewWithConn is called with nil-value as conn. + ErrNoConn = errors.New("conn is nil") + + // ErrNoIfaceHWAddr is returned when NewWithConn is called with nil-value as ifaceHWAddr + ErrNoIfaceHWAddr = errors.New("ifaceHWAddr is nil") ) // pendingCh is a channel associated with a pending TransactionID. @@ -63,35 +74,61 @@ type pendingCh struct { ch chan<- *dhcpv4.DHCPv4 } -type logger interface { - Printf(format string, v ...interface{}) +// Logger is a handler which will be used to output logging messages +type Logger interface { + // PrintMessage print _all_ DHCP messages PrintMessage(prefix string, message *dhcpv4.DHCPv4) + + // Printf is use to print the rest debugging information + Printf(format string, v ...interface{}) } -type emptyLogger struct{} +// EmptyLogger prints nothing +type EmptyLogger struct{} + +// Printf is just a dummy function that does nothing +func (e EmptyLogger) Printf(format string, v ...interface{}) {} -func (e emptyLogger) Printf(format string, v ...interface{}) {} -func (e emptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {} +// PrintMessage is just a dummy function that does nothing +func (e EmptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {} -type shortSummaryLogger struct { - *log.Logger +// Printfer is used for actual output of the logger. For example *log.Logger is a Printfer. +type Printfer interface { + // Printf is the function for logging output. Arguments are handled in the manner of fmt.Printf. + Printf(format string, v ...interface{}) } -func (s shortSummaryLogger) Printf(format string, v ...interface{}) { - s.Logger.Printf(format, v...) +// ShortSummaryLogger is a wrapper for Printfer to implement interface Logger. +// DHCP messages are printed in the short format. +type ShortSummaryLogger struct { + // Printfer is used for actual output of the logger + Printfer } -func (s shortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { + +// Printf prints a log message as-is via predefined Printfer +func (s ShortSummaryLogger) Printf(format string, v ...interface{}) { + s.Printfer.Printf(format, v...) +} + +// PrintMessage prints a DHCP message in the short format via predefined Printfer +func (s ShortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { s.Printf("%s: %s", prefix, message) } -type debugLogger struct { - *log.Logger +// DebugLogger is a wrapper for Printfer to implement interface Logger. +// DHCP messages are printed in the long format. +type DebugLogger struct { + // Printfer is used for actual output of the logger + Printfer } -func (d debugLogger) Printf(format string, v ...interface{}) { - d.Logger.Printf(format, v...) +// Printf prints a log message as-is via predefined Printfer +func (d DebugLogger) Printf(format string, v ...interface{}) { + d.Printfer.Printf(format, v...) } -func (d debugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { + +// PrintMessage prints a DHCP message in the long format via predefined Printfer +func (d DebugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { d.Printf("%s: %s", prefix, message.Summary()) } @@ -101,7 +138,7 @@ type Client struct { conn net.PacketConn timeout time.Duration retry int - logger logger + logger Logger // bufferCap is the channel capacity for each TransactionID. bufferCap int @@ -129,39 +166,58 @@ type Client struct { // New returns a client usable with an unconfigured interface. func New(iface string, opts ...ClientOpt) (*Client, error) { - i, err := net.InterfaceByName(iface) - if err != nil { - return nil, err - } - pc, err := NewRawUDPConn(iface, ClientPort) - if err != nil { - return nil, err - } - return NewWithConn(pc, i.HardwareAddr, opts...) + return new(iface, nil, nil, opts...) } // NewWithConn creates a new DHCP client that sends and receives packets on the // given interface. func NewWithConn(conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) { + return new(``, conn, ifaceHWAddr, opts...) +} + +func new(iface string, conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) { c := &Client{ ifaceHWAddr: ifaceHWAddr, - timeout: defaultTimeout, - retry: defaultRetries, + timeout: DefaultTimeout, + retry: DefaultRetries, serverAddr: DefaultServers, bufferCap: defaultBufferCap, conn: conn, - logger: emptyLogger{}, + logger: EmptyLogger{}, done: make(chan struct{}), pending: make(map[dhcpv4.TransactionID]*pendingCh), } for _, opt := range opts { - opt(c) + err := opt(c) + if err != nil { + return nil, fmt.Errorf("unable to apply option: %w", err) + } + } + + if c.ifaceHWAddr == nil { + if iface == `` { + return nil, ErrNoIfaceHWAddr + } + + i, err := net.InterfaceByName(iface) + if err != nil { + return nil, fmt.Errorf("unable to get interface information: %w", err) + } + + c.ifaceHWAddr = i.HardwareAddr } if c.conn == nil { - return nil, fmt.Errorf("no connection given") + var err error + if iface == `` { + return nil, ErrNoConn + } + c.conn, err = NewRawUDPConn(iface, ClientPort) // broadcast + if err != nil { + return nil, fmt.Errorf("unable to open a broadcasting socket: %w", err) + } } c.wg.Add(1) go c.receiveLoop() @@ -193,10 +249,8 @@ func (c *Client) Close() error { return err } -func isErrClosing(err error) bool { - // Unfortunately, the epoll-connection-closed error is internal to the - // net library. - return strings.Contains(err.Error(), "use of closed network connection") +func (c *Client) isClosed() bool { + return atomic.LoadUint32(&c.closed) != 0 } func (c *Client) receiveLoop() { @@ -204,10 +258,10 @@ func (c *Client) receiveLoop() { for { // TODO: Clients can send a "max packet size" option in their // packets, IIRC. Choose a reasonable size and set it. - b := make([]byte, maxMessageSize) + b := make([]byte, MaxMessageSize) n, _, err := c.conn.ReadFrom(b) if err != nil { - if !isErrClosing(err) { + if !c.isClosed() { c.logger.Printf("error reading from UDP connection: %v", err) } return @@ -249,38 +303,69 @@ func (c *Client) receiveLoop() { } // ClientOpt is a function that configures the Client. -type ClientOpt func(*Client) +type ClientOpt func(c *Client) error // WithTimeout configures the retransmission timeout. // // Default is 5 seconds. func WithTimeout(d time.Duration) ClientOpt { - return func(c *Client) { + return func(c *Client) (err error) { c.timeout = d + return } } -// WithSummaryLogger logs one-line DHCPv4 message summarys when sent & received. +// WithSummaryLogger logs one-line DHCPv4 message summaries when sent & received. func WithSummaryLogger() ClientOpt { - return func(c *Client) { - c.logger = shortSummaryLogger{ - Logger: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), + return func(c *Client) (err error) { + c.logger = ShortSummaryLogger{ + Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), } + return } } // WithDebugLogger logs multi-line full DHCPv4 messages when sent & received. func WithDebugLogger() ClientOpt { - return func(c *Client) { - c.logger = debugLogger{ - Logger: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), + return func(c *Client) (err error) { + c.logger = DebugLogger{ + Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), } + return + } +} + +// WithLogger set the logger (see interface Logger). +func WithLogger(newLogger Logger) ClientOpt { + return func(c *Client) (err error) { + c.logger = newLogger + return + } +} + +// WithUnicast forces client to send messages as unicast frames. +// By default client sends messages as broadcast frames even if server address is defined. +// +// srcAddr is both: +// * The source address of outgoing frames. +// * The address to be listened for incoming frames. +func WithUnicast(srcAddr *net.UDPAddr) ClientOpt { + return func(c *Client) (err error) { + if srcAddr == nil { + srcAddr = &net.UDPAddr{Port: ServerPort} + } + c.conn, err = net.ListenUDP("udp4", srcAddr) + if err != nil { + err = fmt.Errorf("unable to start listening UDP port: %w", err) + } + return } } func withBufferCap(n int) ClientOpt { - return func(c *Client) { + return func(c *Client) (err error) { c.bufferCap = n + return } } @@ -288,15 +373,17 @@ func withBufferCap(n int) ClientOpt { // // Default is 3. func WithRetry(r int) ClientOpt { - return func(c *Client) { + return func(c *Client) (err error) { c.retry = r + return } } // WithServerAddr configures the address to send messages to. func WithServerAddr(n *net.UDPAddr) ClientOpt { - return func(c *Client) { + return func(c *Client) (err error) { c.serverAddr = n + return } } @@ -314,15 +401,23 @@ func IsMessageType(t dhcpv4.MessageType) Matcher { // DiscoverOffer sends a DHCPDiscover message and returns the first valid offer // received. -func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (*dhcpv4.DHCPv4, error) { +func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer *dhcpv4.DHCPv4, err error) { // RFC 2131, Section 4.4.1, Table 5 details what a DISCOVER packet should // contain. discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers, - dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(maxMessageSize)))...) + dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...) if err != nil { - return nil, err + err = fmt.Errorf("unable to create a discovery request: %w", err) + return } - return c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer)) + + offer, err = c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer)) + if err != nil { + err = fmt.Errorf("got an error while the discovery request: %w", err) + return + } + + return } // Request completes the 4-way Discover-Offer-Request-Ack handshake. @@ -331,20 +426,37 @@ func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer, ack *dhcpv4.DHCPv4, err error) { offer, err = c.DiscoverOffer(ctx, modifiers...) if err != nil { - return nil, nil, err + err = fmt.Errorf("unable to receive an offer: %w", err) + return } // TODO(chrisko): should this be unicast to the server? - req, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers, - dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(maxMessageSize)))...) + request, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers, + dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...) if err != nil { - return nil, nil, err + err = fmt.Errorf("unable to create a request: %w", err) + return } - ack, err = c.SendAndRead(ctx, c.serverAddr, req, nil) + + ack, err = c.SendAndRead(ctx, c.serverAddr, request, nil) if err != nil { - return nil, nil, err + err = fmt.Errorf("got an error while processing the request: %w", err) + return } - return offer, ack, nil + + return +} + +// ErrTransactionIDInUse is returned if there were an attempt to send a message +// with the same TransactionID as we are already waiting an answer for. +type ErrTransactionIDInUse struct { + // TransactionID is the transaction ID of the message which the error is related to. + TransactionID dhcpv4.TransactionID +} + +// Error is just the method to comply interface "error". +func (err *ErrTransactionIDInUse) Error() string { + return fmt.Sprintf("transaction ID %s already in use", err.TransactionID) } // send sends p to destination and returns a response channel. @@ -357,7 +469,7 @@ func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv c.pendingMu.Lock() if _, ok := c.pending[msg.TransactionID]; ok { c.pendingMu.Unlock() - return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID) + return nil, nil, &ErrTransactionIDInUse{msg.TransactionID} } ch := make(chan *dhcpv4.DHCPv4, c.bufferCap) @@ -384,7 +496,7 @@ func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil { cancel() - return nil, nil, fmt.Errorf("error writing packet to connection: %v", err) + return nil, nil, fmt.Errorf("error writing packet to connection: %w", err) } return ch, cancel, nil } @@ -441,7 +553,7 @@ func (c *Client) retryFn(fn func(timeout time.Duration) error) error { timeout := c.timeout // Each retry takes the amount of timeout at worst. - for i := 0; i < c.retry || c.retry < 0; i++ { + for i := 0; i < c.retry || c.retry < 0; i++ { // TODO: why is this called "retry" if this is "tries" ("retries"+1)? switch err := fn(timeout); err { case nil: // Got it! |