summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorBruno Dal Bo <brunodalbo@google.com>2020-12-09 09:14:45 -0800
committerShentubot <shentubot@google.com>2020-12-09 15:52:28 -0800
commitf6cb96bd57dec4e3baa8c57ccdeb0f1d8706b682 (patch)
tree10377621928f6167b02e4cf8376599d5d3301238 /pkg/tcpip/transport
parent658f874b94ad83d9b4abed47d9aba856fe5f617c (diff)
Cap UDP payload size to length informed in UDP header
startblock: has LGTM from peterjohnston and then add reviewer ghanan,tamird PiperOrigin-RevId: 346565589
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go102
3 files changed, 93 insertions, 17 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 7ebae63d8..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -58,5 +58,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 5043e7aa5..ee1bb29f8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1267,7 +1267,6 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1276,6 +1275,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
+ // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
+ // packets at "Parse" instead of when handling a packet.
+ pkt.Data.CapLength(int(hdr.PayloadLength()))
+
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1309,7 +1312,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: header.UDP(hdr).SourcePort(),
+ Port: hdr.SourcePort(),
},
}
packet.data = pkt.Data
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index e384f52dd..b0b9b2773 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -1827,27 +1828,31 @@ func TestV4UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantLen := len(payload)
+ wantPayloadLen := len(payload)
if tc.largePayload {
// To work out the data size we need to simulate what the sender would
// have done. The wanted size is the total available minus the sum of
// the headers in the UDP AND ICMP packets, given that we know the test
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+ wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
}
// In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
// Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
-
+ payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ wantDgramLen := wantPayloadLen + header.UDPMinimumSize
+
+ if got, want := len(origDgram), wantDgramLen; got != want {
+ t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
}
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ // Correct UDP length to access payload.
+ origDgram.SetLength(uint16(wantDgramLen))
+
+ if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
+ t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
}
})
}
@@ -1922,20 +1927,23 @@ func TestV6UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv6(hdr.Payload())
payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
+ wantPayloadLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
}
+ wantDgramLen := wantPayloadLen + header.UDPMinimumSize
// In case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ payloadIPHeader.SetPayloadLength(uint16(wantDgramLen))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want {
+ t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
}
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ // Correct UDP length to access payload.
+ origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize))
+ if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
+ t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
}
})
}
@@ -2448,3 +2456,67 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
})
}
}
+
+func TestReceiveShortLength(t *testing.T) {
+ flows := []testFlow{unicastV4, unicastV6}
+ for _, flow := range flows {
+ t.Run(flow.String(), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to wildcard.
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err)
+ }
+
+ payload := newPayload()
+ extraBytes := []byte{1, 2, 3, 4}
+ h := flow.header4Tuple(incoming)
+ var buf buffer.View
+ var proto tcpip.NetworkProtocolNumber
+
+ // Build packets with extra bytes not accounted for in the UDP length
+ // field.
+ var udp header.UDP
+ if flow.isV4() {
+ buf = c.buildV4Packet(payload, &h)
+ buf = append(buf, extraBytes...)
+ ip := header.IPv4(buf)
+ ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes)))
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ proto = ipv4.ProtocolNumber
+ udp = ip.Payload()
+ } else {
+ buf = c.buildV6Packet(payload, &h)
+ buf = append(buf, extraBytes...)
+ ip := header.IPv6(buf)
+ ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes)))
+ proto = ipv6.ProtocolNumber
+ udp = ip.Payload()
+ }
+
+ if diff := cmp.Diff(payload, udp.Payload()); diff != "" {
+ t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff)
+ }
+
+ c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ // Try to receive the data.
+ v, _, err := c.ep.Read(nil)
+ if err != nil {
+ t.Fatalf("c.ep.Read(..): %s", err)
+ }
+
+ // Check the payload is read back without extra bytes.
+ if diff := cmp.Diff(buffer.View(payload), v); diff != "" {
+ t.Errorf("c.ep.Read(..) mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}