summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDmitrii Okunev <xaionaro@fb.com>2019-11-28 12:10:49 +0000
committerDmitrii Okunev <xaionaro@fb.com>2019-11-28 12:10:49 +0000
commit92b156c5580501ec2a4e8a504edbc3db55d7df82 (patch)
tree3161f6119c58160b89a324804c39ede08784c51c
parentec0e0154d15c429b3dcc56af60a2d7b62eb3d5e7 (diff)
Simplified porting from client4 to nclient4
Signed-off-by: Dmitrii Okunev <xaionaro@fb.com>
-rw-r--r--.travis.yml1
-rwxr-xr-x.travis/tests.sh13
-rw-r--r--dhcpv4/client4/client.go1
-rw-r--r--dhcpv4/nclient4/client.go244
-rw-r--r--dhcpv4/nclient4/conn_linux.go9
5 files changed, 198 insertions, 70 deletions
diff --git a/.travis.yml b/.travis.yml
index 14d312e..f257743 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -5,6 +5,7 @@ sudo: required
go:
- "1.11"
- "1.12"
+ - "1.13"
- tip
env:
diff --git a/.travis/tests.sh b/.travis/tests.sh
index b48a3c2..edf17ba 100755
--- a/.travis/tests.sh
+++ b/.travis/tests.sh
@@ -10,14 +10,23 @@ echo "" > coverage.txt
# tests.
ip a
+GO_TEST_OPTS=()
+if [[ "$TRAVIS_GO_VERSION" =~ ^1.(9|10|11|12)$ ]]
+then
+ # We use fmt.Errorf with verb "%w" which appeared only in Go1.13.
+ # So the code compiles and works with Go1.12, but error descriptions
+ # looks uglier and it does not pass "vet" tests on Go<1.13.
+ GO_TEST_OPTS+='-vet=off'
+fi
+
for d in $(go list ./... | grep -v vendor); do
- go test -race -coverprofile=profile.out -covermode=atomic $d
+ go test -race -coverprofile=profile.out -covermode=atomic ${GO_TEST_OPTS[@]} $d
if [ -f profile.out ]; then
cat profile.out >> coverage.txt
rm profile.out
fi
# integration tests
- go test -c -cover -tags=integration -race -covermode=atomic $d
+ go test -c -cover -tags=integration -race -covermode=atomic ${GO_TEST_OPTS[@]} $d
testbin="./$(basename $d).test"
# only run it if it was built - i.e. if there are integ tests
test -x "${testbin}" && sudo "./${testbin}" -test.coverprofile=profile.out
diff --git a/dhcpv4/client4/client.go b/dhcpv4/client4/client.go
index 85d6152..5c5c18d 100644
--- a/dhcpv4/client4/client.go
+++ b/dhcpv4/client4/client.go
@@ -1,3 +1,4 @@
+// Package client4 is deprecated. Use "nclient4" instead.
package client4
import (
diff --git a/dhcpv4/nclient4/client.go b/dhcpv4/nclient4/client.go
index 2092e01..17af89c 100644
--- a/dhcpv4/nclient4/client.go
+++ b/dhcpv4/nclient4/client.go
@@ -18,7 +18,6 @@ import (
"log"
"net"
"os"
- "strings"
"sync"
"sync/atomic"
"time"
@@ -27,10 +26,16 @@ import (
)
const (
- defaultTimeout = 5 * time.Second
- defaultRetries = 3
defaultBufferCap = 5
- maxMessageSize = 1500
+
+ // DefaultTimeout is the default value for read-timeout if option WithTimeout is not set
+ DefaultTimeout = 5 * time.Second
+
+ // DefaultRetries is amount of retries will be done if no answer was received within read-timeout amount of time
+ DefaultRetries = 3
+
+ // MaxMessageSize is the value to be used for DHCP option "MaxMessageSize".
+ MaxMessageSize = 1500
// ClientPort is the port that DHCP clients listen on.
ClientPort = 68
@@ -51,6 +56,12 @@ var (
var (
// ErrNoResponse is returned when no response packet is received.
ErrNoResponse = errors.New("no matching response packet received")
+
+ // ErrNoConn is returned when NewWithConn is called with nil-value as conn.
+ ErrNoConn = errors.New("conn is nil")
+
+ // ErrNoIfaceHWAddr is returned when NewWithConn is called with nil-value as ifaceHWAddr
+ ErrNoIfaceHWAddr = errors.New("ifaceHWAddr is nil")
)
// pendingCh is a channel associated with a pending TransactionID.
@@ -63,35 +74,61 @@ type pendingCh struct {
ch chan<- *dhcpv4.DHCPv4
}
-type logger interface {
- Printf(format string, v ...interface{})
+// Logger is a handler which will be used to output logging messages
+type Logger interface {
+ // PrintMessage print _all_ DHCP messages
PrintMessage(prefix string, message *dhcpv4.DHCPv4)
+
+ // Printf is use to print the rest debugging information
+ Printf(format string, v ...interface{})
}
-type emptyLogger struct{}
+// EmptyLogger prints nothing
+type EmptyLogger struct{}
+
+// Printf is just a dummy function that does nothing
+func (e EmptyLogger) Printf(format string, v ...interface{}) {}
-func (e emptyLogger) Printf(format string, v ...interface{}) {}
-func (e emptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {}
+// PrintMessage is just a dummy function that does nothing
+func (e EmptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {}
-type shortSummaryLogger struct {
- *log.Logger
+// Printfer is used for actual output of the logger. For example *log.Logger is a Printfer.
+type Printfer interface {
+ // Printf is the function for logging output. Arguments are handled in the manner of fmt.Printf.
+ Printf(format string, v ...interface{})
}
-func (s shortSummaryLogger) Printf(format string, v ...interface{}) {
- s.Logger.Printf(format, v...)
+// ShortSummaryLogger is a wrapper for Printfer to implement interface Logger.
+// DHCP messages are printed in the short format.
+type ShortSummaryLogger struct {
+ // Printfer is used for actual output of the logger
+ Printfer
}
-func (s shortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
+
+// Printf prints a log message as-is via predefined Printfer
+func (s ShortSummaryLogger) Printf(format string, v ...interface{}) {
+ s.Printfer.Printf(format, v...)
+}
+
+// PrintMessage prints a DHCP message in the short format via predefined Printfer
+func (s ShortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
s.Printf("%s: %s", prefix, message)
}
-type debugLogger struct {
- *log.Logger
+// DebugLogger is a wrapper for Printfer to implement interface Logger.
+// DHCP messages are printed in the long format.
+type DebugLogger struct {
+ // Printfer is used for actual output of the logger
+ Printfer
}
-func (d debugLogger) Printf(format string, v ...interface{}) {
- d.Logger.Printf(format, v...)
+// Printf prints a log message as-is via predefined Printfer
+func (d DebugLogger) Printf(format string, v ...interface{}) {
+ d.Printfer.Printf(format, v...)
}
-func (d debugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
+
+// PrintMessage prints a DHCP message in the long format via predefined Printfer
+func (d DebugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
d.Printf("%s: %s", prefix, message.Summary())
}
@@ -101,7 +138,7 @@ type Client struct {
conn net.PacketConn
timeout time.Duration
retry int
- logger logger
+ logger Logger
// bufferCap is the channel capacity for each TransactionID.
bufferCap int
@@ -129,39 +166,58 @@ type Client struct {
// New returns a client usable with an unconfigured interface.
func New(iface string, opts ...ClientOpt) (*Client, error) {
- i, err := net.InterfaceByName(iface)
- if err != nil {
- return nil, err
- }
- pc, err := NewRawUDPConn(iface, ClientPort)
- if err != nil {
- return nil, err
- }
- return NewWithConn(pc, i.HardwareAddr, opts...)
+ return new(iface, nil, nil, opts...)
}
// NewWithConn creates a new DHCP client that sends and receives packets on the
// given interface.
func NewWithConn(conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
+ return new(``, conn, ifaceHWAddr, opts...)
+}
+
+func new(iface string, conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
c := &Client{
ifaceHWAddr: ifaceHWAddr,
- timeout: defaultTimeout,
- retry: defaultRetries,
+ timeout: DefaultTimeout,
+ retry: DefaultRetries,
serverAddr: DefaultServers,
bufferCap: defaultBufferCap,
conn: conn,
- logger: emptyLogger{},
+ logger: EmptyLogger{},
done: make(chan struct{}),
pending: make(map[dhcpv4.TransactionID]*pendingCh),
}
for _, opt := range opts {
- opt(c)
+ err := opt(c)
+ if err != nil {
+ return nil, fmt.Errorf("unable to apply option: %w", err)
+ }
+ }
+
+ if c.ifaceHWAddr == nil {
+ if iface == `` {
+ return nil, ErrNoIfaceHWAddr
+ }
+
+ i, err := net.InterfaceByName(iface)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get interface information: %w", err)
+ }
+
+ c.ifaceHWAddr = i.HardwareAddr
}
if c.conn == nil {
- return nil, fmt.Errorf("no connection given")
+ var err error
+ if iface == `` {
+ return nil, ErrNoConn
+ }
+ c.conn, err = NewRawUDPConn(iface, ClientPort) // broadcast
+ if err != nil {
+ return nil, fmt.Errorf("unable to open a broadcasting socket: %w", err)
+ }
}
c.wg.Add(1)
go c.receiveLoop()
@@ -193,10 +249,8 @@ func (c *Client) Close() error {
return err
}
-func isErrClosing(err error) bool {
- // Unfortunately, the epoll-connection-closed error is internal to the
- // net library.
- return strings.Contains(err.Error(), "use of closed network connection")
+func (c *Client) isClosed() bool {
+ return atomic.LoadUint32(&c.closed) != 0
}
func (c *Client) receiveLoop() {
@@ -204,10 +258,10 @@ func (c *Client) receiveLoop() {
for {
// TODO: Clients can send a "max packet size" option in their
// packets, IIRC. Choose a reasonable size and set it.
- b := make([]byte, maxMessageSize)
+ b := make([]byte, MaxMessageSize)
n, _, err := c.conn.ReadFrom(b)
if err != nil {
- if !isErrClosing(err) {
+ if !c.isClosed() {
c.logger.Printf("error reading from UDP connection: %v", err)
}
return
@@ -249,38 +303,69 @@ func (c *Client) receiveLoop() {
}
// ClientOpt is a function that configures the Client.
-type ClientOpt func(*Client)
+type ClientOpt func(c *Client) error
// WithTimeout configures the retransmission timeout.
//
// Default is 5 seconds.
func WithTimeout(d time.Duration) ClientOpt {
- return func(c *Client) {
+ return func(c *Client) (err error) {
c.timeout = d
+ return
}
}
-// WithSummaryLogger logs one-line DHCPv4 message summarys when sent & received.
+// WithSummaryLogger logs one-line DHCPv4 message summaries when sent & received.
func WithSummaryLogger() ClientOpt {
- return func(c *Client) {
- c.logger = shortSummaryLogger{
- Logger: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
+ return func(c *Client) (err error) {
+ c.logger = ShortSummaryLogger{
+ Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
}
+ return
}
}
// WithDebugLogger logs multi-line full DHCPv4 messages when sent & received.
func WithDebugLogger() ClientOpt {
- return func(c *Client) {
- c.logger = debugLogger{
- Logger: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
+ return func(c *Client) (err error) {
+ c.logger = DebugLogger{
+ Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
}
+ return
+ }
+}
+
+// WithLogger set the logger (see interface Logger).
+func WithLogger(newLogger Logger) ClientOpt {
+ return func(c *Client) (err error) {
+ c.logger = newLogger
+ return
+ }
+}
+
+// WithUnicast forces client to send messages as unicast frames.
+// By default client sends messages as broadcast frames even if server address is defined.
+//
+// srcAddr is both:
+// * The source address of outgoing frames.
+// * The address to be listened for incoming frames.
+func WithUnicast(srcAddr *net.UDPAddr) ClientOpt {
+ return func(c *Client) (err error) {
+ if srcAddr == nil {
+ srcAddr = &net.UDPAddr{Port: ServerPort}
+ }
+ c.conn, err = net.ListenUDP("udp4", srcAddr)
+ if err != nil {
+ err = fmt.Errorf("unable to start listening UDP port: %w", err)
+ }
+ return
}
}
func withBufferCap(n int) ClientOpt {
- return func(c *Client) {
+ return func(c *Client) (err error) {
c.bufferCap = n
+ return
}
}
@@ -288,15 +373,17 @@ func withBufferCap(n int) ClientOpt {
//
// Default is 3.
func WithRetry(r int) ClientOpt {
- return func(c *Client) {
+ return func(c *Client) (err error) {
c.retry = r
+ return
}
}
// WithServerAddr configures the address to send messages to.
func WithServerAddr(n *net.UDPAddr) ClientOpt {
- return func(c *Client) {
+ return func(c *Client) (err error) {
c.serverAddr = n
+ return
}
}
@@ -314,15 +401,23 @@ func IsMessageType(t dhcpv4.MessageType) Matcher {
// DiscoverOffer sends a DHCPDiscover message and returns the first valid offer
// received.
-func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (*dhcpv4.DHCPv4, error) {
+func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer *dhcpv4.DHCPv4, err error) {
// RFC 2131, Section 4.4.1, Table 5 details what a DISCOVER packet should
// contain.
discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers,
- dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(maxMessageSize)))...)
+ dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...)
if err != nil {
- return nil, err
+ err = fmt.Errorf("unable to create a discovery request: %w", err)
+ return
}
- return c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer))
+
+ offer, err = c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer))
+ if err != nil {
+ err = fmt.Errorf("got an error while the discovery request: %w", err)
+ return
+ }
+
+ return
}
// Request completes the 4-way Discover-Offer-Request-Ack handshake.
@@ -331,20 +426,37 @@ func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier
func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer, ack *dhcpv4.DHCPv4, err error) {
offer, err = c.DiscoverOffer(ctx, modifiers...)
if err != nil {
- return nil, nil, err
+ err = fmt.Errorf("unable to receive an offer: %w", err)
+ return
}
// TODO(chrisko): should this be unicast to the server?
- req, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers,
- dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(maxMessageSize)))...)
+ request, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers,
+ dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...)
if err != nil {
- return nil, nil, err
+ err = fmt.Errorf("unable to create a request: %w", err)
+ return
}
- ack, err = c.SendAndRead(ctx, c.serverAddr, req, nil)
+
+ ack, err = c.SendAndRead(ctx, c.serverAddr, request, nil)
if err != nil {
- return nil, nil, err
+ err = fmt.Errorf("got an error while processing the request: %w", err)
+ return
}
- return offer, ack, nil
+
+ return
+}
+
+// ErrTransactionIDInUse is returned if there were an attempt to send a message
+// with the same TransactionID as we are already waiting an answer for.
+type ErrTransactionIDInUse struct {
+ // TransactionID is the transaction ID of the message which the error is related to.
+ TransactionID dhcpv4.TransactionID
+}
+
+// Error is just the method to comply interface "error".
+func (err *ErrTransactionIDInUse) Error() string {
+ return fmt.Sprintf("transaction ID %s already in use", err.TransactionID)
}
// send sends p to destination and returns a response channel.
@@ -357,7 +469,7 @@ func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv
c.pendingMu.Lock()
if _, ok := c.pending[msg.TransactionID]; ok {
c.pendingMu.Unlock()
- return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID)
+ return nil, nil, &ErrTransactionIDInUse{msg.TransactionID}
}
ch := make(chan *dhcpv4.DHCPv4, c.bufferCap)
@@ -384,7 +496,7 @@ func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv
if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil {
cancel()
- return nil, nil, fmt.Errorf("error writing packet to connection: %v", err)
+ return nil, nil, fmt.Errorf("error writing packet to connection: %w", err)
}
return ch, cancel, nil
}
@@ -441,7 +553,7 @@ func (c *Client) retryFn(fn func(timeout time.Duration) error) error {
timeout := c.timeout
// Each retry takes the amount of timeout at worst.
- for i := 0; i < c.retry || c.retry < 0; i++ {
+ for i := 0; i < c.retry || c.retry < 0; i++ { // TODO: why is this called "retry" if this is "tries" ("retries"+1)?
switch err := fn(timeout); err {
case nil:
// Got it!
diff --git a/dhcpv4/nclient4/conn_linux.go b/dhcpv4/nclient4/conn_linux.go
index 064e6ca..1d0ec3a 100644
--- a/dhcpv4/nclient4/conn_linux.go
+++ b/dhcpv4/nclient4/conn_linux.go
@@ -7,7 +7,7 @@
package nclient4
import (
- "fmt"
+ "errors"
"io"
"net"
@@ -23,6 +23,11 @@ var (
BroadcastMac = net.HardwareAddr([]byte{255, 255, 255, 255, 255, 255})
)
+var (
+ // ErrUDPAddrIsRequired is an error used when a passed argument is not of type "*net.UDPAddr".
+ ErrUDPAddrIsRequired = errors.New("must supply UDPAddr")
+)
+
// NewRawUDPConn returns a UDP connection bound to the interface and port
// given based on a raw packet socket. All packets are broadcasted.
//
@@ -127,7 +132,7 @@ func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
func (upc *BroadcastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
- return 0, fmt.Errorf("must supply UDPAddr")
+ return 0, ErrUDPAddrIsRequired
}
// Using the boundAddr is not quite right here, but it works.