summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go102
4 files changed, 29 insertions, 110 deletions
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index a6d4fcd59..98bdd29db 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -36,10 +36,10 @@ const (
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
- // SrcPort is the "Source Port" field of a UDP packet.
+ // SrcPort is the "source port" field of a UDP packet.
SrcPort uint16
- // DstPort is the "Destination Port" field of a UDP packet.
+ // DstPort is the "destination port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
@@ -64,57 +64,52 @@ const (
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
-// SourcePort returns the "Source Port" field of the UDP header.
+// SourcePort returns the "source port" field of the udp header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
-// DestinationPort returns the "Destination Port" field of the UDP header.
+// DestinationPort returns the "destination port" field of the udp header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
-// Length returns the "Length" field of the UDP header.
+// Length returns the "length" field of the udp header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
- return b[:b.Length()][UDPMinimumSize:]
+ return b[UDPMinimumSize:]
}
-// Checksum returns the "checksum" field of the UDP header.
+// Checksum returns the "checksum" field of the udp header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
-// SetSourcePort sets the "source port" field of the UDP header.
+// SetSourcePort sets the "source port" field of the udp header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
-// SetDestinationPort sets the "destination port" field of the UDP header.
+// SetDestinationPort sets the "destination port" field of the udp header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
-// SetChecksum sets the "checksum" field of the UDP header.
+// SetChecksum sets the "checksum" field of the udp header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
-// SetLength sets the "length" field of the UDP header.
+// SetLength sets the "length" field of the udp header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
-// PayloadLength returns the length of the payload following the UDP header.
-func (b UDP) PayloadLength() uint16 {
- return b.Length() - UDPMinimumSize
-}
-
-// CalculateChecksum calculates the checksum of the UDP packet, given the
+// CalculateChecksum calculates the checksum of the udp packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 153e8c950..7ebae63d8 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -58,6 +58,5 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 763d1d654..8e16c8435 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1259,6 +1259,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1267,10 +1268,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
- // packets at "Parse" instead of when handling a packet.
- pkt.Data.CapLength(int(hdr.PayloadLength()))
-
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1304,7 +1301,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: hdr.SourcePort(),
+ Port: header.UDP(hdr).SourcePort(),
},
destinationAddress: tcpip.FullAddress{
NIC: pkt.NICID,
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 08980c298..6f89b6271 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,7 +22,6 @@ import (
"testing"
"time"
- "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -1915,31 +1914,27 @@ func TestV4UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantPayloadLen := len(payload)
+ wantLen := len(payload)
if tc.largePayload {
// To work out the data size we need to simulate what the sender would
// have done. The wanted size is the total available minus the sum of
// the headers in the UDP AND ICMP packets, given that we know the test
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
- wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
}
// In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
// Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength))
- origDgram := header.UDP(payloadIPHeader.Payload())
- wantDgramLen := wantPayloadLen + header.UDPMinimumSize
+ payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
- if got, want := len(origDgram), wantDgramLen; got != want {
- t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
}
- // Correct UDP length to access payload.
- origDgram.SetLength(uint16(wantDgramLen))
-
- if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
- t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
}
})
}
@@ -2014,23 +2009,20 @@ func TestV6UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv6(hdr.Payload())
payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantPayloadLen := len(payload)
+ wantLen := len(payload)
if tc.largePayload {
- wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
}
- wantDgramLen := wantPayloadLen + header.UDPMinimumSize
// In case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantDgramLen))
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want {
- t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
}
- // Correct UDP length to access payload.
- origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize))
- if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" {
- t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff)
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
}
})
}
@@ -2543,67 +2535,3 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
})
}
}
-
-func TestReceiveShortLength(t *testing.T) {
- flows := []testFlow{unicastV4, unicastV6}
- for _, flow := range flows {
- t.Run(flow.String(), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to wildcard.
- bindAddr := tcpip.FullAddress{Port: stackPort}
- if err := c.ep.Bind(bindAddr); err != nil {
- c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err)
- }
-
- payload := newPayload()
- extraBytes := []byte{1, 2, 3, 4}
- h := flow.header4Tuple(incoming)
- var buf buffer.View
- var proto tcpip.NetworkProtocolNumber
-
- // Build packets with extra bytes not accounted for in the UDP length
- // field.
- var udp header.UDP
- if flow.isV4() {
- buf = c.buildV4Packet(payload, &h)
- buf = append(buf, extraBytes...)
- ip := header.IPv4(buf)
- ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes)))
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- proto = ipv4.ProtocolNumber
- udp = ip.Payload()
- } else {
- buf = c.buildV6Packet(payload, &h)
- buf = append(buf, extraBytes...)
- ip := header.IPv6(buf)
- ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes)))
- proto = ipv6.ProtocolNumber
- udp = ip.Payload()
- }
-
- if diff := cmp.Diff(payload, udp.Payload()); diff != "" {
- t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff)
- }
-
- c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- // Try to receive the data.
- v, _, err := c.ep.Read(nil)
- if err != nil {
- t.Fatalf("c.ep.Read(nil): %s", err)
- }
-
- // Check the payload is read back without extra bytes.
- if diff := cmp.Diff(buffer.View(payload), v); diff != "" {
- t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}