summaryrefslogtreecommitdiffhomepage
path: root/dhcpv6
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv6')
-rw-r--r--dhcpv6/dhcpv6_test.go51
-rw-r--r--dhcpv6/dhcpv6message.go4
2 files changed, 49 insertions, 6 deletions
diff --git a/dhcpv6/dhcpv6_test.go b/dhcpv6/dhcpv6_test.go
index 0da839a..49fecca 100644
--- a/dhcpv6/dhcpv6_test.go
+++ b/dhcpv6/dhcpv6_test.go
@@ -1,12 +1,15 @@
package dhcpv6
import (
+ "crypto/rand"
"encoding/binary"
+ "errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/stretchr/testify/suite"
"github.com/insomniacslk/dhcp/iana"
)
@@ -50,11 +53,50 @@ func TestBytesToTransactionIDShortData(t *testing.T) {
require.Nil(t, tid)
}
-func TestGenerateTransactionID(t *testing.T) {
+func randomReadMock(value []byte, n int, err error) func([]byte) (int, error) {
+ return func(b []byte) (int, error) {
+ copy(b, value)
+ return n, err
+ }
+}
+
+type GenerateTransactionIDTestSuite struct {
+ suite.Suite
+ random []byte
+}
+
+func (s *GenerateTransactionIDTestSuite) SetupTest() {
+ s.random = make([]byte, 16)
+}
+
+func (s *GenerateTransactionIDTestSuite) TearDown() {
+ randomRead = rand.Read
+}
+
+func (s *GenerateTransactionIDTestSuite) TestErrors() {
+ // Error is returned from random number generator
+ e := errors.New("mocked error")
+ randomRead = randomReadMock(s.random, 0, e)
tid, err := GenerateTransactionID()
- require.NoError(t, err)
- require.NotNil(t, tid)
- require.True(t, *tid <= 0xffffff, "transaction ID should be smaller than 0xffffff")
+ s.Assert().Equal(e, err)
+ s.Assert().Nil(tid)
+
+ // Less than 4 bytes are generated
+ randomRead = randomReadMock(s.random, 3, nil)
+ tid, err = GenerateTransactionID()
+ s.Assert().EqualError(err, "invalid random sequence: shorter than 4 bytes")
+}
+
+func (s *GenerateTransactionIDTestSuite) TestSuccess() {
+ binary.BigEndian.PutUint32(s.random, 0x01020304)
+ randomRead = randomReadMock(s.random, 4, nil)
+ tid, err := GenerateTransactionID()
+ s.Require().NoError(err)
+ s.Assert().Equal(*tid, uint32(0x00010203))
+}
+
+func TestGenerateTransactionIDTestSuite(t *testing.T) {
+ suite.Run(t, new(GenerateTransactionIDTestSuite))
}
func TestNewMessage(t *testing.T) {
@@ -249,7 +291,6 @@ func TestNewMessageTypeSolicitWithCID(t *testing.T) {
require.Equal(t, len(opts), 2)
}
-
func TestIsUsingUEFIArchTypeTrue(t *testing.T) {
msg := DHCPv6Message{}
opt := OptClientArchType{ArchTypes: []iana.ArchType{iana.EFI_BC}}
diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go
index eeb1591..82d44b4 100644
--- a/dhcpv6/dhcpv6message.go
+++ b/dhcpv6/dhcpv6message.go
@@ -34,11 +34,13 @@ func BytesToTransactionID(data []byte) (*uint32, error) {
return &tid, nil
}
+var randomRead = rand.Read
+
func GenerateTransactionID() (*uint32, error) {
var tid *uint32
for {
tidBytes := make([]byte, 4)
- n, err := rand.Read(tidBytes)
+ n, err := randomRead(tidBytes)
if err != nil {
return nil, err
}