diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/udp.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 102 |
4 files changed, 110 insertions, 29 deletions
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 98bdd29db..a6d4fcd59 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -36,10 +36,10 @@ const ( // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { - // SrcPort is the "source port" field of a UDP packet. + // SrcPort is the "Source Port" field of a UDP packet. SrcPort uint16 - // DstPort is the "destination port" field of a UDP packet. + // DstPort is the "Destination Port" field of a UDP packet. DstPort uint16 // Length is the "length" field of a UDP packet. @@ -64,52 +64,57 @@ const ( UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) -// SourcePort returns the "source port" field of the udp header. +// SourcePort returns the "Source Port" field of the UDP header. func (b UDP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } -// DestinationPort returns the "destination port" field of the udp header. +// DestinationPort returns the "Destination Port" field of the UDP header. func (b UDP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } -// Length returns the "length" field of the udp header. +// Length returns the "Length" field of the UDP header. func (b UDP) Length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } // Payload returns the data contained in the UDP datagram. func (b UDP) Payload() []byte { - return b[UDPMinimumSize:] + return b[:b.Length()][UDPMinimumSize:] } -// Checksum returns the "checksum" field of the udp header. +// Checksum returns the "checksum" field of the UDP header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) } -// SetSourcePort sets the "source port" field of the udp header. +// SetSourcePort sets the "source port" field of the UDP header. func (b UDP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[udpSrcPort:], port) } -// SetDestinationPort sets the "destination port" field of the udp header. +// SetDestinationPort sets the "destination port" field of the UDP header. func (b UDP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[udpDstPort:], port) } -// SetChecksum sets the "checksum" field of the udp header. +// SetChecksum sets the "checksum" field of the UDP header. func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } -// SetLength sets the "length" field of the udp header. +// SetLength sets the "length" field of the UDP header. func (b UDP) SetLength(length uint16) { binary.BigEndian.PutUint16(b[udpLength:], length) } -// CalculateChecksum calculates the checksum of the udp packet, given the +// PayloadLength returns the length of the payload following the UDP header. +func (b UDP) PayloadLength() uint16 { + return b.Length() - UDPMinimumSize +} + +// CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 7ebae63d8..153e8c950 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -58,5 +58,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5043e7aa5..ee1bb29f8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1267,7 +1267,6 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. @@ -1276,6 +1275,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } + // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap + // packets at "Parse" instead of when handling a packet. + pkt.Data.CapLength(int(hdr.PayloadLength())) + if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() @@ -1309,7 +1312,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index e384f52dd..b0b9b2773 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -1827,27 +1828,31 @@ func TestV4UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize - wantLen := len(payload) + wantPayloadLen := len(payload) if tc.largePayload { // To work out the data size we need to simulate what the sender would // have done. The wanted size is the total available minus the sum of // the headers in the UDP AND ICMP packets, given that we know the test // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength + wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength } // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. // Add back the two headers within the payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) - + payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength)) origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + wantDgramLen := wantPayloadLen + header.UDPMinimumSize + + if got, want := len(origDgram), wantDgramLen; got != want { + t.Fatalf("got len(origDgram) = %d, want = %d", got, want) } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) + // Correct UDP length to access payload. + origDgram.SetLength(uint16(wantDgramLen)) + + if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) { + t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want) } }) } @@ -1922,20 +1927,23 @@ func TestV6UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv6(hdr.Payload()) payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) + wantPayloadLen := len(payload) if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize } + wantDgramLen := wantPayloadLen + header.UDPMinimumSize // In case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + payloadIPHeader.SetPayloadLength(uint16(wantDgramLen)) origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want { + t.Fatalf("got len(origDgram) = %d, want = %d", got, want) } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) + // Correct UDP length to access payload. + origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize)) + if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) { + t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want) } }) } @@ -2448,3 +2456,67 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { }) } } + +func TestReceiveShortLength(t *testing.T) { + flows := []testFlow{unicastV4, unicastV6} + for _, flow := range flows { + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err) + } + + payload := newPayload() + extraBytes := []byte{1, 2, 3, 4} + h := flow.header4Tuple(incoming) + var buf buffer.View + var proto tcpip.NetworkProtocolNumber + + // Build packets with extra bytes not accounted for in the UDP length + // field. + var udp header.UDP + if flow.isV4() { + buf = c.buildV4Packet(payload, &h) + buf = append(buf, extraBytes...) + ip := header.IPv4(buf) + ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes))) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + proto = ipv4.ProtocolNumber + udp = ip.Payload() + } else { + buf = c.buildV6Packet(payload, &h) + buf = append(buf, extraBytes...) + ip := header.IPv6(buf) + ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes))) + proto = ipv6.ProtocolNumber + udp = ip.Payload() + } + + if diff := cmp.Diff(payload, udp.Payload()); diff != "" { + t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff) + } + + c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + // Try to receive the data. + v, _, err := c.ep.Read(nil) + if err != nil { + t.Fatalf("c.ep.Read(..): %s", err) + } + + // Check the payload is read back without extra bytes. + if diff := cmp.Diff(buffer.View(payload), v); diff != "" { + t.Errorf("c.ep.Read(..) mismatch (-want +got):\n%s", diff) + } + }) + } +} |