diff options
author | insomniac <insomniacslk@users.noreply.github.com> | 2018-07-27 15:39:11 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-07-27 15:39:11 +0100 |
commit | 46ce21cd90de7e082c7ef2eaefb6e499fa80fb3a (patch) | |
tree | 97bbf0af2a89e3107de600da77bedced37a857d1 /dhcpv4 | |
parent | 45ecdd89ec87e6c8d4833d987ef94cab25c064c5 (diff) |
[DHCPv4] BroadcastSendReceive now can wait for specific reply types (#95)
Diffstat (limited to 'dhcpv4')
-rw-r--r-- | dhcpv4/bsdp/client.go | 4 | ||||
-rw-r--r-- | dhcpv4/client.go | 50 | ||||
-rw-r--r-- | dhcpv4/dhcpv4.go | 10 | ||||
-rw-r--r-- | dhcpv4/dhcpv4_test.go | 24 | ||||
-rw-r--r-- | dhcpv4/types.go | 3 |
5 files changed, 74 insertions, 17 deletions
diff --git a/dhcpv4/bsdp/client.go b/dhcpv4/bsdp/client.go index 10cf6b4..3d885aa 100644 --- a/dhcpv4/bsdp/client.go +++ b/dhcpv4/bsdp/client.go @@ -60,7 +60,7 @@ func (c *Client) Exchange(ifname string, informList *dhcpv4.DHCPv4) ([]dhcpv4.DH conversation[0] = *informList // ACK[LIST] - ackForList, err := dhcpv4.BroadcastSendReceive(sendFd, recvFd, informList, c.ReadTimeout, c.WriteTimeout) + ackForList, err := dhcpv4.BroadcastSendReceive(sendFd, recvFd, informList, c.ReadTimeout, c.WriteTimeout, dhcpv4.MessageTypeAck) if err != nil { return conversation, err } @@ -86,7 +86,7 @@ func (c *Client) Exchange(ifname string, informList *dhcpv4.DHCPv4) ([]dhcpv4.DH conversation = append(conversation, *informSelect) // ACK[SELECT] - ackForSelect, err := dhcpv4.BroadcastSendReceive(sendFd, recvFd, informSelect, c.ReadTimeout, c.WriteTimeout) + ackForSelect, err := dhcpv4.BroadcastSendReceive(sendFd, recvFd, informSelect, c.ReadTimeout, c.WriteTimeout, dhcpv4.MessageTypeAck) castVendorOpt(ackForSelect) if err != nil { return conversation, err diff --git a/dhcpv4/client.go b/dhcpv4/client.go index fbdc280..1ce6a5c 100644 --- a/dhcpv4/client.go +++ b/dhcpv4/client.go @@ -148,7 +148,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) { conversation[0] = *discover // Offer - offer, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout) + offer, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout, MessageTypeOffer) if err != nil { return conversation, err } @@ -162,7 +162,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) { conversation = append(conversation, *request) // Ack - ack, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout) + ack, err := BroadcastSendReceive(sfd, rfd, request, c.ReadTimeout, c.WriteTimeout, MessageTypeAck) if err != nil { return conversation, err } @@ -171,8 +171,9 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) { } // BroadcastSendReceive broadcasts packet (with some write timeout) and waits for a -// response up to some read timeout value. -func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration) (*DHCPv4, error) { +// response up to some read timeout value. If the message type is not +// MessageTypeNone, it will wait for a specific message type +func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration, messageType MessageType) (*DHCPv4, error) { packetBytes, err := MakeRawBroadcastPacket(packet.ToBytes()) if err != nil { return nil, err @@ -194,17 +195,36 @@ func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, write defer conn.Close() conn.SetReadDeadline(time.Now().Add(readTimeout)) - buf := make([]byte, MaxUDPReceivedPacketSize) - n, _, _, _, err := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{}) - if err != nil { - errs <- err - return - } - - response, err = FromBytes(buf[:n]) - if err != nil { - errs <- err - return + for { + buf := make([]byte, MaxUDPReceivedPacketSize) + n, _, _, _, err := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{}) + if err != nil { + errs <- err + return + } + + response, err = FromBytes(buf[:n]) + if err != nil { + errs <- err + return + } + // check that this is a response to our message + if response.TransactionID() != packet.TransactionID() { + continue + } + // wait for a response message + if response.Opcode() != OpcodeBootReply { + continue + } + // if we are not requested to wait for a specific message type, + // return what we have + if messageType == MessageTypeNone { + break + } + // break if it's a reply of the desired type, continue otherwise + if response.MessageType() != nil && *response.MessageType() == messageType { + break + } } recvErrors <- nil }(recvErrors) diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index 226c077..a88be19 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -561,6 +561,16 @@ func (d *DHCPv4) AddOption(option Option) { d.options = append(d.options, option) } +// MessageType returns the message type, trying to extract it from the +// OptMessageType option. It returns nil if the message type cannot be extracted +func (d *DHCPv4) MessageType() *MessageType { + opt := d.GetOneOption(OptionDHCPMessageType) + if opt == nil { + return nil + } + return &(opt.(*OptMessageType).MessageType) +} + func (d *DHCPv4) String() string { return fmt.Sprintf("DHCPv4(opcode=%v hwtype=%v hwaddr=%v)", d.OpcodeToString(), d.HwTypeToString(), d.ClientHwAddr()) diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index 4e4f7b5..9d44810 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -322,6 +322,30 @@ func TestGetOption(t *testing.T) { require.Equal(t, d.GetOneOption(OptionRouter), nil) } +func TestDHCPv4RequestFromOffer(t *testing.T) { + offer, err := New() + require.NoError(t, err) + offer.AddOption(&OptMessageType{MessageType: MessageTypeOffer}) + offer.AddOption(&OptServerIdentifier{ServerID: net.IPv4(192, 168, 0, 1)}) + req, err := RequestFromOffer(*offer) + require.NoError(t, err) + require.NotEqual(t, (*MessageType)(nil), *req.MessageType()) + require.Equal(t, MessageTypeRequest, *req.MessageType()) +} + +func TestDHCPv4MessageTypeNil(t *testing.T) { + m, err := New() + require.NoError(t, err) + require.Equal(t, (*MessageType)(nil), m.MessageType()) +} + +func TestDHCPv4MessageTypeDiscovery(t *testing.T) { + m, err := NewDiscoveryForInterface("lo") + require.NoError(t, err) + require.NotEqual(t, (*MessageType)(nil), m.MessageType()) + require.Equal(t, MessageTypeDiscover, *m.MessageType()) +} + // TODO // test broadcast/unicast flags // test Options setter/getter diff --git a/dhcpv4/types.go b/dhcpv4/types.go index 7c22bff..08a1a77 100644 --- a/dhcpv4/types.go +++ b/dhcpv4/types.go @@ -8,6 +8,9 @@ type MessageType byte // DHCP message types const ( + // MessageTypeNone is not a real message type, it is used by certain + // functions to signal that no explict message type is requested + MessageTypeNone MessageType = 0 MessageTypeDiscover MessageType = 1 MessageTypeOffer MessageType = 2 MessageTypeRequest MessageType = 3 |