summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4
diff options
context:
space:
mode:
authorinsomniac <insomniacslk@users.noreply.github.com>2018-07-27 15:39:11 +0100
committerGitHub <noreply@github.com>2018-07-27 15:39:11 +0100
commit46ce21cd90de7e082c7ef2eaefb6e499fa80fb3a (patch)
tree97bbf0af2a89e3107de600da77bedced37a857d1 /dhcpv4
parent45ecdd89ec87e6c8d4833d987ef94cab25c064c5 (diff)
[DHCPv4] BroadcastSendReceive now can wait for specific reply types (#95)
Diffstat (limited to 'dhcpv4')
-rw-r--r--dhcpv4/bsdp/client.go4
-rw-r--r--dhcpv4/client.go50
-rw-r--r--dhcpv4/dhcpv4.go10
-rw-r--r--dhcpv4/dhcpv4_test.go24
-rw-r--r--dhcpv4/types.go3
5 files changed, 74 insertions, 17 deletions
diff --git a/dhcpv4/bsdp/client.go b/dhcpv4/bsdp/client.go
index 10cf6b4..3d885aa 100644
--- a/dhcpv4/bsdp/client.go
+++ b/dhcpv4/bsdp/client.go
@@ -60,7 +60,7 @@ func (c *Client) Exchange(ifname string, informList *dhcpv4.DHCPv4) ([]dhcpv4.DH
conversation[0] = *informList
// ACK[LIST]
- ackForList, err := dhcpv4.BroadcastSendReceive(sendFd, recvFd, informList, c.ReadTimeout, c.WriteTimeout)
+ ackForList, err := dhcpv4.BroadcastSendReceive(sendFd, recvFd, informList, c.ReadTimeout, c.WriteTimeout, dhcpv4.MessageTypeAck)
if err != nil {
return conversation, err
}
@@ -86,7 +86,7 @@ func (c *Client) Exchange(ifname string, informList *dhcpv4.DHCPv4) ([]dhcpv4.DH
conversation = append(conversation, *informSelect)
// ACK[SELECT]
- ackForSelect, err := dhcpv4.BroadcastSendReceive(sendFd, recvFd, informSelect, c.ReadTimeout, c.WriteTimeout)
+ ackForSelect, err := dhcpv4.BroadcastSendReceive(sendFd, recvFd, informSelect, c.ReadTimeout, c.WriteTimeout, dhcpv4.MessageTypeAck)
castVendorOpt(ackForSelect)
if err != nil {
return conversation, err
diff --git a/dhcpv4/client.go b/dhcpv4/client.go
index fbdc280..1ce6a5c 100644
--- a/dhcpv4/client.go
+++ b/dhcpv4/client.go
@@ -148,7 +148,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) {
conversation[0] = *discover
// Offer
- offer, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout)
+ offer, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout, MessageTypeOffer)
if err != nil {
return conversation, err
}
@@ -162,7 +162,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) {
conversation = append(conversation, *request)
// Ack
- ack, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout)
+ ack, err := BroadcastSendReceive(sfd, rfd, request, c.ReadTimeout, c.WriteTimeout, MessageTypeAck)
if err != nil {
return conversation, err
}
@@ -171,8 +171,9 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) {
}
// BroadcastSendReceive broadcasts packet (with some write timeout) and waits for a
-// response up to some read timeout value.
-func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration) (*DHCPv4, error) {
+// response up to some read timeout value. If the message type is not
+// MessageTypeNone, it will wait for a specific message type
+func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration, messageType MessageType) (*DHCPv4, error) {
packetBytes, err := MakeRawBroadcastPacket(packet.ToBytes())
if err != nil {
return nil, err
@@ -194,17 +195,36 @@ func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, write
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(readTimeout))
- buf := make([]byte, MaxUDPReceivedPacketSize)
- n, _, _, _, err := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{})
- if err != nil {
- errs <- err
- return
- }
-
- response, err = FromBytes(buf[:n])
- if err != nil {
- errs <- err
- return
+ for {
+ buf := make([]byte, MaxUDPReceivedPacketSize)
+ n, _, _, _, err := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{})
+ if err != nil {
+ errs <- err
+ return
+ }
+
+ response, err = FromBytes(buf[:n])
+ if err != nil {
+ errs <- err
+ return
+ }
+ // check that this is a response to our message
+ if response.TransactionID() != packet.TransactionID() {
+ continue
+ }
+ // wait for a response message
+ if response.Opcode() != OpcodeBootReply {
+ continue
+ }
+ // if we are not requested to wait for a specific message type,
+ // return what we have
+ if messageType == MessageTypeNone {
+ break
+ }
+ // break if it's a reply of the desired type, continue otherwise
+ if response.MessageType() != nil && *response.MessageType() == messageType {
+ break
+ }
}
recvErrors <- nil
}(recvErrors)
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go
index 226c077..a88be19 100644
--- a/dhcpv4/dhcpv4.go
+++ b/dhcpv4/dhcpv4.go
@@ -561,6 +561,16 @@ func (d *DHCPv4) AddOption(option Option) {
d.options = append(d.options, option)
}
+// MessageType returns the message type, trying to extract it from the
+// OptMessageType option. It returns nil if the message type cannot be extracted
+func (d *DHCPv4) MessageType() *MessageType {
+ opt := d.GetOneOption(OptionDHCPMessageType)
+ if opt == nil {
+ return nil
+ }
+ return &(opt.(*OptMessageType).MessageType)
+}
+
func (d *DHCPv4) String() string {
return fmt.Sprintf("DHCPv4(opcode=%v hwtype=%v hwaddr=%v)",
d.OpcodeToString(), d.HwTypeToString(), d.ClientHwAddr())
diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go
index 4e4f7b5..9d44810 100644
--- a/dhcpv4/dhcpv4_test.go
+++ b/dhcpv4/dhcpv4_test.go
@@ -322,6 +322,30 @@ func TestGetOption(t *testing.T) {
require.Equal(t, d.GetOneOption(OptionRouter), nil)
}
+func TestDHCPv4RequestFromOffer(t *testing.T) {
+ offer, err := New()
+ require.NoError(t, err)
+ offer.AddOption(&OptMessageType{MessageType: MessageTypeOffer})
+ offer.AddOption(&OptServerIdentifier{ServerID: net.IPv4(192, 168, 0, 1)})
+ req, err := RequestFromOffer(*offer)
+ require.NoError(t, err)
+ require.NotEqual(t, (*MessageType)(nil), *req.MessageType())
+ require.Equal(t, MessageTypeRequest, *req.MessageType())
+}
+
+func TestDHCPv4MessageTypeNil(t *testing.T) {
+ m, err := New()
+ require.NoError(t, err)
+ require.Equal(t, (*MessageType)(nil), m.MessageType())
+}
+
+func TestDHCPv4MessageTypeDiscovery(t *testing.T) {
+ m, err := NewDiscoveryForInterface("lo")
+ require.NoError(t, err)
+ require.NotEqual(t, (*MessageType)(nil), m.MessageType())
+ require.Equal(t, MessageTypeDiscover, *m.MessageType())
+}
+
// TODO
// test broadcast/unicast flags
// test Options setter/getter
diff --git a/dhcpv4/types.go b/dhcpv4/types.go
index 7c22bff..08a1a77 100644
--- a/dhcpv4/types.go
+++ b/dhcpv4/types.go
@@ -8,6 +8,9 @@ type MessageType byte
// DHCP message types
const (
+ // MessageTypeNone is not a real message type, it is used by certain
+ // functions to signal that no explict message type is requested
+ MessageTypeNone MessageType = 0
MessageTypeDiscover MessageType = 1
MessageTypeOffer MessageType = 2
MessageTypeRequest MessageType = 3