diff options
author | gVisor bot <gvisor-bot@google.com> | 2020-12-16 15:38:39 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-12-16 15:41:03 -0800 |
commit | 0ac6636aaf83afa0b644b40319c4b1d3c5185427 (patch) | |
tree | 9b5b56132acb77721cf7dfea2f44fc5f37a158b0 /pkg/tcpip/transport/udp | |
parent | 2ec6e44c9e66d50bdaee40f644dc0779dc946b06 (diff) |
Automated rollback of changelist 346565589
PiperOrigin-RevId: 347911316
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r-- | pkg/tcpip/transport/udp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 102 |
3 files changed, 17 insertions, 93 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 153e8c950..7ebae63d8 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -58,6 +58,5 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 763d1d654..8e16c8435 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1259,6 +1259,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. @@ -1267,10 +1268,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap - // packets at "Parse" instead of when handling a packet. - pkt.Data.CapLength(int(hdr.PayloadLength())) - if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() @@ -1304,7 +1301,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 08980c298..6f89b6271 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,7 +22,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -1915,31 +1914,27 @@ func TestV4UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize - wantPayloadLen := len(payload) + wantLen := len(payload) if tc.largePayload { // To work out the data size we need to simulate what the sender would // have done. The wanted size is the total available minus the sum of // the headers in the UDP AND ICMP packets, given that we know the test // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. - wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength } // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. // Add back the two headers within the payload. - payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength)) - origDgram := header.UDP(payloadIPHeader.Payload()) - wantDgramLen := wantPayloadLen + header.UDPMinimumSize + payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) - if got, want := len(origDgram), wantDgramLen; got != want { - t.Fatalf("got len(origDgram) = %d, want = %d", got, want) + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) } - // Correct UDP length to access payload. - origDgram.SetLength(uint16(wantDgramLen)) - - if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) { - t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want) + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) } }) } @@ -2014,23 +2009,20 @@ func TestV6UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv6(hdr.Payload()) payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantPayloadLen := len(payload) + wantLen := len(payload) if tc.largePayload { - wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize } - wantDgramLen := wantPayloadLen + header.UDPMinimumSize // In case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantDgramLen)) + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want { - t.Fatalf("got len(origDgram) = %d, want = %d", got, want) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) } - // Correct UDP length to access payload. - origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize)) - if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" { - t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff) + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) } }) } @@ -2543,67 +2535,3 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { }) } } - -func TestReceiveShortLength(t *testing.T) { - flows := []testFlow{unicastV4, unicastV6} - for _, flow := range flows { - t.Run(flow.String(), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to wildcard. - bindAddr := tcpip.FullAddress{Port: stackPort} - if err := c.ep.Bind(bindAddr); err != nil { - c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err) - } - - payload := newPayload() - extraBytes := []byte{1, 2, 3, 4} - h := flow.header4Tuple(incoming) - var buf buffer.View - var proto tcpip.NetworkProtocolNumber - - // Build packets with extra bytes not accounted for in the UDP length - // field. - var udp header.UDP - if flow.isV4() { - buf = c.buildV4Packet(payload, &h) - buf = append(buf, extraBytes...) - ip := header.IPv4(buf) - ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes))) - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - proto = ipv4.ProtocolNumber - udp = ip.Payload() - } else { - buf = c.buildV6Packet(payload, &h) - buf = append(buf, extraBytes...) - ip := header.IPv6(buf) - ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes))) - proto = ipv6.ProtocolNumber - udp = ip.Payload() - } - - if diff := cmp.Diff(payload, udp.Payload()); diff != "" { - t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff) - } - - c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Try to receive the data. - v, _, err := c.ep.Read(nil) - if err != nil { - t.Fatalf("c.ep.Read(nil): %s", err) - } - - // Check the payload is read back without extra bytes. - if diff := cmp.Diff(buffer.View(payload), v); diff != "" { - t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff) - } - }) - } -} |