summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4
diff options
context:
space:
mode:
authorNick Brown <nickbrow@google.com>2021-05-12 16:51:06 -0700
committergVisor bot <gvisor-bot@google.com>2021-05-12 16:53:43 -0700
commit29f4b71eb3db3d082735bd4316006d6bcc3230a1 (patch)
tree868142adfcffdb8ba6a605f67fbd4a520d5cac8f /pkg/tcpip/network/ipv4
parent9854e5ac4d7f80a7db10270313bce7e485ce6f9b (diff)
Send ICMP errors when unable to forward fragmented packets
Before this change, we would silently drop packets when the packet was too big to be sent out through the NIC (and, for IPv4 packets, if DF was set). This change brings us into line with RFC 792 (IPv4) and RFC 4443 (IPv6), both of which specify that gateways should return an ICMP error to the sender when the packet can't be fragmented. PiperOrigin-RevId: 373480078
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go22
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go29
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go208
3 files changed, 214 insertions, 45 deletions
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index c8ed1ce79..d1a82b584 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -387,6 +387,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// icmpReason is a marker interface for IPv4 specific ICMP errors.
type icmpReason interface {
isICMPReason()
+ // isForwarding indicates whether or not the error arose while attempting to
+ // forward a packet.
isForwarding() bool
}
@@ -463,6 +465,22 @@ func (*icmpReasonNetworkUnreachable) isForwarding() bool {
return true
}
+// icmpReasonFragmentationNeeded is an error where a packet requires
+// fragmentation while also having the Don't Fragment flag set, as per RFC 792
+// page 3, Destination Unreachable Message.
+type icmpReasonFragmentationNeeded struct{}
+
+func (*icmpReasonFragmentationNeeded) isICMPReason() {}
+func (*icmpReasonFragmentationNeeded) isForwarding() bool {
+ // If we hit a Don't Fragment error, then we know we are operating as a router.
+ // As per RFC 792 page 4, Destination Unreachable Message,
+ //
+ // Another case is when a datagram must be fragmented to be forwarded by a
+ // gateway yet the Don't Fragment flag is on. In this case the gateway must
+ // discard the datagram and may return a destination unreachable message.
+ return true
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
@@ -635,6 +653,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4NetUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonFragmentationNeeded:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
+ counter = sent.dstUnreachable
case *icmpReasonTTLExceeded:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4TTLExceeded)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 4031032d0..aef83e834 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -434,6 +434,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
}
if packetMustBeFragmented(pkt, networkMTU) {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ if h.Flags()&header.IPv4FlagDontFragment != 0 && pkt.NetworkPacketInfo.IsForwardedPacket {
+ // TODO(gvisor.dev/issue/5919): Handle error condition in which DontFragment
+ // is set but the packet must be fragmented for the non-forwarding case.
+ return &tcpip.ErrMessageTooLong{}
+ }
sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
@@ -695,13 +701,28 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
- if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(newHdr).ToVectorisedView(),
- })); err != nil {
+ IsForwardedPacket: true,
+ })); err.(type) {
+ case nil:
+ return nil
+ case *tcpip.ErrMessageTooLong:
+ // As per RFC 792, page 4, Destination Unreachable:
+ //
+ // Another case is when a datagram must be fragmented to be forwarded by a
+ // gateway yet the Don't Fragment flag is on. In this case the gateway must
+ // discard the datagram and may return a destination unreachable message.
+ //
+ // WriteHeaderIncludedPacket checks for the presence of the Don't Fragment bit
+ // while sending the packet and returns this error iff fragmentation is
+ // necessary and the bit is also set.
+ _ = e.protocol.returnError(&icmpReasonFragmentationNeeded{}, pkt)
+ return &ip.ErrMessageTooLong{}
+ default:
return &ip.ErrOther{Err: err}
}
- return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
@@ -830,6 +851,8 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
case *ip.ErrParameterProblem:
e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
stats.ip.MalformedPacketsReceived.Increment()
+ case *ip.ErrMessageTooLong:
+ stats.ip.Forwarding.PacketTooBig.Increment()
default:
panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt))
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 7a7cad04a..3c8a39973 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -112,6 +112,10 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+type forwardedPacket struct {
+ fragments []fragmentInfo
+}
+
func TestForwarding(t *testing.T) {
const (
nicID1 = 1
@@ -129,6 +133,7 @@ func TestForwarding(t *testing.T) {
Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
PrefixLen: 8,
}
+ linkAddr2 := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
unreachableIPv4Addr := tcpip.Address(net.ParseIP("12.0.0.2").To4())
@@ -141,7 +146,9 @@ func TestForwarding(t *testing.T) {
sourceAddr tcpip.Address
destAddr tcpip.Address
expectErrorICMP bool
- expectPacketForwarded bool
+ ipFlags uint8
+ mtu uint32
+ payloadLength int
options header.IPv4Options
forwardedOptions header.IPv4Options
icmpType header.ICMPv4Type
@@ -149,6 +156,8 @@ func TestForwarding(t *testing.T) {
expectPacketUnrouteableError bool
expectLinkLocalSourceError bool
expectLinkLocalDestError bool
+ expectPacketForwarded bool
+ expectedFragmentsForwarded []fragmentInfo
}{
{
name: "TTL of zero",
@@ -158,6 +167,7 @@ func TestForwarding(t *testing.T) {
expectErrorICMP: true,
icmpType: header.ICMPv4TimeExceeded,
icmpCode: header.ICMPv4TTLExceeded,
+ mtu: ipv4.MaxTotalSize,
},
{
name: "TTL of one",
@@ -165,6 +175,7 @@ func TestForwarding(t *testing.T) {
sourceAddr: remoteIPv4Addr1,
destAddr: remoteIPv4Addr2,
expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
name: "TTL of two",
@@ -172,6 +183,7 @@ func TestForwarding(t *testing.T) {
sourceAddr: remoteIPv4Addr1,
destAddr: remoteIPv4Addr2,
expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
name: "Max TTL",
@@ -179,6 +191,7 @@ func TestForwarding(t *testing.T) {
sourceAddr: remoteIPv4Addr1,
destAddr: remoteIPv4Addr2,
expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
name: "four EOL options",
@@ -186,6 +199,7 @@ func TestForwarding(t *testing.T) {
sourceAddr: remoteIPv4Addr1,
destAddr: remoteIPv4Addr2,
expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{0, 0, 0, 0},
forwardedOptions: header.IPv4Options{0, 0, 0, 0},
},
@@ -194,6 +208,7 @@ func TestForwarding(t *testing.T) {
TTL: 2,
sourceAddr: remoteIPv4Addr1,
destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 12, 13, 0xF1,
192, 168, 1, 12,
@@ -208,6 +223,7 @@ func TestForwarding(t *testing.T) {
TTL: 2,
sourceAddr: remoteIPv4Addr1,
destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 24, 21, 0x00,
1, 2, 3, 4,
@@ -231,6 +247,7 @@ func TestForwarding(t *testing.T) {
TTL: 2,
sourceAddr: remoteIPv4Addr1,
destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 12, 13, 0x11,
192, 168, 1, 12,
@@ -254,6 +271,7 @@ func TestForwarding(t *testing.T) {
sourceAddr: remoteIPv4Addr1,
destAddr: unreachableIPv4Addr,
expectErrorICMP: true,
+ mtu: ipv4.MaxTotalSize,
icmpType: header.ICMPv4DstUnreachable,
icmpCode: header.ICMPv4NetUnreachable,
expectPacketUnrouteableError: true,
@@ -278,6 +296,51 @@ func TestForwarding(t *testing.T) {
destAddr: remoteIPv4Addr2,
expectLinkLocalSourceError: true,
},
+ {
+ name: "Fragmentation needed and DF set",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ ipFlags: header.IPv4FlagDontFragment,
+ // We've picked this MTU because it is:
+ //
+ // 1) Greater than the minimum MTU that IPv4 hosts are required to process
+ // (576 bytes). As per RFC 1812, Section 4.3.2.3:
+ //
+ // The ICMP datagram SHOULD contain as much of the original datagram as
+ // possible without the length of the ICMP datagram exceeding 576 bytes.
+ //
+ // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a
+ // complete ICMP packet on the incoming endpoint (and make assertions about
+ // it).
+ //
+ // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose
+ // size exceeds the MTU.
+ mtu: 1000,
+ payloadLength: 1004,
+ expectErrorICMP: true,
+ icmpType: header.ICMPv4DstUnreachable,
+ icmpCode: header.ICMPv4FragmentationNeeded,
+ },
+ {
+ name: "Fragmentation needed and DF not set",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: 1000,
+ payloadLength: 1004,
+ expectPacketForwarded: true,
+ // Combined, these fragments have length of 1012 octets, which is equal to
+ // the length of the payload (1004 octets), plus the length of the ICMP
+ // header (8 octets).
+ expectedFragmentsForwarded: []fragmentInfo{
+ // The first fragment has a length of the greatest multiple of 8 which is
+ // less than or equal to to `mtu - header.IPv4MinimumSize`.
+ {offset: 0, payloadSize: uint16(976), more: true},
+ // The next fragment holds the rest of the packet.
+ {offset: uint16(976), payloadSize: 36, more: false},
+ },
+ },
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
@@ -293,7 +356,7 @@ func TestForwarding(t *testing.T) {
clock.Advance(time.Millisecond * randomTimeOffset)
// We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, ipv4.MaxTotalSize, "")
+ e1 := channel.New(1, test.mtu, "")
if err := s.CreateNIC(nicID1, e1); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
@@ -302,7 +365,11 @@ func TestForwarding(t *testing.T) {
t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
}
- e2 := channel.New(1, ipv4.MaxTotalSize, "")
+ expectedEmittedPacketCount := 1
+ if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount {
+ expectedEmittedPacketCount = len(test.expectedFragmentsForwarded)
+ }
+ e2 := channel.New(expectedEmittedPacketCount, test.mtu, linkAddr2)
if err := s.CreateNIC(nicID2, e2); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
@@ -330,9 +397,11 @@ func TestForwarding(t *testing.T) {
if ipHeaderLength > header.IPv4MaximumHeaderSize {
t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
- totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(totalLen))
- icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpHeaderLength := header.ICMPv4MinimumSize
+ totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength
+ hdr := buffer.NewPrependable(totalLength)
+ hdr.Prepend(test.payloadLength)
+ icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength))
icmp.SetIdent(randomIdent)
icmp.SetSequence(randomSequence)
icmp.SetType(header.ICMPv4Echo)
@@ -341,11 +410,12 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(^header.Checksum(icmp, 0))
ip := header.IPv4(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv4Fields{
- TotalLength: totalLen,
+ TotalLength: uint16(totalLength),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: test.TTL,
SrcAddr: test.sourceAddr,
DstAddr: test.destAddr,
+ Flags: test.ipFlags,
})
if len(test.options) != 0 {
ip.SetHeaderLength(uint8(ipHeaderLength))
@@ -360,6 +430,7 @@ func TestForwarding(t *testing.T) {
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
+ requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
reply, ok := e1.Read()
@@ -368,6 +439,18 @@ func TestForwarding(t *testing.T) {
t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
}
+ // We expect the ICMP packet to contain as much of the original packet as
+ // possible up to a limit of 576 bytes, split between payload, IP header,
+ // and ICMP header.
+ expectedICMPPayloadLength := func() int {
+ maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize
+ maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength
+ if len(hdr.View()) > maxICMPPayloadLength {
+ return maxICMPPayloadLength
+ }
+ return len(hdr.View())
+ }
+
checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
checker.SrcAddr(ipv4Addr1.Address),
checker.DstAddr(test.sourceAddr),
@@ -376,41 +459,58 @@ func TestForwarding(t *testing.T) {
checker.ICMPv4Checksum(),
checker.ICMPv4Type(test.icmpType),
checker.ICMPv4Code(test.icmpCode),
- checker.ICMPv4Payload([]byte(hdr.View())),
+ checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
),
)
-
- if n := e2.Drain(); n != 0 {
- t.Fatalf("got e2.Drain() = %d, want = 0", n)
- }
} else if ok {
t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
}
- reply, ok = e2.Read()
if test.expectPacketForwarded {
- if !ok {
- t.Fatal("expected ICMP Echo packet through outgoing NIC")
- }
+ if len(test.expectedFragmentsForwarded) != 0 {
+ fragmentedPackets := []*stack.PacketBuffer{}
+ for i := 0; i < len(test.expectedFragmentsForwarded); i++ {
+ reply, ok = e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo fragment through outgoing NIC")
+ }
+ fragmentedPackets = append(fragmentedPackets, reply.Pkt)
+ }
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(test.sourceAddr),
- checker.DstAddr(test.destAddr),
- checker.TTL(test.TTL-1),
- checker.IPv4Options(test.forwardedOptions),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Code(header.ICMPv4UnusedCode),
- checker.ICMPv4Payload(nil),
- ),
- )
+ // The forwarded packet's TTL will have been decremented.
+ ipHeader := header.IPv4(requestPkt.NetworkHeader().View())
+ ipHeader.SetTTL(ipHeader.TTL() - 1)
+
+ // Forwarded packets have available header bytes equalling the sum of the
+ // maximum IP header size and the maximum size allocated for link layer
+ // headers. In this case, no size is allocated for link layer headers.
+ expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize
+ if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil {
+ t.Error(err)
+ }
+ } else {
+ reply, ok = e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ }
- if n := e1.Drain(); n != 0 {
- t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(test.sourceAddr),
+ checker.DstAddr(test.destAddr),
+ checker.TTL(test.TTL-1),
+ checker.IPv4Options(test.forwardedOptions),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Payload(nil),
+ ),
+ )
+ }
+ } else {
+ if reply, ok = e2.Read(); ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
}
- } else if ok {
- t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
}
boolToInt := func(val bool) uint64 {
@@ -443,6 +543,10 @@ func TestForwarding(t *testing.T) {
if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
}
+
+ if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want)
+ }
})
}
}
@@ -1264,13 +1368,25 @@ func TestIPv4Sanity(t *testing.T) {
}
}
-// comparePayloads compared the contents of all the packets against the contents
-// of the source packet.
-func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+// compareFragments compares the contents of a set of fragmented packets against
+// the contents of a source packet.
+//
+// If withIPHeader is set to true, we will validate the fragmented packets' IP
+// headers against the source packet's IP header. If set to false, we validate
+// the fragmented packets' IP headers against each other.
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error {
// Make a complete array of the sourcePacket packet.
- source := header.IPv4(packets[0].NetworkHeader().View())
+ var source header.IPv4
vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
- source = append(source, vv.ToView()...)
+
+ // If the packet to be fragmented contains an IPv4 header, use that header for
+ // validating fragment headers. Else, use the header of the first fragment.
+ if withIPHeader {
+ source = header.IPv4(vv.ToView())
+ } else {
+ source = header.IPv4(packets[0].NetworkHeader().View())
+ source = append(source, vv.ToView()...)
+ }
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -1293,12 +1409,12 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
if got := fragmentIPHeader.TransportProtocol(); got != proto {
return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto))
}
- if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve {
- return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
- }
if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want)
}
+ if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes)
+ }
if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want)
}
@@ -1314,6 +1430,14 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
sourceCopy.SetChecksum(0)
sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
+
+ // If we are validating against the original IP header, we should exclude the
+ // ID field, which will only be set fo fragmented packets.
+ if withIPHeader {
+ fragmentIPHeader.SetID(0)
+ fragmentIPHeader.SetChecksum(0)
+ fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum())
+ }
if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
}
@@ -1442,7 +1566,7 @@ func TestFragmentationWritePacket(t *testing.T) {
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil {
t.Error(err)
}
})
@@ -1523,7 +1647,7 @@ func TestFragmentationWritePackets(t *testing.T) {
}
fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
- if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil {
t.Error(err)
}
})