diff options
Diffstat (limited to 'dhcpv4')
-rw-r--r-- | dhcpv4/dhcpv4.go | 10 | ||||
-rw-r--r-- | dhcpv4/dhcpv4_test.go | 37 | ||||
-rw-r--r-- | dhcpv4/fuzz.go | 41 | ||||
-rw-r--r-- | dhcpv4/option_routes.go | 3 | ||||
-rw-r--r-- | dhcpv4/option_routes_test.go | 14 | ||||
-rw-r--r-- | dhcpv4/options.go | 10 | ||||
-rw-r--r-- | dhcpv4/options_test.go | 18 |
7 files changed, 108 insertions, 25 deletions
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index ff72c9e..e85fc60 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -144,6 +144,7 @@ func New(modifiers ...Modifier) (*DHCPv4, error) { d := DHCPv4{ OpCode: OpcodeBootRequest, HWType: iana.HWTypeEthernet, + ClientHWAddr: make(net.HardwareAddr, 6), HopCount: 0, TransactionID: xid, NumSeconds: 0, @@ -476,9 +477,6 @@ func (d *DHCPv4) ToBytes() []byte { // HwAddrLen hlen := uint8(len(d.ClientHWAddr)) - if hlen == 0 && d.HWType == iana.HWTypeEthernet { - hlen = 6 - } buf.Write8(hlen) buf.Write8(d.HopCount) buf.WriteBytes(d.TransactionID[:]) @@ -492,13 +490,11 @@ func (d *DHCPv4) ToBytes() []byte { copy(buf.WriteN(16), d.ClientHWAddr) var sname [64]byte - copy(sname[:], []byte(d.ServerHostName)) - sname[len(d.ServerHostName)] = 0 + copy(sname[:63], []byte(d.ServerHostName)) buf.WriteBytes(sname[:]) var file [128]byte - copy(file[:], []byte(d.BootFileName)) - file[len(d.BootFileName)] = 0 + copy(file[:127], []byte(d.BootFileName)) buf.WriteBytes(file[:]) // The magic cookie. diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index ea3776c..6bbee31 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "bytes" "net" + "strconv" + "strings" "testing" "github.com/insomniacslk/dhcp/iana" @@ -80,16 +82,18 @@ func TestFromBytes(t *testing.T) { // above } -func TestFromBytesZeroLength(t *testing.T) { - data := []byte{} - _, err := FromBytes(data) - require.Error(t, err) -} - -func TestFromBytesShortLength(t *testing.T) { - data := []byte{1, 1, 6, 0} - _, err := FromBytes(data) - require.Error(t, err) +func TestFromBytesGenericInvalid(t *testing.T) { + data := [][]byte{ + {}, + {1, 1, 6, 0}, + } + t.Parallel() + for i, packet := range data { + t.Run(strconv.Itoa(i), func(t *testing.T) { + _, err := FromBytes(packet) + require.Error(t, err) + }) + } } func TestFromBytesInvalidOptions(t *testing.T) { @@ -178,6 +182,17 @@ func TestNewToBytes(t *testing.T) { require.Equal(t, expected, got) } +func TestToBytesStringTooLong(t *testing.T) { + d, err := New() + if err != nil { + t.Fatal(err) + } + d.ServerHostName = strings.Repeat("a", 256) + d.BootFileName = strings.Repeat("a", 256) + + require.NotPanics(t, func() { _ = d.ToBytes() }) +} + func TestGetOption(t *testing.T) { d, err := New() if err != nil { @@ -339,7 +354,7 @@ func TestSummary(t *testing.T) { " your IP: 0.0.0.0\n" + " server IP: 0.0.0.0\n" + " gateway IP: 0.0.0.0\n" + - " client MAC: \n" + + " client MAC: 00:00:00:00:00:00\n" + " server hostname: \n" + " bootfile name: \n" + " options:\n" + diff --git a/dhcpv4/fuzz.go b/dhcpv4/fuzz.go new file mode 100644 index 0000000..cf62ba5 --- /dev/null +++ b/dhcpv4/fuzz.go @@ -0,0 +1,41 @@ +// +build gofuzz + +package dhcpv4 + +import ( + "fmt" + "reflect" +) + +// Fuzz is the entrypoint for go-fuzz +func Fuzz(data []byte) int { + msg, err := FromBytes(data) + if err != nil { + return 0 + } + + serialized := msg.ToBytes() + + // Compared to dhcpv6, dhcpv4 has padding and fixed-size fields containing + // variable-length data; We can't expect the library to output byte-for-byte + // identical packets after a round-trip. + // Instead, we check that after a round-trip, the packet reserializes to the + // same internal representation + rtMsg, err := FromBytes(serialized) + + if err != nil || !reflect.DeepEqual(msg, rtMsg) { + fmt.Printf("Input: %x\n", data) + fmt.Printf("Round-trip: %x\n", serialized) + fmt.Println("Message: ", msg.Summary()) + fmt.Printf("Go repr: %#v\n", msg) + fmt.Println("Reserialized: ", rtMsg.Summary()) + fmt.Printf("Go repr: %#v\n", rtMsg) + if err != nil { + fmt.Printf("Got error while reserializing: %v\n", err) + panic("round-trip error: " + err.Error()) + } + panic("round-trip different: " + msg.Summary()) + } + + return 1 +} diff --git a/dhcpv4/option_routes.go b/dhcpv4/option_routes.go index 603273a..c98d481 100644 --- a/dhcpv4/option_routes.go +++ b/dhcpv4/option_routes.go @@ -38,6 +38,9 @@ func (r Route) Marshal(buf *uio.Lexer) { // Unmarshal implements uio.Unmarshaler. func (r *Route) Unmarshal(buf *uio.Lexer) error { maskSize := buf.Read8() + if maskSize > 32 { + return fmt.Errorf("invalid mask length %d in route option", maskSize) + } r.Dest = &net.IPNet{ IP: make([]byte, net.IPv4len), Mask: net.CIDRMask(int(maskSize), 32), diff --git a/dhcpv4/option_routes_test.go b/dhcpv4/option_routes_test.go index 19e331b..33f0ce7 100644 --- a/dhcpv4/option_routes_test.go +++ b/dhcpv4/option_routes_test.go @@ -16,9 +16,9 @@ func mustParseIPNet(s string) *net.IPNet { func TestParseRoutes(t *testing.T) { for _, tt := range []struct { - p []byte - want Routes - err error + p []byte + want Routes + wantErr bool }{ { p: []byte{32, 10, 2, 3, 4, 0, 0, 0, 0}, @@ -51,10 +51,14 @@ func TestParseRoutes(t *testing.T) { }, }, }, + { + p: []byte{64, 10, 2, 3, 4}, + wantErr: true, // Mask length 64 > 32 + }, } { var r Routes - if err := r.FromBytes(tt.p); err != tt.err { - t.Errorf("FromBytes(%v) = %v, want %v", tt.p, err, tt.err) + if err := r.FromBytes(tt.p); (err != nil) != tt.wantErr { + t.Errorf("FromBytes(%v) Unexpected error state: %v", tt.p, err) } if !reflect.DeepEqual(r, tt.want) { diff --git a/dhcpv4/options.go b/dhcpv4/options.go index ea902f6..058c4ad 100644 --- a/dhcpv4/options.go +++ b/dhcpv4/options.go @@ -185,12 +185,20 @@ func (o Options) Marshal(b *uio.Lexer) { code := uint8(c) // Even if the End option is in there, don't marshal it until // the end. - if code == optEnd { + // Don't write padding either, since the options are sorted + // it would always be written first which isn't useful + if code == optEnd || code == optPad { continue } data := o[code] + // Ensure even 0-length options are written out + if len(data) == 0 { + b.Write8(code) + b.Write8(0) + continue + } // RFC 3396: If more than 256 bytes of data are given, the // option is simply listed multiple times. for len(data) > 0 { diff --git a/dhcpv4/options_test.go b/dhcpv4/options_test.go index 6c5393c..3850d2d 100644 --- a/dhcpv4/options_test.go +++ b/dhcpv4/options_test.go @@ -194,9 +194,25 @@ func TestOptionsMarshal(t *testing.T) { 5, 1, 10, ), }, + { + // Test 0-length options + opts: Options{ + 80: []byte{}, + }, + want: []byte{80, 0}, + }, + { + // Test special options, handled by the message marshalling code + // and ignored by the options marshalling code + opts: Options{ + 0: []byte{}, // Padding + 255: []byte{}, // End of options + }, + want: nil, // not written out + }, } { t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) { - require.Equal(t, uio.ToBigEndian(tt.opts), tt.want) + require.Equal(t, tt.want, uio.ToBigEndian(tt.opts)) }) } } |