summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go102
4 files changed, 110 insertions, 29 deletions
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 98bdd29db..a6d4fcd59 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -36,10 +36,10 @@ const (
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
- // SrcPort is the "source port" field of a UDP packet.
+ // SrcPort is the "Source Port" field of a UDP packet.
SrcPort uint16
- // DstPort is the "destination port" field of a UDP packet.
+ // DstPort is the "Destination Port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
@@ -64,52 +64,57 @@ const (
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
-// SourcePort returns the "source port" field of the udp header.
+// SourcePort returns the "Source Port" field of the UDP header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
-// DestinationPort returns the "destination port" field of the udp header.
+// DestinationPort returns the "Destination Port" field of the UDP header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
-// Length returns the "length" field of the udp header.
+// Length returns the "Length" field of the UDP header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
- return b[UDPMinimumSize:]
+ return b[:b.Length()][UDPMinimumSize:]
}
-// Checksum returns the "checksum" field of the udp header.
+// Checksum returns the "checksum" field of the UDP header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
-// SetSourcePort sets the "source port" field of the udp header.
+// SetSourcePort sets the "source port" field of the UDP header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
-// SetDestinationPort sets the "destination port" field of the udp header.
+// SetDestinationPort sets the "destination port" field of the UDP header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
-// SetChecksum sets the "checksum" field of the udp header.
+// SetChecksum sets the "checksum" field of the UDP header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
-// SetLength sets the "length" field of the udp header.
+// SetLength sets the "length" field of the UDP header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
-// CalculateChecksum calculates the checksum of the udp packet, given the
+// PayloadLength returns the length of the payload following the UDP header.
+func (b UDP) PayloadLength() uint16 {
+ return b.Length() - UDPMinimumSize
+}
+
+// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 7ebae63d8..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -58,5 +58,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 5043e7aa5..ee1bb29f8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1267,7 +1267,6 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1276,6 +1275,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
+ // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
+ // packets at "Parse" instead of when handling a packet.
+ pkt.Data.CapLength(int(hdr.PayloadLength()))
+
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1309,7 +1312,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: header.UDP(hdr).SourcePort(),
+ Port: hdr.SourcePort(),
},
}
packet.data = pkt.Data
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index e384f52dd..b0b9b2773 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -1827,27 +1828,31 @@ func TestV4UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantLen := len(payload)
+ wantPayloadLen := len(payload)
if tc.largePayload {
// To work out the data size we need to simulate what the sender would
// have done. The wanted size is the total available minus the sum of
// the headers in the UDP AND ICMP packets, given that we know the test
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+ wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
}
// In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
// Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
-
+ payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ wantDgramLen := wantPayloadLen + header.UDPMinimumSize
+
+ if got, want := len(origDgram), wantDgramLen; got != want {
+ t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
}
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ // Correct UDP length to access payload.
+ origDgram.SetLength(uint16(wantDgramLen))
+
+ if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
+ t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
}
})
}
@@ -1922,20 +1927,23 @@ func TestV6UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv6(hdr.Payload())
payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
+ wantPayloadLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
}
+ wantDgramLen := wantPayloadLen + header.UDPMinimumSize
// In case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ payloadIPHeader.SetPayloadLength(uint16(wantDgramLen))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want {
+ t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
}
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ // Correct UDP length to access payload.
+ origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize))
+ if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
+ t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
}
})
}
@@ -2448,3 +2456,67 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
})
}
}
+
+func TestReceiveShortLength(t *testing.T) {
+ flows := []testFlow{unicastV4, unicastV6}
+ for _, flow := range flows {
+ t.Run(flow.String(), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to wildcard.
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err)
+ }
+
+ payload := newPayload()
+ extraBytes := []byte{1, 2, 3, 4}
+ h := flow.header4Tuple(incoming)
+ var buf buffer.View
+ var proto tcpip.NetworkProtocolNumber
+
+ // Build packets with extra bytes not accounted for in the UDP length
+ // field.
+ var udp header.UDP
+ if flow.isV4() {
+ buf = c.buildV4Packet(payload, &h)
+ buf = append(buf, extraBytes...)
+ ip := header.IPv4(buf)
+ ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes)))
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ proto = ipv4.ProtocolNumber
+ udp = ip.Payload()
+ } else {
+ buf = c.buildV6Packet(payload, &h)
+ buf = append(buf, extraBytes...)
+ ip := header.IPv6(buf)
+ ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes)))
+ proto = ipv6.ProtocolNumber
+ udp = ip.Payload()
+ }
+
+ if diff := cmp.Diff(payload, udp.Payload()); diff != "" {
+ t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff)
+ }
+
+ c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ // Try to receive the data.
+ v, _, err := c.ep.Read(nil)
+ if err != nil {
+ t.Fatalf("c.ep.Read(..): %s", err)
+ }
+
+ // Check the payload is read back without extra bytes.
+ if diff := cmp.Diff(buffer.View(payload), v); diff != "" {
+ t.Errorf("c.ep.Read(..) mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}