summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv4/dhcpv4.go3
-rw-r--r--dhcpv6/async/client.go4
-rw-r--r--dhcpv6/dhcpv6.go16
-rw-r--r--dhcpv6/dhcpv6_test.go69
-rw-r--r--dhcpv6/dhcpv6message.go65
-rw-r--r--dhcpv6/dhcpv6relay_test.go2
-rw-r--r--dhcpv6/option_relaymsg_test.go3
-rw-r--r--dhcpv6/types.go8
8 files changed, 54 insertions, 116 deletions
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go
index 5f6943b..1e6c26e 100644
--- a/dhcpv4/dhcpv4.go
+++ b/dhcpv4/dhcpv4.go
@@ -112,6 +112,9 @@ func GetExternalIPv4Addrs(addrs []net.Addr) ([]net.IP, error) {
func GenerateTransactionID() (TransactionID, error) {
var xid TransactionID
n, err := rand.Read(xid[:])
+ if err != nil {
+ return xid, err
+ }
if n != 4 {
return xid, errors.New("invalid random sequence for transaction ID: smaller than 32 bits")
}
diff --git a/dhcpv6/async/client.go b/dhcpv6/async/client.go
index c574208..7a8b9ec 100644
--- a/dhcpv6/async/client.go
+++ b/dhcpv6/async/client.go
@@ -25,7 +25,7 @@ type Client struct {
receiveQueue chan dhcpv6.DHCPv6
sendQueue chan dhcpv6.DHCPv6
packetsLock sync.Mutex
- packets map[uint32]*promise.Promise
+ packets map[dhcpv6.TransactionID]*promise.Promise
errors chan error
}
@@ -69,7 +69,7 @@ func (c *Client) Open(bufferSize int) error {
c.stopping = new(sync.WaitGroup)
c.sendQueue = make(chan dhcpv6.DHCPv6, bufferSize)
c.receiveQueue = make(chan dhcpv6.DHCPv6, bufferSize)
- c.packets = make(map[uint32]*promise.Promise)
+ c.packets = make(map[dhcpv6.TransactionID]*promise.Promise)
c.packetsLock = sync.Mutex{}
c.errors = make(chan error)
diff --git a/dhcpv6/dhcpv6.go b/dhcpv6/dhcpv6.go
index c9f8c05..70c2bff 100644
--- a/dhcpv6/dhcpv6.go
+++ b/dhcpv6/dhcpv6.go
@@ -63,14 +63,10 @@ func FromBytes(data []byte) (DHCPv6, error) {
}
return &d, nil
} else {
- tid, err := BytesToTransactionID(data[1:4])
- if err != nil {
- return nil, err
- }
d := DHCPv6Message{
- messageType: messageType,
- transactionID: *tid,
+ messageType: messageType,
}
+ copy(d.transactionID[:], data[1:4])
if err := d.options.FromBytes(data[4:]); err != nil {
return nil, err
}
@@ -86,7 +82,7 @@ func NewMessage(modifiers ...Modifier) (DHCPv6, error) {
}
msg := DHCPv6Message{
messageType: MessageTypeSolicit,
- transactionID: *tid,
+ transactionID: tid,
}
// apply modifiers
d := DHCPv6(&msg)
@@ -209,16 +205,16 @@ func IsUsingUEFI(msg DHCPv6) bool {
// GetTransactionID returns a transactionID of a message or its inner message
// in case of relay
-func GetTransactionID(packet DHCPv6) (uint32, error) {
+func GetTransactionID(packet DHCPv6) (TransactionID, error) {
if message, ok := packet.(*DHCPv6Message); ok {
return message.TransactionID(), nil
}
if relay, ok := packet.(*DHCPv6Relay); ok {
message, err := relay.GetInnerMessage()
if err != nil {
- return 0, err
+ return TransactionID{0, 0, 0}, err
}
return GetTransactionID(message)
}
- return 0, errors.New("Invalid DHCPv6 packet")
+ return TransactionID{0, 0, 0}, errors.New("Invalid DHCPv6 packet")
}
diff --git a/dhcpv6/dhcpv6_test.go b/dhcpv6/dhcpv6_test.go
index 3a9e3f1..d5a29aa 100644
--- a/dhcpv6/dhcpv6_test.go
+++ b/dhcpv6/dhcpv6_test.go
@@ -7,52 +7,12 @@ import (
"net"
"testing"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/insomniacslk/dhcp/iana"
)
-func TestBytesToTransactionID(t *testing.T) {
- // Check if the function transforms the bytes for the exact length input
- b := make([]byte, 4)
- binary.LittleEndian.PutUint32(b, 0x01020304)
- tid, err := BytesToTransactionID(b)
- require.NoError(t, err)
- require.NotNil(t, tid)
- assert.Equal(t, *tid, uint32(0x000040302))
-
- binary.BigEndian.PutUint32(b, 0x01020304)
- tid, err = BytesToTransactionID(b)
- require.NoError(t, err)
- require.NotNil(t, tid)
- assert.Equal(t, *tid, uint32(0x00010203))
-
- // Check if the function transforms only the first bytes for a longer input
- b = make([]byte, 8)
- binary.LittleEndian.PutUint32(b, 0x01020304)
- binary.LittleEndian.PutUint32(b[4:], 0x11121314)
- tid, err = BytesToTransactionID(b)
- require.NoError(t, err)
- require.NotNil(t, tid)
- assert.Equal(t, *tid, uint32(0x000040302))
-
- binary.BigEndian.PutUint32(b, 0x01020304)
- binary.BigEndian.PutUint32(b[4:], 0x11121314)
- tid, err = BytesToTransactionID(b)
- require.NoError(t, err)
- require.NotNil(t, tid)
- assert.Equal(t, *tid, uint32(0x00010203))
-}
-
-func TestBytesToTransactionIDShortData(t *testing.T) {
- // short sequence, less than three bytes
- tid, err := BytesToTransactionID([]byte{0x11, 0x22})
- require.Error(t, err)
- require.Nil(t, tid)
-}
-
func randomReadMock(value []byte, n int, err error) func([]byte) (int, error) {
return func(b []byte) (int, error) {
copy(b, value)
@@ -77,22 +37,21 @@ func (s *GenerateTransactionIDTestSuite) TestErrors() {
// Error is returned from random number generator
e := errors.New("mocked error")
randomRead = randomReadMock(s.random, 0, e)
- tid, err := GenerateTransactionID()
+ _, err := GenerateTransactionID()
s.Assert().Equal(e, err)
- s.Assert().Nil(tid)
// Less than 4 bytes are generated
- randomRead = randomReadMock(s.random, 3, nil)
+ randomRead = randomReadMock(s.random, 2, nil)
_, err = GenerateTransactionID()
- s.Assert().EqualError(err, "invalid random sequence: shorter than 4 bytes")
+ s.Assert().EqualError(err, "invalid random sequence: shorter than 3 bytes")
}
func (s *GenerateTransactionIDTestSuite) TestSuccess() {
- binary.BigEndian.PutUint32(s.random, 0x01020304)
- randomRead = randomReadMock(s.random, 4, nil)
+ binary.BigEndian.PutUint32(s.random, 0x01020300)
+ randomRead = randomReadMock(s.random, 3, nil)
tid, err := GenerateTransactionID()
s.Require().NoError(err)
- s.Assert().Equal(*tid, uint32(0x00010203))
+ s.Assert().Equal(TransactionID{0x1, 0x2, 0x3}, tid)
}
func TestGenerateTransactionIDTestSuite(t *testing.T) {
@@ -159,8 +118,9 @@ func TestSettersAndGetters(t *testing.T) {
require.Equal(t, MessageTypeAdvertise, d.Type())
// TransactionID
- d.SetTransactionID(12345)
- require.Equal(t, uint32(12345), d.TransactionID())
+ xid := TransactionID{0xa, 0xb, 0xc}
+ d.SetTransactionID(xid)
+ require.Equal(t, xid, d.TransactionID())
// Options
require.Empty(t, d.Options())
@@ -180,11 +140,12 @@ func TestAddOption(t *testing.T) {
func TestToBytes(t *testing.T) {
d := DHCPv6Message{}
d.SetMessage(MessageTypeSolicit)
- d.SetTransactionID(0xabcdef)
+ xid := TransactionID{0xa, 0xb, 0xc}
+ d.SetTransactionID(xid)
opt := OptionGeneric{OptionCode: 0, OptionData: []byte{}}
d.AddOption(&opt)
bytes := d.ToBytes()
- expected := []byte{01, 0xab, 0xcd, 0xef, 0x00, 0x00, 0x00, 0x00}
+ expected := []byte{01, 0xa, 0xb, 0xc, 0x00, 0x00, 0x00, 0x00}
require.Equal(t, expected, bytes)
}
@@ -199,7 +160,8 @@ func TestFromAndToBytes(t *testing.T) {
func TestNewAdvertiseFromSolicit(t *testing.T) {
s := DHCPv6Message{}
s.SetMessage(MessageTypeSolicit)
- s.SetTransactionID(0xabcdef)
+ xid := TransactionID{0xa, 0xb, 0xc}
+ s.SetTransactionID(xid)
cid := OptClientId{}
s.AddOption(&cid)
duid := Duid{}
@@ -212,7 +174,8 @@ func TestNewAdvertiseFromSolicit(t *testing.T) {
func TestNewReplyFromDHCPv6Message(t *testing.T) {
msg := DHCPv6Message{}
- msg.SetTransactionID(0xabcdef)
+ xid := TransactionID{0xa, 0xb, 0xc}
+ msg.SetTransactionID(xid)
cid := OptClientId{}
msg.AddOption(&cid)
sid := OptServerId{}
diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go
index abccf43..85d93a3 100644
--- a/dhcpv6/dhcpv6message.go
+++ b/dhcpv6/dhcpv6message.go
@@ -2,7 +2,6 @@ package dhcpv6
import (
"crypto/rand"
- "encoding/binary"
"errors"
"fmt"
"log"
@@ -16,49 +15,21 @@ const MessageHeaderSize = 4
type DHCPv6Message struct {
messageType MessageType
- transactionID uint32 // only 24 bits are used though
+ transactionID TransactionID
options Options
}
-func BytesToTransactionID(data []byte) (*uint32, error) {
- // return a uint32 from a sequence of bytes, representing a transaction ID.
- // Transaction IDs are three-bytes long. If the provided data is shorter than
- // 3 bytes, it return an error. If longer, will use the first three bytes
- // only.
- if len(data) < 3 {
- return nil, fmt.Errorf("Invalid transaction ID: less than 3 bytes")
- }
- buf := make([]byte, 4)
- copy(buf[1:4], data[:3])
- tid := binary.BigEndian.Uint32(buf)
- return &tid, nil
-}
-
var randomRead = rand.Read
-func GenerateTransactionID() (*uint32, error) {
- var tid *uint32
- for {
- tidBytes := make([]byte, 4)
- n, err := randomRead(tidBytes)
- if err != nil {
- return nil, err
- }
- if n != 4 {
- return nil, fmt.Errorf("invalid random sequence: shorter than 4 bytes")
- }
- tid, err = BytesToTransactionID(tidBytes)
- if err != nil {
- return nil, err
- }
- if tid == nil {
- return nil, fmt.Errorf("got a nil Transaction ID")
- }
- // retry until != 0
- // TODO add retry limit
- if *tid != 0 {
- break
- }
+// GenerateTransactionID generates a random 3-byte transaction ID.
+func GenerateTransactionID() (TransactionID, error) {
+ var tid TransactionID
+ n, err := randomRead(tid[:])
+ if err != nil {
+ return tid, err
+ }
+ if n != len(tid) {
+ return tid, fmt.Errorf("invalid random sequence: shorter than 3 bytes")
}
return tid, nil
}
@@ -260,16 +231,14 @@ func (d *DHCPv6Message) MessageTypeToString() string {
return d.messageType.String()
}
-func (d *DHCPv6Message) TransactionID() uint32 {
+// TransactionID returns this message's transaction id.
+func (d *DHCPv6Message) TransactionID() TransactionID {
return d.transactionID
}
-func (d *DHCPv6Message) SetTransactionID(tid uint32) {
- ttid := tid & 0x00ffffff
- if ttid != tid {
- log.Printf("Warning: truncating transaction ID that is longer than 24 bits: %v", tid)
- }
- d.transactionID = ttid
+// SetTransactionID sets this message's transaction id.
+func (d *DHCPv6Message) SetTransactionID(tid TransactionID) {
+ d.transactionID = tid
}
func (d *DHCPv6Message) SetOptions(options []Option) {
@@ -343,9 +312,7 @@ func (d *DHCPv6Message) Summary() string {
func (d *DHCPv6Message) ToBytes() []byte {
var ret []byte
ret = append(ret, byte(d.messageType))
- tidBytes := make([]byte, 4)
- binary.BigEndian.PutUint32(tidBytes, d.transactionID)
- ret = append(ret, tidBytes[1:4]...) // discard the first byte
+ ret = append(ret, d.transactionID[:]...) // discard the first byte
for _, opt := range d.options {
ret = append(ret, opt.ToBytes()...)
}
diff --git a/dhcpv6/dhcpv6relay_test.go b/dhcpv6/dhcpv6relay_test.go
index 0710c5d..d2446ec 100644
--- a/dhcpv6/dhcpv6relay_test.go
+++ b/dhcpv6/dhcpv6relay_test.go
@@ -92,7 +92,7 @@ func TestDHCPv6RelayToBytes(t *testing.T) {
opt := OptRelayMsg{
relayMessage: &DHCPv6Message{
messageType: MessageTypeSolicit,
- transactionID: 0xaabbcc,
+ transactionID: TransactionID{0xaa, 0xbb, 0xcc},
options: []Option{
&OptElapsedTime{
ElapsedTime: 0,
diff --git a/dhcpv6/option_relaymsg_test.go b/dhcpv6/option_relaymsg_test.go
index 3e0deaa..996b514 100644
--- a/dhcpv6/option_relaymsg_test.go
+++ b/dhcpv6/option_relaymsg_test.go
@@ -102,7 +102,8 @@ func TestRelayMsgParseOptRelayMsgSingleEncapsulation(t *testing.T) {
MessageTypeSolicit, dType,
)
}
- if tID := innerDHCP.TransactionID(); tID != 0xaabbcc {
+ xid := TransactionID{0xaa, 0xbb, 0xcc}
+ if tID := innerDHCP.TransactionID(); tID != xid {
t.Fatalf("Invalid inner DHCP transaction ID. Expected 0xaabbcc, got %v", tID)
}
if len(innerDHCP.options) != 1 {
diff --git a/dhcpv6/types.go b/dhcpv6/types.go
index f842ff1..77b1e75 100644
--- a/dhcpv6/types.go
+++ b/dhcpv6/types.go
@@ -4,6 +4,14 @@ import (
"fmt"
)
+// TransactionID is a DHCPv6 Transaction ID defined by RFC 3315, Section 6.
+type TransactionID [3]byte
+
+// String prints the transaction ID as a hex value.
+func (xid TransactionID) String() string {
+ return fmt.Sprintf("0x%x", xid[:])
+}
+
// MessageType represents the kind of DHCPv6 message.
type MessageType uint8