summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip/transport/udp
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/BUILD77
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go746
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go91
-rw-r--r--pkg/tcpip/transport/udp/protocol.go73
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go625
5 files changed, 1612 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
new file mode 100644
index 000000000..ac34a932e
--- /dev/null
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -0,0 +1,77 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "udp_state",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "udp_packet_list.go",
+ ],
+ out = "udp_state.go",
+ imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
+ package = "udp",
+)
+
+go_template_instance(
+ name = "udp_packet_list",
+ out = "udp_packet_list.go",
+ package = "udp",
+ prefix = "udpPacket",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Linker": "*udpPacket",
+ },
+)
+
+go_library(
+ name = "udp",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "protocol.go",
+ "udp_packet_list.go",
+ "udp_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/state",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "udp_x_test",
+ size = "small",
+ srcs = ["udp_test.go"],
+ deps = [
+ ":udp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "udp_packet_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
new file mode 100644
index 000000000..80fa88c4c
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -0,0 +1,746 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package udp
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type udpPacket struct {
+ udpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents a UDP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+type endpoint struct {
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+}
+
+// NewConnectedEndpoint creates a new endpoint in the connected state using the
+// provided route.
+func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := newEndpoint(stack, r.NetProto, waiterQueue)
+
+ // Register new endpoint so that packets are routed to it.
+ if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ ep.id = id
+ ep.route = r.Clone()
+ ep.dstPort = id.RemotePort
+ ep.regNICID = r.NICID()
+
+ ep.state = stateConnected
+
+ return ep, nil
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ }
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ var dstPort uint16
+ if to == nil {
+ route = &e.route
+ dstPort = e.dstPort
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route, given that it may need to
+ // change in Route.Resolve() call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ toCopy := *to
+ to = &toCopy
+ netProto, err := e.checkV4Mapped(to, true)
+ if err != nil {
+ return 0, err
+ }
+
+ // Find the enpoint.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto)
+ if err != nil {
+ return 0, err
+ }
+ defer r.Release()
+
+ route = &r
+ dstPort = to.Port
+ }
+
+ if route.IsResolutionRequired() {
+ waker := &sleep.Waker{}
+ if err := route.Resolve(waker); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Link address needs to be resolved. Resolution was triggered the background.
+ // Better luck next time.
+ //
+ // TODO: queue up the request and send after link address
+ // is resolved.
+ route.RemoveWaker(waker)
+ return 0, tcpip.ErrNoLinkAddress
+ }
+ return 0, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, err
+ }
+ sendUDP(route, v, e.id.LocalPort, dstPort)
+ return uintptr(len(v)), nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, *tcpip.Error) {
+ return 0, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ // TODO: Actually implement this.
+ switch v := opt.(type) {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+ }
+
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// sendUDP sends a UDP segment via the provided network endpoint and under the
+// provided identity.
+func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16) *tcpip.Error {
+ // Allocate a buffer for the UDP header.
+ hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+
+ // Initialize the header.
+ udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+
+ length := uint16(hdr.UsedLength()) + uint16(len(data))
+ udp.Encode(&header.UDPFields{
+ SrcPort: localPort,
+ DstPort: remotePort,
+ Length: length,
+ })
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber)
+ if data != nil {
+ xsum = header.Checksum(data, xsum)
+ }
+
+ udp.SetChecksum(^udp.CalculateChecksum(xsum, length))
+ }
+
+ return r.WritePacket(&hdr, data, ProtocolNumber)
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Port == 0 {
+ // We don't support connecting to port zero.
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ localPort := uint16(0)
+ switch e.state {
+ case stateInitial:
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemotePort: addr.Port,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if e.netProto == header.IPv6ProtocolNumber && !e.v6only {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
+ }
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ // Remove the old registration.
+ if e.id.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.dstPort = addr.Port
+ e.regNICID = nicid
+ e.effectiveNetProtos = netProtos
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if id.LocalPort != 0 {
+ // The endpoint already has a local port, just attempt to
+ // register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
+ return id, err
+ }
+
+ // We need to find a port for the endpoint.
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ id.LocalPort = p
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ if err != nil {
+ return err
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ // Unregister, the commit failed.
+ e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id)
+ return err
+ }
+ }
+
+ e.id = id
+ e.regNICID = addr.NIC
+ e.effectiveNetProtos = netProtos
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr, commit)
+ if err != nil {
+ return err
+ }
+
+ e.bindNICID = addr.NIC
+ e.bindAddr = addr.Addr
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ return
+ }
+
+ vv.TrimFront(header.UDPMinimumSize)
+
+ e.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &udpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ Port: hdr.SourcePort(),
+ },
+ }
+ pkt.data = vv.Clone(pkt.views[:])
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += vv.Size()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
new file mode 100644
index 000000000..41b98424a
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -0,0 +1,91 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves udpPacket.data field.
+func (u *udpPacket) saveData() buffer.VectorisedView {
+ // We canoot save u.data directly as u.data.views may alias to u.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return u.data.Clone(nil)
+}
+
+// loadData loads udpPacket.data field.
+func (u *udpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing u.views for data.views.
+ u.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ netProto := e.effectiveNetProtos[0]
+ // Connect() and bindLocked() both assert
+ //
+ // netProto == header.IPv6ProtocolNumber
+ //
+ // before creating a multi-entry effectiveNetProtos.
+ if len(e.effectiveNetProtos) > 1 {
+ netProto = header.IPv6ProtocolNumber
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, netProto)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, e.id)
+ if err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
new file mode 100644
index 000000000..fa30e7201
--- /dev/null
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -0,0 +1,73 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package udp contains the implementation of the UDP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing udp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the udp protocol name.
+ ProtocolName = "udp"
+
+ // ProtocolNumber is the udp protocol number.
+ ProtocolNumber = header.UDPProtocolNumber
+)
+
+type protocol struct{}
+
+// Number returns the udp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new udp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// MinimumPacketSize returns the minimum valid udp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.UDPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given udp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.UDP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
new file mode 100644
index 000000000..65c567952
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -0,0 +1,625 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package udp_test
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
+ testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
+ V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testAddr = "\x0a\x00\x00\x02"
+ testPort = 4096
+
+ // defaultMTU is the MTU, in bytes, used throughout the tests, except
+ // where another value is explicitly used. It is chosen to match the MTU
+ // of loopback interfaces on linux systems.
+ defaultMTU = 65536
+)
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newDualTestContext(t *testing.T, mtu uint32) *testContext {
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName})
+
+ id, linkEP := channel.New(256, mtu, "")
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ {
+ Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ }
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v4only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v4only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+}
+
+func (c *testContext) getV6Packet() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr))
+ return b
+
+ case <-time.After(2 * time.Second):
+ c.t.Fatalf("Packet wasn't written out")
+ }
+
+ return nil
+}
+
+func (c *testContext) getPacket() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr))
+ return b
+
+ case <-time.After(2 * time.Second):
+ c.t.Fatalf("Packet wasn't written out")
+ }
+
+ return nil
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.Checksum([]byte(testV6Addr), 0)
+ xsum = header.Checksum([]byte(stackV6Addr), xsum)
+ xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
+
+ // Calculate the UDP checksum and set it.
+ length := uint16(header.UDPMinimumSize + len(payload))
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum, length))
+
+ // Inject packet.
+ var views [1]buffer.View
+ vv := buf.ToVectorisedView(views)
+ c.linkEP.Inject(ipv6.ProtocolNumber, &vv)
+}
+
+func (c *testContext) sendPacket(payload []byte, h *headers) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: testAddr,
+ DstAddr: stackAddr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.Checksum([]byte(testAddr), 0)
+ xsum = header.Checksum([]byte(stackAddr), xsum)
+ xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
+
+ // Calculate the UDP checksum and set it.
+ length := uint16(header.UDPMinimumSize + len(payload))
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum, length))
+
+ // Inject packet.
+ var views [1]buffer.View
+ vv := buf.ToVectorisedView(views)
+ c.linkEP.Inject(ipv4.ProtocolNumber, &vv)
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func testV4Read(c *testContext) {
+ // Send a packet.
+ payload := newPayload()
+ c.sendPacket(payload, &headers{
+ srcPort: testPort,
+ dstPort: stackPort,
+ })
+
+ // Try to receive the data.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&we, waiter.EventIn)
+ defer c.wq.EventUnregister(&we)
+
+ var addr tcpip.FullAddress
+ v, err := c.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ // Wait for data to become available.
+ select {
+ case <-ch:
+ v, err = c.ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for data")
+ }
+ }
+
+ // Check the peer address.
+ if addr.Addr != testAddr {
+ c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
+ }
+
+ // Check the payload.
+ if !bytes.Equal(payload, v) {
+ c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
+ }
+}
+
+func TestV4ReadOnV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Read(c)
+}
+
+func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to v4 mapped wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Read(c)
+}
+
+func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to local address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Read(c)
+}
+
+func TestV6ReadOnV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Send a packet.
+ payload := newPayload()
+ c.sendV6Packet(payload, &headers{
+ srcPort: testPort,
+ dstPort: stackPort,
+ })
+
+ // Try to receive the data.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&we, waiter.EventIn)
+ defer c.wq.EventUnregister(&we)
+
+ var addr tcpip.FullAddress
+ v, err := c.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ // Wait for data to become available.
+ select {
+ case <-ch:
+ v, err = c.ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for data")
+ }
+ }
+
+ // Check the peer address.
+ if addr.Addr != testV6Addr {
+ c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
+ }
+
+ // Check the payload.
+ if !bytes.Equal(payload, v) {
+ c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
+ }
+}
+
+func TestV4ReadOnV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ // Create v4 UDP endpoint.
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Read(c)
+}
+
+func testDualWrite(c *testContext) uint16 {
+ // Write to V4 mapped address.
+ payload := buffer.View(newPayload())
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
+ })
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet.
+ b := c.getPacket()
+ udp := header.UDP(header.IPv4(b).Payload())
+ checker.IPv4(c.t, b,
+ checker.UDP(
+ checker.DstPort(testPort),
+ ),
+ )
+
+ port := udp.SourcePort()
+
+ // Check the payload.
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+
+ // Write to v6 address.
+ payload = buffer.View(newPayload())
+ n, err = c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ })
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet, and that the source port is the
+ // same as the one used in ipv4.
+ b = c.getV6Packet()
+ udp = header.UDP(header.IPv6(b).Payload())
+ checker.IPv6(c.t, b,
+ checker.UDP(
+ checker.DstPort(testPort),
+ checker.SrcPort(port),
+ ),
+ )
+
+ // Check the payload.
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+
+ return port
+}
+
+func TestDualWriteUnbound(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ testDualWrite(c)
+}
+
+func TestDualWriteBoundToWildcard(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ p := testDualWrite(c)
+ if p != stackPort {
+ c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
+ }
+}
+
+func TestDualWriteConnectedToV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Connect to v6 address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ testDualWrite(c)
+}
+
+func TestDualWriteConnectedToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Connect to v4 mapped address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ testDualWrite(c)
+}
+
+func TestV4WriteOnV6Only(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(true)
+
+ // Write to V4 mapped address.
+ payload := buffer.View(newPayload())
+ _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
+ })
+ if err != tcpip.ErrNoRoute {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
+ }
+}
+
+func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to v4 mapped address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Write to v6 address.
+ payload := buffer.View(newPayload())
+ _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ })
+ if err != tcpip.ErrNoRoute {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
+ }
+}
+
+func TestV6WriteOnConnected(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Connect to v6 address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Write without destination.
+ payload := buffer.View(newPayload())
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet.
+ b := c.getV6Packet()
+ udp := header.UDP(header.IPv6(b).Payload())
+ checker.IPv6(c.t, b,
+ checker.UDP(
+ checker.DstPort(testPort),
+ ),
+ )
+
+ // Check the payload.
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+}
+
+func TestV4WriteOnConnected(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Connect to v4 mapped address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Write without destination.
+ payload := buffer.View(newPayload())
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet.
+ b := c.getPacket()
+ udp := header.UDP(header.IPv4(b).Payload())
+ checker.IPv4(c.t, b,
+ checker.UDP(
+ checker.DstPort(testPort),
+ ),
+ )
+
+ // Check the payload.
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+}