diff options
-rw-r--r-- | pkg/tcpip/header/udp.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 102 |
4 files changed, 29 insertions, 110 deletions
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index a6d4fcd59..98bdd29db 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -36,10 +36,10 @@ const ( // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { - // SrcPort is the "Source Port" field of a UDP packet. + // SrcPort is the "source port" field of a UDP packet. SrcPort uint16 - // DstPort is the "Destination Port" field of a UDP packet. + // DstPort is the "destination port" field of a UDP packet. DstPort uint16 // Length is the "length" field of a UDP packet. @@ -64,57 +64,52 @@ const ( UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) -// SourcePort returns the "Source Port" field of the UDP header. +// SourcePort returns the "source port" field of the udp header. func (b UDP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } -// DestinationPort returns the "Destination Port" field of the UDP header. +// DestinationPort returns the "destination port" field of the udp header. func (b UDP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } -// Length returns the "Length" field of the UDP header. +// Length returns the "length" field of the udp header. func (b UDP) Length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } // Payload returns the data contained in the UDP datagram. func (b UDP) Payload() []byte { - return b[:b.Length()][UDPMinimumSize:] + return b[UDPMinimumSize:] } -// Checksum returns the "checksum" field of the UDP header. +// Checksum returns the "checksum" field of the udp header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) } -// SetSourcePort sets the "source port" field of the UDP header. +// SetSourcePort sets the "source port" field of the udp header. func (b UDP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[udpSrcPort:], port) } -// SetDestinationPort sets the "destination port" field of the UDP header. +// SetDestinationPort sets the "destination port" field of the udp header. func (b UDP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[udpDstPort:], port) } -// SetChecksum sets the "checksum" field of the UDP header. +// SetChecksum sets the "checksum" field of the udp header. func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } -// SetLength sets the "length" field of the UDP header. +// SetLength sets the "length" field of the udp header. func (b UDP) SetLength(length uint16) { binary.BigEndian.PutUint16(b[udpLength:], length) } -// PayloadLength returns the length of the payload following the UDP header. -func (b UDP) PayloadLength() uint16 { - return b.Length() - UDPMinimumSize -} - -// CalculateChecksum calculates the checksum of the UDP packet, given the +// CalculateChecksum calculates the checksum of the udp packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 153e8c950..7ebae63d8 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -58,6 +58,5 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 763d1d654..8e16c8435 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1259,6 +1259,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. @@ -1267,10 +1268,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap - // packets at "Parse" instead of when handling a packet. - pkt.Data.CapLength(int(hdr.PayloadLength())) - if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() @@ -1304,7 +1301,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 08980c298..6f89b6271 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,7 +22,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -1915,31 +1914,27 @@ func TestV4UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize - wantPayloadLen := len(payload) + wantLen := len(payload) if tc.largePayload { // To work out the data size we need to simulate what the sender would // have done. The wanted size is the total available minus the sum of // the headers in the UDP AND ICMP packets, given that we know the test // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. - wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength } // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. // Add back the two headers within the payload. - payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength)) - origDgram := header.UDP(payloadIPHeader.Payload()) - wantDgramLen := wantPayloadLen + header.UDPMinimumSize + payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) - if got, want := len(origDgram), wantDgramLen; got != want { - t.Fatalf("got len(origDgram) = %d, want = %d", got, want) + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) } - // Correct UDP length to access payload. - origDgram.SetLength(uint16(wantDgramLen)) - - if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) { - t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want) + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) } }) } @@ -2014,23 +2009,20 @@ func TestV6UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv6(hdr.Payload()) payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantPayloadLen := len(payload) + wantLen := len(payload) if tc.largePayload { - wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize } - wantDgramLen := wantPayloadLen + header.UDPMinimumSize // In case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantDgramLen)) + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want { - t.Fatalf("got len(origDgram) = %d, want = %d", got, want) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) } - // Correct UDP length to access payload. - origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize)) - if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" { - t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff) + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) } }) } @@ -2543,67 +2535,3 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { }) } } - -func TestReceiveShortLength(t *testing.T) { - flows := []testFlow{unicastV4, unicastV6} - for _, flow := range flows { - t.Run(flow.String(), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to wildcard. - bindAddr := tcpip.FullAddress{Port: stackPort} - if err := c.ep.Bind(bindAddr); err != nil { - c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err) - } - - payload := newPayload() - extraBytes := []byte{1, 2, 3, 4} - h := flow.header4Tuple(incoming) - var buf buffer.View - var proto tcpip.NetworkProtocolNumber - - // Build packets with extra bytes not accounted for in the UDP length - // field. - var udp header.UDP - if flow.isV4() { - buf = c.buildV4Packet(payload, &h) - buf = append(buf, extraBytes...) - ip := header.IPv4(buf) - ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes))) - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - proto = ipv4.ProtocolNumber - udp = ip.Payload() - } else { - buf = c.buildV6Packet(payload, &h) - buf = append(buf, extraBytes...) - ip := header.IPv6(buf) - ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes))) - proto = ipv6.ProtocolNumber - udp = ip.Payload() - } - - if diff := cmp.Diff(payload, udp.Payload()); diff != "" { - t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff) - } - - c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Try to receive the data. - v, _, err := c.ep.Read(nil) - if err != nil { - t.Fatalf("c.ep.Read(nil): %s", err) - } - - // Check the payload is read back without extra bytes. - if diff := cmp.Diff(buffer.View(payload), v); diff != "" { - t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff) - } - }) - } -} |