summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2020-12-16 15:38:39 -0800
committergVisor bot <gvisor-bot@google.com>2020-12-16 15:41:03 -0800
commit0ac6636aaf83afa0b644b40319c4b1d3c5185427 (patch)
tree9b5b56132acb77721cf7dfea2f44fc5f37a158b0 /pkg/tcpip/transport/udp
parent2ec6e44c9e66d50bdaee40f644dc0779dc946b06 (diff)
Automated rollback of changelist 346565589
PiperOrigin-RevId: 347911316
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go102
3 files changed, 17 insertions, 93 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 153e8c950..7ebae63d8 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -58,6 +58,5 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 763d1d654..8e16c8435 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1259,6 +1259,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1267,10 +1268,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
- // packets at "Parse" instead of when handling a packet.
- pkt.Data.CapLength(int(hdr.PayloadLength()))
-
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1304,7 +1301,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: hdr.SourcePort(),
+ Port: header.UDP(hdr).SourcePort(),
},
destinationAddress: tcpip.FullAddress{
NIC: pkt.NICID,
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 08980c298..6f89b6271 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,7 +22,6 @@ import (
"testing"
"time"
- "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -1915,31 +1914,27 @@ func TestV4UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantPayloadLen := len(payload)
+ wantLen := len(payload)
if tc.largePayload {
// To work out the data size we need to simulate what the sender would
// have done. The wanted size is the total available minus the sum of
// the headers in the UDP AND ICMP packets, given that we know the test
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
- wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
}
// In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
// Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength))
- origDgram := header.UDP(payloadIPHeader.Payload())
- wantDgramLen := wantPayloadLen + header.UDPMinimumSize
+ payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
- if got, want := len(origDgram), wantDgramLen; got != want {
- t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
}
- // Correct UDP length to access payload.
- origDgram.SetLength(uint16(wantDgramLen))
-
- if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
- t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
}
})
}
@@ -2014,23 +2009,20 @@ func TestV6UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv6(hdr.Payload())
payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantPayloadLen := len(payload)
+ wantLen := len(payload)
if tc.largePayload {
- wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
}
- wantDgramLen := wantPayloadLen + header.UDPMinimumSize
// In case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantDgramLen))
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want {
- t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
}
- // Correct UDP length to access payload.
- origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize))
- if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" {
- t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff)
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
}
})
}
@@ -2543,67 +2535,3 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
})
}
}
-
-func TestReceiveShortLength(t *testing.T) {
- flows := []testFlow{unicastV4, unicastV6}
- for _, flow := range flows {
- t.Run(flow.String(), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to wildcard.
- bindAddr := tcpip.FullAddress{Port: stackPort}
- if err := c.ep.Bind(bindAddr); err != nil {
- c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err)
- }
-
- payload := newPayload()
- extraBytes := []byte{1, 2, 3, 4}
- h := flow.header4Tuple(incoming)
- var buf buffer.View
- var proto tcpip.NetworkProtocolNumber
-
- // Build packets with extra bytes not accounted for in the UDP length
- // field.
- var udp header.UDP
- if flow.isV4() {
- buf = c.buildV4Packet(payload, &h)
- buf = append(buf, extraBytes...)
- ip := header.IPv4(buf)
- ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes)))
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- proto = ipv4.ProtocolNumber
- udp = ip.Payload()
- } else {
- buf = c.buildV6Packet(payload, &h)
- buf = append(buf, extraBytes...)
- ip := header.IPv6(buf)
- ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes)))
- proto = ipv6.ProtocolNumber
- udp = ip.Payload()
- }
-
- if diff := cmp.Diff(payload, udp.Payload()); diff != "" {
- t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff)
- }
-
- c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- // Try to receive the data.
- v, _, err := c.ep.Read(nil)
- if err != nil {
- t.Fatalf("c.ep.Read(nil): %s", err)
- }
-
- // Check the payload is read back without extra bytes.
- if diff := cmp.Diff(buffer.View(payload), v); diff != "" {
- t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}