diff options
-rw-r--r-- | dhcpv4/dhcpv4.go | 3 | ||||
-rw-r--r-- | dhcpv6/async/client.go | 4 | ||||
-rw-r--r-- | dhcpv6/dhcpv6.go | 16 | ||||
-rw-r--r-- | dhcpv6/dhcpv6_test.go | 69 | ||||
-rw-r--r-- | dhcpv6/dhcpv6message.go | 65 | ||||
-rw-r--r-- | dhcpv6/dhcpv6relay_test.go | 2 | ||||
-rw-r--r-- | dhcpv6/option_relaymsg_test.go | 3 | ||||
-rw-r--r-- | dhcpv6/types.go | 8 |
8 files changed, 54 insertions, 116 deletions
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index 5f6943b..1e6c26e 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -112,6 +112,9 @@ func GetExternalIPv4Addrs(addrs []net.Addr) ([]net.IP, error) { func GenerateTransactionID() (TransactionID, error) { var xid TransactionID n, err := rand.Read(xid[:]) + if err != nil { + return xid, err + } if n != 4 { return xid, errors.New("invalid random sequence for transaction ID: smaller than 32 bits") } diff --git a/dhcpv6/async/client.go b/dhcpv6/async/client.go index c574208..7a8b9ec 100644 --- a/dhcpv6/async/client.go +++ b/dhcpv6/async/client.go @@ -25,7 +25,7 @@ type Client struct { receiveQueue chan dhcpv6.DHCPv6 sendQueue chan dhcpv6.DHCPv6 packetsLock sync.Mutex - packets map[uint32]*promise.Promise + packets map[dhcpv6.TransactionID]*promise.Promise errors chan error } @@ -69,7 +69,7 @@ func (c *Client) Open(bufferSize int) error { c.stopping = new(sync.WaitGroup) c.sendQueue = make(chan dhcpv6.DHCPv6, bufferSize) c.receiveQueue = make(chan dhcpv6.DHCPv6, bufferSize) - c.packets = make(map[uint32]*promise.Promise) + c.packets = make(map[dhcpv6.TransactionID]*promise.Promise) c.packetsLock = sync.Mutex{} c.errors = make(chan error) diff --git a/dhcpv6/dhcpv6.go b/dhcpv6/dhcpv6.go index c9f8c05..70c2bff 100644 --- a/dhcpv6/dhcpv6.go +++ b/dhcpv6/dhcpv6.go @@ -63,14 +63,10 @@ func FromBytes(data []byte) (DHCPv6, error) { } return &d, nil } else { - tid, err := BytesToTransactionID(data[1:4]) - if err != nil { - return nil, err - } d := DHCPv6Message{ - messageType: messageType, - transactionID: *tid, + messageType: messageType, } + copy(d.transactionID[:], data[1:4]) if err := d.options.FromBytes(data[4:]); err != nil { return nil, err } @@ -86,7 +82,7 @@ func NewMessage(modifiers ...Modifier) (DHCPv6, error) { } msg := DHCPv6Message{ messageType: MessageTypeSolicit, - transactionID: *tid, + transactionID: tid, } // apply modifiers d := DHCPv6(&msg) @@ -209,16 +205,16 @@ func IsUsingUEFI(msg DHCPv6) bool { // GetTransactionID returns a transactionID of a message or its inner message // in case of relay -func GetTransactionID(packet DHCPv6) (uint32, error) { +func GetTransactionID(packet DHCPv6) (TransactionID, error) { if message, ok := packet.(*DHCPv6Message); ok { return message.TransactionID(), nil } if relay, ok := packet.(*DHCPv6Relay); ok { message, err := relay.GetInnerMessage() if err != nil { - return 0, err + return TransactionID{0, 0, 0}, err } return GetTransactionID(message) } - return 0, errors.New("Invalid DHCPv6 packet") + return TransactionID{0, 0, 0}, errors.New("Invalid DHCPv6 packet") } diff --git a/dhcpv6/dhcpv6_test.go b/dhcpv6/dhcpv6_test.go index 3a9e3f1..d5a29aa 100644 --- a/dhcpv6/dhcpv6_test.go +++ b/dhcpv6/dhcpv6_test.go @@ -7,52 +7,12 @@ import ( "net" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/insomniacslk/dhcp/iana" ) -func TestBytesToTransactionID(t *testing.T) { - // Check if the function transforms the bytes for the exact length input - b := make([]byte, 4) - binary.LittleEndian.PutUint32(b, 0x01020304) - tid, err := BytesToTransactionID(b) - require.NoError(t, err) - require.NotNil(t, tid) - assert.Equal(t, *tid, uint32(0x000040302)) - - binary.BigEndian.PutUint32(b, 0x01020304) - tid, err = BytesToTransactionID(b) - require.NoError(t, err) - require.NotNil(t, tid) - assert.Equal(t, *tid, uint32(0x00010203)) - - // Check if the function transforms only the first bytes for a longer input - b = make([]byte, 8) - binary.LittleEndian.PutUint32(b, 0x01020304) - binary.LittleEndian.PutUint32(b[4:], 0x11121314) - tid, err = BytesToTransactionID(b) - require.NoError(t, err) - require.NotNil(t, tid) - assert.Equal(t, *tid, uint32(0x000040302)) - - binary.BigEndian.PutUint32(b, 0x01020304) - binary.BigEndian.PutUint32(b[4:], 0x11121314) - tid, err = BytesToTransactionID(b) - require.NoError(t, err) - require.NotNil(t, tid) - assert.Equal(t, *tid, uint32(0x00010203)) -} - -func TestBytesToTransactionIDShortData(t *testing.T) { - // short sequence, less than three bytes - tid, err := BytesToTransactionID([]byte{0x11, 0x22}) - require.Error(t, err) - require.Nil(t, tid) -} - func randomReadMock(value []byte, n int, err error) func([]byte) (int, error) { return func(b []byte) (int, error) { copy(b, value) @@ -77,22 +37,21 @@ func (s *GenerateTransactionIDTestSuite) TestErrors() { // Error is returned from random number generator e := errors.New("mocked error") randomRead = randomReadMock(s.random, 0, e) - tid, err := GenerateTransactionID() + _, err := GenerateTransactionID() s.Assert().Equal(e, err) - s.Assert().Nil(tid) // Less than 4 bytes are generated - randomRead = randomReadMock(s.random, 3, nil) + randomRead = randomReadMock(s.random, 2, nil) _, err = GenerateTransactionID() - s.Assert().EqualError(err, "invalid random sequence: shorter than 4 bytes") + s.Assert().EqualError(err, "invalid random sequence: shorter than 3 bytes") } func (s *GenerateTransactionIDTestSuite) TestSuccess() { - binary.BigEndian.PutUint32(s.random, 0x01020304) - randomRead = randomReadMock(s.random, 4, nil) + binary.BigEndian.PutUint32(s.random, 0x01020300) + randomRead = randomReadMock(s.random, 3, nil) tid, err := GenerateTransactionID() s.Require().NoError(err) - s.Assert().Equal(*tid, uint32(0x00010203)) + s.Assert().Equal(TransactionID{0x1, 0x2, 0x3}, tid) } func TestGenerateTransactionIDTestSuite(t *testing.T) { @@ -159,8 +118,9 @@ func TestSettersAndGetters(t *testing.T) { require.Equal(t, MessageTypeAdvertise, d.Type()) // TransactionID - d.SetTransactionID(12345) - require.Equal(t, uint32(12345), d.TransactionID()) + xid := TransactionID{0xa, 0xb, 0xc} + d.SetTransactionID(xid) + require.Equal(t, xid, d.TransactionID()) // Options require.Empty(t, d.Options()) @@ -180,11 +140,12 @@ func TestAddOption(t *testing.T) { func TestToBytes(t *testing.T) { d := DHCPv6Message{} d.SetMessage(MessageTypeSolicit) - d.SetTransactionID(0xabcdef) + xid := TransactionID{0xa, 0xb, 0xc} + d.SetTransactionID(xid) opt := OptionGeneric{OptionCode: 0, OptionData: []byte{}} d.AddOption(&opt) bytes := d.ToBytes() - expected := []byte{01, 0xab, 0xcd, 0xef, 0x00, 0x00, 0x00, 0x00} + expected := []byte{01, 0xa, 0xb, 0xc, 0x00, 0x00, 0x00, 0x00} require.Equal(t, expected, bytes) } @@ -199,7 +160,8 @@ func TestFromAndToBytes(t *testing.T) { func TestNewAdvertiseFromSolicit(t *testing.T) { s := DHCPv6Message{} s.SetMessage(MessageTypeSolicit) - s.SetTransactionID(0xabcdef) + xid := TransactionID{0xa, 0xb, 0xc} + s.SetTransactionID(xid) cid := OptClientId{} s.AddOption(&cid) duid := Duid{} @@ -212,7 +174,8 @@ func TestNewAdvertiseFromSolicit(t *testing.T) { func TestNewReplyFromDHCPv6Message(t *testing.T) { msg := DHCPv6Message{} - msg.SetTransactionID(0xabcdef) + xid := TransactionID{0xa, 0xb, 0xc} + msg.SetTransactionID(xid) cid := OptClientId{} msg.AddOption(&cid) sid := OptServerId{} diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go index abccf43..85d93a3 100644 --- a/dhcpv6/dhcpv6message.go +++ b/dhcpv6/dhcpv6message.go @@ -2,7 +2,6 @@ package dhcpv6 import ( "crypto/rand" - "encoding/binary" "errors" "fmt" "log" @@ -16,49 +15,21 @@ const MessageHeaderSize = 4 type DHCPv6Message struct { messageType MessageType - transactionID uint32 // only 24 bits are used though + transactionID TransactionID options Options } -func BytesToTransactionID(data []byte) (*uint32, error) { - // return a uint32 from a sequence of bytes, representing a transaction ID. - // Transaction IDs are three-bytes long. If the provided data is shorter than - // 3 bytes, it return an error. If longer, will use the first three bytes - // only. - if len(data) < 3 { - return nil, fmt.Errorf("Invalid transaction ID: less than 3 bytes") - } - buf := make([]byte, 4) - copy(buf[1:4], data[:3]) - tid := binary.BigEndian.Uint32(buf) - return &tid, nil -} - var randomRead = rand.Read -func GenerateTransactionID() (*uint32, error) { - var tid *uint32 - for { - tidBytes := make([]byte, 4) - n, err := randomRead(tidBytes) - if err != nil { - return nil, err - } - if n != 4 { - return nil, fmt.Errorf("invalid random sequence: shorter than 4 bytes") - } - tid, err = BytesToTransactionID(tidBytes) - if err != nil { - return nil, err - } - if tid == nil { - return nil, fmt.Errorf("got a nil Transaction ID") - } - // retry until != 0 - // TODO add retry limit - if *tid != 0 { - break - } +// GenerateTransactionID generates a random 3-byte transaction ID. +func GenerateTransactionID() (TransactionID, error) { + var tid TransactionID + n, err := randomRead(tid[:]) + if err != nil { + return tid, err + } + if n != len(tid) { + return tid, fmt.Errorf("invalid random sequence: shorter than 3 bytes") } return tid, nil } @@ -260,16 +231,14 @@ func (d *DHCPv6Message) MessageTypeToString() string { return d.messageType.String() } -func (d *DHCPv6Message) TransactionID() uint32 { +// TransactionID returns this message's transaction id. +func (d *DHCPv6Message) TransactionID() TransactionID { return d.transactionID } -func (d *DHCPv6Message) SetTransactionID(tid uint32) { - ttid := tid & 0x00ffffff - if ttid != tid { - log.Printf("Warning: truncating transaction ID that is longer than 24 bits: %v", tid) - } - d.transactionID = ttid +// SetTransactionID sets this message's transaction id. +func (d *DHCPv6Message) SetTransactionID(tid TransactionID) { + d.transactionID = tid } func (d *DHCPv6Message) SetOptions(options []Option) { @@ -343,9 +312,7 @@ func (d *DHCPv6Message) Summary() string { func (d *DHCPv6Message) ToBytes() []byte { var ret []byte ret = append(ret, byte(d.messageType)) - tidBytes := make([]byte, 4) - binary.BigEndian.PutUint32(tidBytes, d.transactionID) - ret = append(ret, tidBytes[1:4]...) // discard the first byte + ret = append(ret, d.transactionID[:]...) // discard the first byte for _, opt := range d.options { ret = append(ret, opt.ToBytes()...) } diff --git a/dhcpv6/dhcpv6relay_test.go b/dhcpv6/dhcpv6relay_test.go index 0710c5d..d2446ec 100644 --- a/dhcpv6/dhcpv6relay_test.go +++ b/dhcpv6/dhcpv6relay_test.go @@ -92,7 +92,7 @@ func TestDHCPv6RelayToBytes(t *testing.T) { opt := OptRelayMsg{ relayMessage: &DHCPv6Message{ messageType: MessageTypeSolicit, - transactionID: 0xaabbcc, + transactionID: TransactionID{0xaa, 0xbb, 0xcc}, options: []Option{ &OptElapsedTime{ ElapsedTime: 0, diff --git a/dhcpv6/option_relaymsg_test.go b/dhcpv6/option_relaymsg_test.go index 3e0deaa..996b514 100644 --- a/dhcpv6/option_relaymsg_test.go +++ b/dhcpv6/option_relaymsg_test.go @@ -102,7 +102,8 @@ func TestRelayMsgParseOptRelayMsgSingleEncapsulation(t *testing.T) { MessageTypeSolicit, dType, ) } - if tID := innerDHCP.TransactionID(); tID != 0xaabbcc { + xid := TransactionID{0xaa, 0xbb, 0xcc} + if tID := innerDHCP.TransactionID(); tID != xid { t.Fatalf("Invalid inner DHCP transaction ID. Expected 0xaabbcc, got %v", tID) } if len(innerDHCP.options) != 1 { diff --git a/dhcpv6/types.go b/dhcpv6/types.go index f842ff1..77b1e75 100644 --- a/dhcpv6/types.go +++ b/dhcpv6/types.go @@ -4,6 +4,14 @@ import ( "fmt" ) +// TransactionID is a DHCPv6 Transaction ID defined by RFC 3315, Section 6. +type TransactionID [3]byte + +// String prints the transaction ID as a hex value. +func (xid TransactionID) String() string { + return fmt.Sprintf("0x%x", xid[:]) +} + // MessageType represents the kind of DHCPv6 message. type MessageType uint8 |