summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go111
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go113
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go170
6 files changed, 336 insertions, 68 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index ac2666f69..c1ca22b35 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "udp_packet_list",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index ac5905772..0bec7e62d 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "math"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -37,13 +36,17 @@ type udpPacket struct {
views [8]buffer.View `state:"nosave"`
}
-type endpointState int
+// EndpointState represents the state of a UDP endpoint.
+type EndpointState uint32
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnected
+ StateClosed
)
// endpoint represents a UDP endpoint. This struct serves as the interface
@@ -74,7 +77,7 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
sndBufSize int
id stack.TransportEndpointID
- state endpointState
+ state EndpointState
bindNICID tcpip.NICID
regNICID tcpip.NICID
route stack.Route `state:"manual"`
@@ -140,7 +143,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
- case stateBound, stateConnected:
+ case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
}
@@ -163,7 +166,7 @@ func (e *endpoint) Close() {
e.route.Release()
// Update the state.
- e.state = stateClosed
+ e.state = StateClosed
e.mu.Unlock()
@@ -211,11 +214,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
switch e.state {
- case stateInitial:
- case stateConnected:
+ case StateInitial:
+ case StateConnected:
return false, nil
- case stateBound:
+ case StateBound:
if to == nil {
return false, tcpip.ErrDestinationRequired
}
@@ -232,7 +235,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return true, nil
}
@@ -273,17 +276,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
- if p.Size() > math.MaxUint16 {
- // Payload can't possibly fit in a packet.
- return 0, nil, tcpip.ErrMessageTooLong
- }
-
to := opts.To
e.mu.RLock()
@@ -322,7 +320,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
- if e.state != stateConnected {
+ if e.state != StateConnected {
return 0, nil, tcpip.ErrInvalidEndpointState
}
}
@@ -366,10 +364,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
+ if len(v) > header.UDPMaximumPacketSize {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
ttl := route.DefaultTTL()
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
@@ -387,7 +389,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
@@ -400,7 +407,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -566,7 +573,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
}
+
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -576,18 +596,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -726,7 +734,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return nil
}
id := stack.TransportEndpointID{}
@@ -741,9 +749,13 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if err != nil {
return err
}
- e.state = stateBound
+ e.state = StateBound
} else {
- e.state = stateInitial
+ if e.id.LocalPort != 0 {
+ // Release the ephemeral port.
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ }
+ e.state = StateInitial
}
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
@@ -772,8 +784,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicid := addr.NIC
var localPort uint16
switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
+ case StateInitial:
+ case StateBound, StateConnected:
localPort = e.id.LocalPort
if e.bindNICID == 0 {
break
@@ -801,7 +813,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
RemoteAddress: r.RemoteAddress,
}
- if e.state == stateInitial {
+ if e.state == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -832,7 +844,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.regNICID = nicid
e.effectiveNetProtos = netProtos
- e.state = stateConnected
+ e.state = StateConnected
e.rcvMu.Lock()
e.rcvReady = true
@@ -854,7 +866,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -903,7 +915,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -946,7 +958,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = stateBound
+ e.state = StateBound
e.rcvMu.Lock()
e.rcvReady = true
@@ -989,7 +1001,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1069,10 +1081,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
-// State implements socket.Socket.State.
+// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- // TODO(b/112063468): Translate internal state to values returned by Linux.
- return 0
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return uint32(e.state)
}
func isBroadcastOrMulticast(a tcpip.Address) bool {
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 5cbb56120..be46e6d4e 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -77,7 +77,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return
}
@@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == stateConnected {
+ if e.state == StateConnected {
e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
panic(err)
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a874fc9fd..a9edc2c8d 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -84,7 +84,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.dstPort = r.id.RemotePort
ep.regNICID = r.route.NICID()
- ep.state = stateConnected
+ ep.state = StateConnected
ep.rcvMu.Lock()
ep.rcvReady = true
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index f76e7fbe1..f5cc932dd 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -14,7 +14,7 @@
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// activated on the stack by passing udp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -30,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the udp protocol name.
- ProtocolName = "udp"
-
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
@@ -69,7 +66,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ return true
+ }
+ // TODO(b/129426613): only send an ICMP message if UDP checksum is valid.
+
+ // Only send ICMP error if the address is not a multicast/broadcast
+ // v4/v6 address or the source is not the unspecified address.
+ //
+ // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
+ return true
+ }
+
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ //
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported; or
+ //
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ switch len(id.LocalAddress) {
+ case header.IPv4AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
+ return true
+ }
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 where it mentioned that at least 8
+ // bytes of the payload must be included. Today linux and other
+ // systems implement the] RFC1812 definition and not the original
+ // RFC 1122 requirement.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4DstUnreachable)
+ pkt.SetCode(header.ICMPv4PortUnreachable)
+ pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL())
+
+ case header.IPv6AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
+ return true
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
+ pkt.SetType(header.ICMPv6DstUnreachable)
+ pkt.SetCode(header.ICMPv6PortUnreachable)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL())
+ }
return true
}
@@ -83,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
- })
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 9da6edce2..2ec27be4d 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -274,13 +274,17 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ ep := channel.New(256, mtu, "")
+ wep := stack.LinkEndpoint(ep)
- id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -306,7 +310,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
@@ -461,7 +465,11 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
}
func newPayload() []byte {
- b := make([]byte, 30+rand.Intn(100))
+ return newMinPayload(30)
+}
+
+func newMinPayload(minSize int) []byte {
+ b := make([]byte, minSize+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
@@ -1238,3 +1246,153 @@ func TestMulticastInterfaceOption(t *testing.T) {
})
}
}
+
+// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV4UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true, will result in a payload large enough
+ // so that the final generated IPv4 packet is larger than
+ // header.IPv4MinimumProcessableDatagramSize.
+ largePayload bool
+ }{
+ {unicastV4, true, false},
+ {unicastV4, true, true},
+ {multicastV4, false, false},
+ {multicastV4, false, true},
+ {broadcast, false, false},
+ {broadcast, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(576)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
+
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
+
+// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV6UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true will result in a payload large enough to
+ // create an IPv6 packet > header.IPv6MinimumMTU bytes.
+ largePayload bool
+ }{
+ {unicastV6, true, false},
+ {unicastV6, true, true},
+ {multicastV6, false, false},
+ {multicastV6, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(1280)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}