summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-09-01 17:09:49 -0700
committergVisor bot <gvisor-bot@google.com>2021-09-01 17:12:24 -0700
commitae3bd32011889fe59bb89946532dd7ee14973696 (patch)
treeaebfc760ab1ea6911d9376914e6551f1dfd6de27 /pkg
parent5032f4f57d9d46a3dfebb50523907724713e0001 (diff)
Extract network datagram endpoint common facilities
...from the UDP endpoint. Datagram-based transport endpoints (e.g. UDP, RAW IP) can share a lot of their write path due to the datagram-based nature of these endpoints. Extract the common facilities from UDP so they can be shared with other transport endpoints (in a later change). Test: UDP syscall tests. PiperOrigin-RevId: 394347774
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go10
-rw-r--r--pkg/tcpip/transport/BUILD13
-rw-r--r--pkg/tcpip/transport/datagram.go49
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD43
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go722
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go56
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go209
-rw-r--r--pkg/tcpip/transport/transport.go16
-rw-r--r--pkg/tcpip/transport/udp/BUILD2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go858
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go59
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go21
13 files changed, 1357 insertions, 702 deletions
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index e347442e7..bf5ec4558 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -48,6 +48,7 @@ go_library(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 2f9462cee..8cf2f29e4 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -59,8 +59,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -2045,7 +2045,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
return syserr.ErrInvalidEndpointState
- } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ } else if isUDPSocket(skType, skProto) && transport.DatagramEndpointState(ep.State()) != transport.DatagramEndpointStateInitial {
return syserr.ErrInvalidEndpointState
}
@@ -3331,10 +3331,10 @@ func (s *socketOpsCommon) State() uint32 {
}
case isUDPSocket(s.skType, s.protocol):
// UDP socket.
- switch udp.EndpointState(s.Endpoint.State()) {
- case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ switch transport.DatagramEndpointState(s.Endpoint.State()) {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateBound, transport.DatagramEndpointStateClosed:
return linux.TCP_CLOSE
- case udp.StateConnected:
+ case transport.DatagramEndpointStateConnected:
return linux.TCP_ESTABLISHED
default:
return 0
diff --git a/pkg/tcpip/transport/BUILD b/pkg/tcpip/transport/BUILD
new file mode 100644
index 000000000..af332ed91
--- /dev/null
+++ b/pkg/tcpip/transport/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "transport",
+ srcs = [
+ "datagram.go",
+ "transport.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/tcpip"],
+)
diff --git a/pkg/tcpip/transport/datagram.go b/pkg/tcpip/transport/datagram.go
new file mode 100644
index 000000000..dfce72c69
--- /dev/null
+++ b/pkg/tcpip/transport/datagram.go
@@ -0,0 +1,49 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// DatagramEndpointState is the state of a datagram-based endpoint.
+type DatagramEndpointState tcpip.EndpointState
+
+// The states a datagram-based endpoint may be in.
+const (
+ _ DatagramEndpointState = iota
+ DatagramEndpointStateInitial
+ DatagramEndpointStateBound
+ DatagramEndpointStateConnected
+ DatagramEndpointStateClosed
+)
+
+// String implements fmt.Stringer.
+func (s DatagramEndpointState) String() string {
+ switch s {
+ case DatagramEndpointStateInitial:
+ return "INITIAL"
+ case DatagramEndpointStateBound:
+ return "BOUND"
+ case DatagramEndpointStateConnected:
+ return "CONNECTED"
+ case DatagramEndpointStateClosed:
+ return "CLOSED"
+ default:
+ panic(fmt.Sprintf("unhandled %[1]T variant = %[1]d", s))
+ }
+}
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
new file mode 100644
index 000000000..d10e3f13a
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "network",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/transport/udp:__pkg__",
+ ],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ ],
+)
+
+go_test(
+ name = "network_test",
+ size = "small",
+ srcs = ["endpoint_test.go"],
+ deps = [
+ ":network",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/udp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
new file mode 100644
index 000000000..0dce60d89
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -0,0 +1,722 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package network provides facilities to support tcpip.Endpoints that operate
+// at the network layer or above.
+package network
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Endpoint is a datagram-based endpoint. It only supports sending datagrams to
+// a peer.
+//
+// +stateify savable
+type Endpoint struct {
+ // The following fields must only be set once then never changed.
+ stack *stack.Stack `state:"manual"`
+ ops *tcpip.SocketOptions
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
+ // state holds a transport.DatagramBasedEndpointState.
+ //
+ // state must be read from/written to atomically.
+ state uint32
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ info stack.TransportEndpointInfo
+ // owner is the owner of transmitted packets.
+ owner tcpip.PacketOwner
+ writeShutdown bool
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ connectedRoute *stack.Route `state:"manual"`
+ multicastMemberships map[multicastMembership]struct{}
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ ttl uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ multicastTTL uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ multicastAddr tcpip.Address
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ multicastNICID tcpip.NICID
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ sendTOS uint8
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+// Init initializes the endpoint.
+func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
+ if e.multicastMemberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
+ }
+
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
+ *e = Endpoint{
+ stack: s,
+ ops: ops,
+ netProto: netProto,
+ transProto: transProto,
+
+ state: uint32(transport.DatagramEndpointStateInitial),
+
+ info: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
+ effectiveNetProto: netProto,
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ }
+}
+
+// NetProto returns the network protocol the endpoint was initialized with.
+func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
+ return e.netProto
+}
+
+// setState sets the state of the endpoint.
+func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
+ atomic.StoreUint32(&e.state, uint32(state))
+}
+
+// State returns the state of the endpoint.
+func (e *Endpoint) State() transport.DatagramEndpointState {
+ return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
+}
+
+// Close cleans the endpoint's resources and leaves the endpoint in a closed
+// state.
+func (e *Endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return
+ }
+
+ for mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ if e.connectedRoute != nil {
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+ }
+
+ e.setEndpointState(transport.DatagramEndpointStateClosed)
+}
+
+// SetOwner sets the owner of transmitted packets.
+func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.owner = owner
+}
+
+func calculateTTL(route *stack.Route, ttl uint8, multicastTTL uint8) uint8 {
+ if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
+ return multicastTTL
+ }
+
+ if ttl == 0 {
+ return route.DefaultTTL()
+ }
+
+ return ttl
+}
+
+// WriteContext holds the context for a write.
+type WriteContext struct {
+ transProto tcpip.TransportProtocolNumber
+ route *stack.Route
+ ttl uint8
+ tos uint8
+ owner tcpip.PacketOwner
+}
+
+// Release releases held resources.
+func (c *WriteContext) Release() {
+ c.route.Release()
+ *c = WriteContext{}
+}
+
+// WritePacketInfo is the properties of a packet that may be written.
+type WritePacketInfo struct {
+ NetProto tcpip.NetworkProtocolNumber
+ LocalAddress, RemoteAddress tcpip.Address
+ MaxHeaderLength uint16
+ RequiresTXTransportChecksum bool
+}
+
+// PacketInfo returns the properties of a packet that will be written.
+func (c *WriteContext) PacketInfo() WritePacketInfo {
+ return WritePacketInfo{
+ NetProto: c.route.NetProto(),
+ LocalAddress: c.route.LocalAddress(),
+ RemoteAddress: c.route.RemoteAddress(),
+ MaxHeaderLength: c.route.MaxHeaderLength(),
+ RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(),
+ }
+}
+
+// WritePacket attempts to write the packet.
+func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
+ pkt.Owner = c.owner
+
+ if headerIncluded {
+ return c.route.WriteHeaderIncludedPacket(pkt)
+ }
+
+ return c.route.WritePacket(stack.NetworkHeaderParams{
+ Protocol: c.transProto,
+ TTL: c.ttl,
+ TOS: c.tos,
+ }, pkt)
+}
+
+// AcquireContextForWrite acquires a WriteContext.
+func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
+ }
+
+ if e.writeShutdown {
+ return WriteContext{}, &tcpip.ErrClosedForSend{}
+ }
+
+ route := e.connectedRoute
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return WriteContext{}, &tcpip.ErrDestinationRequired{}
+ }
+
+ route.Acquire()
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicID := opts.To.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
+ if e.info.BindNICID != 0 {
+ if nicID != 0 && nicID != e.info.BindNICID {
+ return WriteContext{}, &tcpip.ErrNoRoute{}
+ }
+
+ nicID = e.info.BindNICID
+ }
+
+ dst, netProto, err := e.checkV4MappedLocked(*opts.To)
+ if err != nil {
+ return WriteContext{}, err
+ }
+
+ route, _, err = e.connectRoute(nicID, dst, netProto)
+ if err != nil {
+ return WriteContext{}, err
+ }
+ }
+
+ if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
+ route.Release()
+ return WriteContext{}, &tcpip.ErrBroadcastDisabled{}
+ }
+
+ return WriteContext{
+ transProto: e.transProto,
+ route: route,
+ ttl: calculateTTL(route, e.ttl, e.multicastTTL),
+ tos: e.sendTOS,
+ owner: e.owner,
+ }, nil
+}
+
+// Disconnect disconnects the endpoint from its peer.
+func (e *Endpoint) Disconnect() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return
+ }
+
+ // Exclude ephemerally bound endpoints.
+ if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" {
+ e.info.ID = stack.TransportEndpointID{
+ LocalAddress: e.info.ID.LocalAddress,
+ }
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ } else {
+ e.info.ID = stack.TransportEndpointID{}
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
+ }
+
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+}
+
+// connectRoute establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+ localAddr := e.info.ID.LocalAddress
+ if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
+ }
+
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicID == 0 {
+ nicID = e.multicastNICID
+ }
+ if localAddr == "" && nicID == 0 {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
+ if err != nil {
+ return nil, 0, err
+ }
+ return r, nicID, nil
+}
+
+// Connect connects the endpoint to the address.
+func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ return nil
+ })
+}
+
+// ConnectAndThen connects the endpoint to the address and then calls the
+// provided function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the network protocol used to connect to the peer
+// and the source and destination addresses that will be used to send traffic to
+// the peer.
+func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicID := addr.NIC
+ switch e.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ if e.info.BindNICID == 0 {
+ break
+ }
+
+ if nicID != 0 && nicID != e.info.BindNICID {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ nicID = e.info.BindNICID
+ default:
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ r, nicID, err := e.connectRoute(nicID, addr, netProto)
+ if err != nil {
+ return err
+ }
+
+ id := stack.TransportEndpointID{
+ LocalAddress: e.info.ID.LocalAddress,
+ RemoteAddress: r.RemoteAddress(),
+ }
+ if e.State() == transport.DatagramEndpointStateInitial {
+ id.LocalAddress = r.LocalAddress()
+ }
+
+ if err := f(r.NetProto(), e.info.ID, id); err != nil {
+ return err
+ }
+
+ if e.connectedRoute != nil {
+ // If the endpoint was previously connected then release any previous route.
+ e.connectedRoute.Release()
+ }
+ e.connectedRoute = r
+ e.info.ID = id
+ e.info.RegisterNICID = nicID
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateConnected)
+ return nil
+}
+
+// Shutdown shutsdown the endpoint.
+func (e *Endpoint) Shutdown() tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ e.writeShutdown = true
+ return nil
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
+
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+ unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
+ }
+ return unwrapped, netProto, nil
+}
+
+func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
+}
+
+// Bind binds the endpoint to the address.
+func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
+ return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error {
+ return nil
+ })
+}
+
+// BindAndThen binds the endpoint to the address and then calls the provided
+// function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the bound network protocol and address.
+func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.State() != transport.DatagramEndpointStateInitial {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ nicID := addr.NIC
+ if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
+ nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr)
+ if nicID == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if err := f(netProto, addr.Addr); err != nil {
+ return err
+ }
+
+ e.info.ID = stack.TransportEndpointID{
+ LocalAddress: addr.Addr,
+ }
+ e.info.BindNICID = nicID
+ e.info.RegisterNICID = nicID
+ e.info.BindAddr = addr.Addr
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ return nil
+}
+
+// GetLocalAddress returns the address that the endpoint is bound to.
+func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ addr := e.info.BindAddr
+ if e.State() == transport.DatagramEndpointStateConnected {
+ addr = e.connectedRoute.LocalAddress()
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.info.RegisterNICID,
+ Addr: addr,
+ }
+}
+
+// GetRemoteAddress returns the address that the endpoint is connected to.
+func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return tcpip.FullAddress{}, false
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.connectedRoute.RemoteAddress(),
+ NIC: e.info.RegisterNICID,
+ }, true
+}
+
+// SetSockOptInt sets the socket option.
+func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+ }
+
+ return nil
+}
+
+// GetSockOptInt returns the socket option.
+func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ v := int(e.multicastTTL)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ default:
+ return -1, &tcpip.ErrUnknownProtocolOption{}
+ }
+}
+
+// SetSockOpt sets the socket option.
+func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ fa, netProto, err := e.checkV4MappedLocked(fa)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if e.info.BindNICID != 0 && e.info.BindNICID != nic {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case *tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return &tcpip.ErrPortInUse{}
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToInsert] = struct{}{}
+
+ case *tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ delete(e.multicastMemberships, memToRemove)
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt returns the socket option.
+func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ NIC: e.multicastNICID,
+ InterfaceAddr: e.multicastAddr,
+ }
+ e.mu.Unlock()
+
+ default:
+ return &tcpip.ErrUnknownProtocolOption{}
+ }
+ return nil
+}
+
+// Info returns a copy of the endpoint info.
+func (e *Endpoint) Info() stack.TransportEndpointInfo {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.info
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
new file mode 100644
index 000000000..858007156
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -0,0 +1,56 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *Endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.stack = s
+
+ for m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(fmt.Sprintf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err))
+ }
+ }
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound:
+ if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) {
+ if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 {
+ panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress))
+ }
+ }
+ case transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ multicastLoop := e.ops.GetMulticastLoop()
+ e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
+ if err != nil {
+ panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
new file mode 100644
index 000000000..2c43eb66a
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -0,0 +1,209 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+func TestEndpointStateTransitions(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ var (
+ ipv4NICAddr = testutil.MustParse4("1.2.3.4")
+ ipv6NICAddr = testutil.MustParse6("a::1")
+ ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
+ ipv6RemoteAddr = testutil.MustParse6("b::1")
+ )
+
+ data := buffer.View([]byte{1, 2, 4, 5})
+ v4Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(ipv4NICAddr),
+ checker.DstAddr(ipv4RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ v6Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(ipv6NICAddr),
+ checker.DstAddr(ipv6RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ expectedMaxHeaderLength uint16
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLocalAddr tcpip.Address
+ bindAddr tcpip.Address
+ expectedBoundAddr tcpip.Address
+ remoteAddr tcpip.Address
+ expectedRemoteAddr tcpip.Address
+ checker func(*testing.T, buffer.View)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: header.IPv4AllSystems,
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: ipv4RemoteAddr,
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv6FixedHeaderSize,
+ expectedNetProto: ipv6.ProtocolNumber,
+ expectedLocalAddr: ipv6NICAddr,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ expectedBoundAddr: header.IPv6AllNodesMulticastAddress,
+ remoteAddr: ipv6RemoteAddr,
+ expectedRemoteAddr: ipv6RemoteAddr,
+ checker: v6Checker,
+ },
+ {
+ name: "IPv4-mapped-IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: testutil.MustParse6("::ffff:e000:0001"),
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: testutil.MustParse6("::ffff:0607:0809"),
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ })
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ if state := ep.State(); state != transport.DatagramEndpointStateInitial {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateBound {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); connected {
+ t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr)
+ }
+
+ connectAddr := tcpip.FullAddress{Addr: test.remoteAddr}
+ if err := ep.Connect(connectAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", connectAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateConnected {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); !connected {
+ t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr)
+ } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" {
+ t.Errorf("remote address mismatch (-want +got):\n%s", diff)
+ }
+
+ ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("ep.AcquireContexForWrite({}): %s", err)
+ }
+ defer ctx.Release()
+ info := ctx.PacketInfo()
+ if diff := cmp.Diff(network.WritePacketInfo{
+ NetProto: test.expectedNetProto,
+ LocalAddress: test.expectedLocalAddr,
+ RemoteAddress: test.expectedRemoteAddr,
+ MaxHeaderLength: test.expectedMaxHeaderLength,
+ RequiresTXTransportChecksum: true,
+ }, info); diff != "" {
+ t.Errorf("write packet info mismatch (-want +got):\n%s", diff)
+ }
+ if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(info.MaxHeaderLength),
+ Data: data.ToVectorisedView(),
+ }), false /* headerIncluded */); err != nil {
+ t.Fatalf("ctx.WritePacket(_, false): %s", err)
+ }
+ if pkt, ok := e.Read(); !ok {
+ t.Fatalf("expected packet to be read from link endpoint")
+ } else {
+ test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ }
+
+ ep.Close()
+ if state := ep.State(); state != transport.DatagramEndpointStateClosed {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/transport.go b/pkg/tcpip/transport/transport.go
new file mode 100644
index 000000000..4c2ae87f4
--- /dev/null
+++ b/pkg/tcpip/transport/transport.go
@@ -0,0 +1,16 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport supports transport protocols.
+package transport
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index cdc344ab7..5cc7a2886 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -35,6 +35,8 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 108580508..ac7ecb5f8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,8 +15,8 @@
package udp
import (
+ "fmt"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -40,36 +42,6 @@ type udpPacket struct {
tos uint8
}
-// EndpointState represents the state of a UDP endpoint.
-type EndpointState tcpip.EndpointState
-
-// Endpoint states. Note that are represented in a netstack-specific manner and
-// may not be meaningful externally. Specifically, they need to be translated to
-// Linux's representation for these states if presented to userspace.
-const (
- _ EndpointState = iota
- StateInitial
- StateBound
- StateConnected
- StateClosed
-)
-
-// String implements fmt.Stringer.
-func (s EndpointState) String() string {
- switch s {
- case StateInitial:
- return "INITIAL"
- case StateBound:
- return "BOUND"
- case StateConnected:
- return "CONNECTING"
- case StateClosed:
- return "CLOSED"
- default:
- return "UNKNOWN"
- }
-}
-
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -79,7 +51,6 @@ func (s EndpointState) String() string {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
@@ -87,6 +58,10 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -96,37 +71,19 @@ type endpoint struct {
rcvBufSize int
rcvClosed bool
- // The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- // state must be read/set using the EndpointState()/setEndpointState()
- // methods.
- state uint32
- route *stack.Route `state:"manual"`
- dstPort uint16
- ttl uint8
- multicastTTL uint8
- multicastAddr tcpip.Address
- multicastNICID tcpip.NICID
- portFlags ports.Flags
-
lastErrorMu sync.Mutex `state:"nosave"`
lastError tcpip.Error
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
- // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
- // applied while sending packets. Defaults to 0 as on Linux.
- sendTOS uint8
-
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
-
- // multicastMemberships that need to be remvoed when the endpoint is
- // closed. Protected by the mu mutex.
- multicastMemberships map[multicastMembership]struct{}
+ readShutdown bool
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -136,55 +93,25 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
-}
-// +stateify savable
-type multicastMembership struct {
- nicID tcpip.NICID
- multicastAddr tcpip.Address
+ localPort uint16
+ remotePort uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: header.UDPProtocolNumber,
- },
+ stack: s,
waiterQueue: waiterQueue,
- // RFC 1075 section 5.4 recommends a TTL of 1 for membership
- // requests.
- //
- // RFC 5135 4.2.1 appears to assume that IGMP messages have a
- // TTL of 1.
- //
- // RFC 5135 Appendix A defines TTL=1: A multicast source that
- // wants its traffic to not traverse a router (e.g., leave a
- // home network) may find it useful to send traffic with IP
- // TTL=1.
- //
- // Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastMemberships: make(map[multicastMembership]struct{}),
- state: uint32(StateInitial),
- uniqueID: s.UniqueID(),
+ uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -200,20 +127,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
-// setEndpointState updates the state of the endpoint to state atomically. This
-// method is unexported as the only place we should update the state is in this
-// package but we allow the state to be read freely without holding e.mu.
-//
-// Precondition: e.mu must be held to call this method.
-func (e *endpoint) setEndpointState(state EndpointState) {
- atomic.StoreUint32(&e.state, uint32(state))
-}
-
-// EndpointState() returns the current state of the endpoint.
-func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32(&e.state))
-}
-
// UniqueID implements stack.TransportEndpoint.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -244,16 +157,22 @@ func (e *endpoint) Abort() {
// associated with it.
func (e *endpoint) Close() {
e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.EndpointState() {
- case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ e.mu.Unlock()
+ return
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice)
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -261,13 +180,10 @@ func (e *endpoint) Close() {
e.stack.ReleasePort(portRes)
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- for mem := range e.multicastMemberships {
- e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
- }
- e.multicastMemberships = nil
-
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -278,14 +194,9 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- // Update the state.
- e.setEndpointState(StateClosed)
-
+ e.net.Shutdown()
+ e.net.Close()
+ e.readShutdown = true
e.mu.Unlock()
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
@@ -359,19 +270,19 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
return res, nil
}
-// prepareForWrite prepares the endpoint for sending data. In particular, it
-// binds it if it's still in the initial state. To do so, it must first
+// prepareForWriteInner prepares the endpoint for sending data. In particular,
+// it binds it if it's still in the initial state. To do so, it must first
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.EndpointState() {
- case StateInitial:
- case StateConnected:
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
- case StateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -386,7 +297,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -398,33 +309,6 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
return true, nil
}
-// connectRoute establishes a route to the specified interface or the
-// configured multicast interface if no interface is specified and the
-// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
- localAddr := e.ID.LocalAddress
- if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
- // A packet can only originate from a unicast address (i.e., an interface).
- localAddr = ""
- }
-
- if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
- if nicID == 0 {
- nicID = e.multicastNICID
- }
- if localAddr == "" && nicID == 0 {
- localAddr = e.multicastAddr
- }
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
- if err != nil {
- return nil, 0, err
- }
- return r, nicID, nil
-}
-
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
@@ -448,18 +332,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
+func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return udpPacketInfo{}, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(opts.To)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
return udpPacketInfo{}, err
}
@@ -469,49 +348,27 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions
}
}
- route := e.route
- dstPort := e.dstPort
+ dst, connected := e.net.GetRemoteAddress()
+ dst.Port = e.remotePort
if opts.To != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := opts.To.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return udpPacketInfo{}, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
if opts.To.Port == 0 {
// Port 0 is an invalid port to send to.
return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
}
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
- if err != nil {
- return udpPacketInfo{}, err
- }
-
- r, _, err := e.connectRoute(nicID, dst, netProto)
- if err != nil {
- return udpPacketInfo{}, err
- }
- defer r.Release()
-
- route = r
- dstPort = dst.Port
+ dst = *opts.To
+ } else if !connected {
+ return udpPacketInfo{}, &tcpip.ErrDestinationRequired{}
}
- if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{}
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ if err != nil {
+ return udpPacketInfo{}, err
}
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
+ ctx.Release()
return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
@@ -520,50 +377,25 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions
if so.GetRecvError() {
so.QueueLocalErr(
&tcpip.ErrMessageTooLong{},
- route.NetProto(),
+ e.net.NetProto(),
header.UDPMaximumPacketSize,
- tcpip.FullAddress{
- NIC: route.NICID(),
- Addr: route.RemoteAddress(),
- Port: dstPort,
- },
+ dst,
v,
)
}
+ ctx.Release()
return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}
- ttl := e.ttl
- useDefaultTTL := ttl == 0
- if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
- ttl = e.multicastTTL
- // Multicast allows a 0 TTL.
- useDefaultTTL = false
- }
-
return udpPacketInfo{
- route: route,
- data: buffer.View(v),
- localPort: e.ID.LocalPort,
- remotePort: dstPort,
- ttl: ttl,
- useDefaultTTL: useDefaultTTL,
- tos: e.sendTOS,
- owner: e.owner,
- noChecksum: e.SocketOptions().GetNoChecksum(),
+ ctx: ctx,
+ data: v,
+ localPort: e.localPort,
+ remotePort: dst.Port,
}, nil
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if err := e.LastError(); err != nil {
- return 0, err
- }
-
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
// Do not hold lock when sending as loopback is synchronous and if the UDP
// datagram ends up generating an ICMP response then it can result in a
// deadlock where the ICMP response handling ends up acquiring this endpoint's
@@ -574,15 +406,53 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
//
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
- u, err := e.buildUDPPacketInfo(p, opts)
- if err != nil {
+
+ if err := e.LastError(); err != nil {
return 0, err
}
- n, err := u.send()
+
+ udpInfo, err := e.prepareForWrite(p, opts)
if err != nil {
return 0, err
}
- return int64(n), nil
+ defer udpInfo.ctx.Release()
+
+ pktInfo := udpInfo.ctx.PacketInfo()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength),
+ Data: udpInfo.data.ToVectorisedView(),
+ })
+
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
+
+ length := uint16(pkt.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: udpInfo.localPort,
+ DstPort: udpInfo.remotePort,
+ Length: length,
+ })
+
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if pktInfo.RequiresTXTransportChecksum &&
+ (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) {
+ udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine(
+ header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length),
+ pkt.Data().AsRange().Checksum(),
+ )))
+ }
+ if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ e.stack.Stats().UDP.PacketSendErrors.Increment()
+ return 0, err
+ }
+
+ // Track count of packets sent.
+ e.stack.Stats().UDP.PacketsSent.Increment()
+ return int64(len(udpInfo.data)), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
@@ -601,36 +471,7 @@ func (e *endpoint) OnReusePortSet(v bool) {
// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.MTUDiscoverOption:
- // Return not supported if the value is not disabling path
- // MTU discovery.
- if v != tcpip.PMTUDiscoveryDont {
- return &tcpip.ErrNotSupported{}
- }
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- e.multicastTTL = uint8(v)
- e.mu.Unlock()
-
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- }
-
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
@@ -642,145 +483,12 @@ func (e *endpoint) HasNIC(id int32) bool {
// SetSockOpt implements tcpip.Endpoint.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
-
- fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
- if err != nil {
- return err
- }
- nic := v.NIC
- addr := fa.Addr
-
- if nic == 0 && addr == "" {
- e.multicastAddr = ""
- e.multicastNICID = 0
- break
- }
-
- if nic != 0 {
- if !e.stack.CheckNIC(nic) {
- return &tcpip.ErrBadLocalAddress{}
- }
- } else {
- nic = e.stack.CheckLocalAddress(0, netProto, addr)
- if nic == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
- }
-
- if e.BindNICID != 0 && e.BindNICID != nic {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- e.multicastNICID = nic
- e.multicastAddr = addr
-
- case *tcpip.AddMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
-
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToInsert]; ok {
- return &tcpip.ErrPortInUse{}
- }
-
- if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- e.multicastMemberships[memToInsert] = struct{}{}
-
- case *tcpip.RemoveMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToRemove]; !ok {
- return &tcpip.ErrBadLocalAddress{}
- }
-
- if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- delete(e.multicastMemberships, memToRemove)
-
- case *tcpip.SocketDetachFilterOption:
- return nil
- }
- return nil
+ return e.net.SetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
- case tcpip.IPv4TOSOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.MTUDiscoverOption:
- // The only supported setting is path MTU discovery disabled.
- return tcpip.PMTUDiscoveryDont, nil
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- v := int(e.multicastTTL)
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveQueueSizeOption:
v := 0
e.rcvMu.Lock()
@@ -791,108 +499,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.mu.Lock()
- v := int(e.ttl)
- e.mu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// GetSockOpt implements tcpip.Endpoint.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- *o = tcpip.MulticastInterfaceOption{
- NIC: e.multicastNICID,
- InterfaceAddr: e.multicastAddr,
- }
- e.mu.Unlock()
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
- return nil
+ return e.net.GetSockOpt(opt)
}
-// udpPacketInfo contains all information required to send a UDP packet.
-//
-// This should be used as a value-only type, which exists in order to simplify
-// return value syntax. It should not be exported or extended.
+// udpPacketInfo holds information needed to send a UDP packet.
type udpPacketInfo struct {
- route *stack.Route
- data buffer.View
- localPort uint16
- remotePort uint16
- ttl uint8
- useDefaultTTL bool
- tos uint8
- owner tcpip.PacketOwner
- noChecksum bool
-}
-
-// send sends the given packet.
-func (u *udpPacketInfo) send() (int, tcpip.Error) {
- vv := u.data.ToVectorisedView()
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()),
- Data: vv,
- })
- pkt.Owner = u.owner
-
- // Initialize the UDP header.
- udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- pkt.TransportProtocolNumber = ProtocolNumber
-
- length := uint16(pkt.Size())
- udp.Encode(&header.UDPFields{
- SrcPort: u.localPort,
- DstPort: u.remotePort,
- Length: length,
- })
-
- // Set the checksum field unless TX checksum offload is enabled.
- // On IPv4, UDP checksum is optional, and a zero value indicates the
- // transmitter skipped the checksum generation (RFC768).
- // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if u.route.RequiresTXTransportChecksum() &&
- (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) {
- xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length)
- for _, v := range vv.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udp.SetChecksum(^udp.CalculateChecksum(xsum))
- }
-
- if u.useDefaultTTL {
- u.ttl = u.route.DefaultTTL()
- }
- if err := u.route.WritePacket(stack.NetworkHeaderParams{
- Protocol: ProtocolNumber,
- TTL: u.ttl,
- TOS: u.tos,
- }, pkt); err != nil {
- u.route.Stats().UDP.PacketSendErrors.Increment()
- return 0, err
- }
-
- // Track count of packets sent.
- u.route.Stats().UDP.PacketsSent.Increment()
- return len(u.data), nil
-}
-
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
+ ctx network.WriteContext
+ data buffer.View
+ localPort uint16
+ remotePort uint16
}
// Disconnect implements tcpip.Endpoint.
@@ -900,7 +522,7 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.EndpointState() != StateConnected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return nil
}
var (
@@ -913,26 +535,28 @@ func (e *endpoint) Disconnect() tcpip.Error {
boundPortFlags := e.boundPortFlags
// Exclude ephemerally bound endpoints.
- if e.BindNICID != 0 || e.ID.LocalAddress == "" {
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ if info.BindNICID != 0 || info.ID.LocalAddress == "" {
var err tcpip.Error
id = stack.TransportEndpointID{
- LocalPort: e.ID.LocalPort,
- LocalAddress: e.ID.LocalAddress,
+ LocalPort: info.ID.LocalPort,
+ LocalAddress: info.ID.LocalAddress,
}
id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
return err
}
- e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
- if e.ID.LocalPort != 0 {
+ if info.ID.LocalPort != 0 {
// Release the ephemeral port.
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: info.ID.LocalAddress,
+ Port: info.ID.LocalPort,
Flags: boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -940,15 +564,14 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
- e.setEndpointState(StateInitial)
}
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
- e.ID = id
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice)
e.boundBindToDevice = btd
- e.route.Release()
- e.route = nil
- e.dstPort = 0
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+
+ e.net.Disconnect()
return nil
}
@@ -958,88 +581,48 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- var localPort uint16
- switch e.EndpointState() {
- case StateInitial:
- case StateBound, StateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
-
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.localPort
+ nextID.RemotePort = addr.Port
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: e.ID.LocalAddress,
- LocalPort: localPort,
- RemotePort: addr.Port,
- RemoteAddress: r.RemoteAddress(),
- }
+ oldPortFlags := e.boundPortFlags
- if e.EndpointState() == StateInitial {
- id.LocalAddress = r.LocalAddress()
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv4ProtocolNumber,
- header.IPv6ProtocolNumber,
+ nextID, btd, err := e.registerWithStack(netProtos, nextID)
+ if err != nil {
+ return err
}
- }
- oldPortFlags := e.boundPortFlags
+ // Remove the old registration.
+ if e.localPort != 0 {
+ previousID.LocalPort = e.localPort
+ previousID.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice)
+ }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = nextID.LocalPort
+ e.remotePort = nextID.RemotePort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- // Remove the old registration.
- if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
- }
-
- e.ID = id
- e.boundBindToDevice = btd
- if e.route != nil {
- // If the endpoint was already connected then make sure we release the
- // previous route.
- e.route.Release()
- }
- e.route = r
- e.dstPort = addr.Port
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- e.setEndpointState(StateConnected)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1054,15 +637,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // A socket in the bound state can still receive multicast messages,
- // so we need to notify waiters on shutdown.
- if state := e.EndpointState(); state != StateBound && state != StateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- e.shutdownFlags |= flags
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
+ }
if flags&tcpip.ShutdownRead != 0 {
+ e.readShutdown = true
+
e.rcvMu.Lock()
wasClosed := e.rcvClosed
e.rcvClosed = true
@@ -1088,7 +679,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- if e.ID.LocalPort == 0 {
+ if e.localPort == 0 {
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
@@ -1126,56 +717,43 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv6ProtocolNumber,
- header.IPv4ProtocolNumber,
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{boundNetProto}
+ if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
}
- }
- nicID := addr.NIC
- if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
- // A local unicast address was specified, verify that it's valid.
- nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nicID == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: boundAddr,
+ }
+ id, btd, err := e.registerWithStack(netProtos, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = id.LocalPort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.boundBindToDevice = btd
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- // Mark endpoint as bound.
- e.setEndpointState(StateBound)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1190,9 +768,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
return err
}
- // Save the effective NICID generated by bindLocked.
- e.BindNICID = e.RegisterNICID
-
return nil
}
@@ -1201,16 +776,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- addr := e.ID.LocalAddress
- if e.EndpointState() == StateConnected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.localPort
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -1218,15 +786,13 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.EndpointState() != StateConnected || e.dstPort == 0 {
+ addr, connected := e.net.GetRemoteAddress()
+ if !connected || e.remotePort == 0 {
return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ addr.Port = e.remotePort
+ return addr, nil
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -1376,19 +942,20 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
payload = udp.Payload()
}
+ id := e.net.Info().ID
e.SocketOptions().QueueErr(&tcpip.SockError{
Err: err,
Cause: transErr,
Payload: payload,
Dst: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: id.RemoteAddress,
+ Port: e.remotePort,
},
Offender: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: e.localPort,
},
NetProto: pkt.NetworkProtocolNumber,
})
@@ -1403,7 +970,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// TODO(gvisor.dev/issues/5270): Handle all transport errors.
switch transErr.Kind() {
case stack.DestinationPortUnreachableTransportError:
- if e.EndpointState() == StateConnected {
+ if e.net.State() == transport.DatagramEndpointStateConnected {
e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt)
}
}
@@ -1411,16 +978,17 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// State implements tcpip.Endpoint.
func (e *endpoint) State() uint32 {
- return uint32(e.EndpointState())
+ return uint32(e.net.State())
}
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
- return &ret
+ defer e.mu.RUnlock()
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ return &info
}
// Stats returns a pointer to the endpoint stats.
@@ -1431,13 +999,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements tcpip.Endpoint.
func (*endpoint) Wait() {}
-func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
-}
-
// SetOwner implements tcpip.Endpoint.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// SocketOptions implements tcpip.Endpoint.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 1f638c3f6..20c45ab87 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,12 +15,13 @@
package udp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -66,50 +67,28 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.mu.Lock()
defer e.mu.Unlock()
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- for m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
-
- state := e.EndpointState()
- if state != StateBound && state != StateConnected {
- return
- }
-
- netProto := e.effectiveNetProtos[0]
- // Connect() and bindLocked() both assert
- //
- // netProto == header.IPv6ProtocolNumber
- //
- // before creating a multi-entry effectiveNetProtos.
- if len(e.effectiveNetProtos) > 1 {
- netProto = header.IPv6ProtocolNumber
- }
-
- var err tcpip.Error
- if state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ var err tcpip.Error
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
- } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
- // A local unicast address is specified, verify that it's valid.
- if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- // Our saved state had a port, but we don't actually have a
- // reservation. We need to remove the port from our state, but still
- // pass it to the reservation machinery.
- id := e.ID
- e.ID.LocalPort = 0
- e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
- if err != nil {
- panic(err)
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 7c357cb09..7238fc019 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -70,28 +70,29 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
netHdr := r.pkt.Network()
- route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
- if err != nil {
+ if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil {
+ return nil, err
+ }
+
+ if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil {
return nil, err
}
- ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
- route.Release()
return nil, err
}
- ep.ID = r.id
- ep.route = route
- ep.dstPort = r.id.RemotePort
+ ep.localPort = r.id.LocalPort
+ ep.remotePort = r.id.RemotePort
ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
- ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
- ep.state = uint32(StateConnected)
-
ep.rcvMu.Lock()
ep.rcvReady = true
ep.rcvMu.Unlock()