diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-09-01 17:09:49 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-09-01 17:12:24 -0700 |
commit | ae3bd32011889fe59bb89946532dd7ee14973696 (patch) | |
tree | aebfc760ab1ea6911d9376914e6551f1dfd6de27 /pkg | |
parent | 5032f4f57d9d46a3dfebb50523907724713e0001 (diff) |
Extract network datagram endpoint common facilities
...from the UDP endpoint.
Datagram-based transport endpoints (e.g. UDP, RAW IP) can share a lot
of their write path due to the datagram-based nature of these endpoints.
Extract the common facilities from UDP so they can be shared with other
transport endpoints (in a later change).
Test: UDP syscall tests.
PiperOrigin-RevId: 394347774
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/netstack/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/BUILD | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/datagram.go | 49 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/BUILD | 43 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint.go | 722 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint_state.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint_test.go | 209 | ||||
-rw-r--r-- | pkg/tcpip/transport/transport.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 858 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 59 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/forwarder.go | 21 |
13 files changed, 1357 insertions, 702 deletions
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e347442e7..bf5ec4558 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/usermem", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 2f9462cee..8cf2f29e4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -59,8 +59,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -2045,7 +2045,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { return syserr.ErrInvalidEndpointState - } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + } else if isUDPSocket(skType, skProto) && transport.DatagramEndpointState(ep.State()) != transport.DatagramEndpointStateInitial { return syserr.ErrInvalidEndpointState } @@ -3331,10 +3331,10 @@ func (s *socketOpsCommon) State() uint32 { } case isUDPSocket(s.skType, s.protocol): // UDP socket. - switch udp.EndpointState(s.Endpoint.State()) { - case udp.StateInitial, udp.StateBound, udp.StateClosed: + switch transport.DatagramEndpointState(s.Endpoint.State()) { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateBound, transport.DatagramEndpointStateClosed: return linux.TCP_CLOSE - case udp.StateConnected: + case transport.DatagramEndpointStateConnected: return linux.TCP_ESTABLISHED default: return 0 diff --git a/pkg/tcpip/transport/BUILD b/pkg/tcpip/transport/BUILD new file mode 100644 index 000000000..af332ed91 --- /dev/null +++ b/pkg/tcpip/transport/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "transport", + srcs = [ + "datagram.go", + "transport.go", + ], + visibility = ["//visibility:public"], + deps = ["//pkg/tcpip"], +) diff --git a/pkg/tcpip/transport/datagram.go b/pkg/tcpip/transport/datagram.go new file mode 100644 index 000000000..dfce72c69 --- /dev/null +++ b/pkg/tcpip/transport/datagram.go @@ -0,0 +1,49 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// DatagramEndpointState is the state of a datagram-based endpoint. +type DatagramEndpointState tcpip.EndpointState + +// The states a datagram-based endpoint may be in. +const ( + _ DatagramEndpointState = iota + DatagramEndpointStateInitial + DatagramEndpointStateBound + DatagramEndpointStateConnected + DatagramEndpointStateClosed +) + +// String implements fmt.Stringer. +func (s DatagramEndpointState) String() string { + switch s { + case DatagramEndpointStateInitial: + return "INITIAL" + case DatagramEndpointStateBound: + return "BOUND" + case DatagramEndpointStateConnected: + return "CONNECTED" + case DatagramEndpointStateClosed: + return "CLOSED" + default: + panic(fmt.Sprintf("unhandled %[1]T variant = %[1]d", s)) + } +} diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD new file mode 100644 index 000000000..d10e3f13a --- /dev/null +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "network", + srcs = [ + "endpoint.go", + "endpoint_state.go", + ], + visibility = [ + "//pkg/tcpip/transport/udp:__pkg__", + ], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + ], +) + +go_test( + name = "network_test", + size = "small", + srcs = ["endpoint_test.go"], + deps = [ + ":network", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/udp", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go new file mode 100644 index 000000000..0dce60d89 --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -0,0 +1,722 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package network provides facilities to support tcpip.Endpoints that operate +// at the network layer or above. +package network + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" +) + +// Endpoint is a datagram-based endpoint. It only supports sending datagrams to +// a peer. +// +// +stateify savable +type Endpoint struct { + // The following fields must only be set once then never changed. + stack *stack.Stack `state:"manual"` + ops *tcpip.SocketOptions + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + + // state holds a transport.DatagramBasedEndpointState. + // + // state must be read from/written to atomically. + state uint32 + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + info stack.TransportEndpointInfo + // owner is the owner of transmitted packets. + owner tcpip.PacketOwner + writeShutdown bool + effectiveNetProto tcpip.NetworkProtocolNumber + connectedRoute *stack.Route `state:"manual"` + multicastMemberships map[multicastMembership]struct{} + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + ttl uint8 + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + multicastTTL uint8 + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + multicastAddr tcpip.Address + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + multicastNICID tcpip.NICID + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. + sendTOS uint8 +} + +// +stateify savable +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + +// Init initializes the endpoint. +func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) { + if e.multicastMemberships != nil { + panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships)) + } + + switch netProto { + case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber: + default: + panic(fmt.Sprintf("invalid protocol number = %d", netProto)) + } + + *e = Endpoint{ + stack: s, + ops: ops, + netProto: netProto, + transProto: transProto, + + state: uint32(transport.DatagramEndpointStateInitial), + + info: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, + effectiveNetProto: netProto, + // Linux defaults to TTL=1. + multicastTTL: 1, + multicastMemberships: make(map[multicastMembership]struct{}), + } +} + +// NetProto returns the network protocol the endpoint was initialized with. +func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { + return e.netProto +} + +// setState sets the state of the endpoint. +func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { + atomic.StoreUint32(&e.state, uint32(state)) +} + +// State returns the state of the endpoint. +func (e *Endpoint) State() transport.DatagramEndpointState { + return transport.DatagramEndpointState(atomic.LoadUint32(&e.state)) +} + +// Close cleans the endpoint's resources and leaves the endpoint in a closed +// state. +func (e *Endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.State() == transport.DatagramEndpointStateClosed { + return + } + + for mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + + if e.connectedRoute != nil { + e.connectedRoute.Release() + e.connectedRoute = nil + } + + e.setEndpointState(transport.DatagramEndpointStateClosed) +} + +// SetOwner sets the owner of transmitted packets. +func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) { + e.mu.Lock() + defer e.mu.Unlock() + e.owner = owner +} + +func calculateTTL(route *stack.Route, ttl uint8, multicastTTL uint8) uint8 { + if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { + return multicastTTL + } + + if ttl == 0 { + return route.DefaultTTL() + } + + return ttl +} + +// WriteContext holds the context for a write. +type WriteContext struct { + transProto tcpip.TransportProtocolNumber + route *stack.Route + ttl uint8 + tos uint8 + owner tcpip.PacketOwner +} + +// Release releases held resources. +func (c *WriteContext) Release() { + c.route.Release() + *c = WriteContext{} +} + +// WritePacketInfo is the properties of a packet that may be written. +type WritePacketInfo struct { + NetProto tcpip.NetworkProtocolNumber + LocalAddress, RemoteAddress tcpip.Address + MaxHeaderLength uint16 + RequiresTXTransportChecksum bool +} + +// PacketInfo returns the properties of a packet that will be written. +func (c *WriteContext) PacketInfo() WritePacketInfo { + return WritePacketInfo{ + NetProto: c.route.NetProto(), + LocalAddress: c.route.LocalAddress(), + RemoteAddress: c.route.RemoteAddress(), + MaxHeaderLength: c.route.MaxHeaderLength(), + RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(), + } +} + +// WritePacket attempts to write the packet. +func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { + pkt.Owner = c.owner + + if headerIncluded { + return c.route.WriteHeaderIncludedPacket(pkt) + } + + return c.route.WritePacket(stack.NetworkHeaderParams{ + Protocol: c.transProto, + TTL: c.ttl, + TOS: c.tos, + }, pkt) +} + +// AcquireContextForWrite acquires a WriteContext. +func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. + if opts.More { + return WriteContext{}, &tcpip.ErrInvalidOptionValue{} + } + + if e.State() == transport.DatagramEndpointStateClosed { + return WriteContext{}, &tcpip.ErrInvalidEndpointState{} + } + + if e.writeShutdown { + return WriteContext{}, &tcpip.ErrClosedForSend{} + } + + route := e.connectedRoute + if opts.To == nil { + // If the user doesn't specify a destination, they should have + // connected to another address. + if e.State() != transport.DatagramEndpointStateConnected { + return WriteContext{}, &tcpip.ErrDestinationRequired{} + } + + route.Acquire() + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicID := opts.To.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } + if e.info.BindNICID != 0 { + if nicID != 0 && nicID != e.info.BindNICID { + return WriteContext{}, &tcpip.ErrNoRoute{} + } + + nicID = e.info.BindNICID + } + + dst, netProto, err := e.checkV4MappedLocked(*opts.To) + if err != nil { + return WriteContext{}, err + } + + route, _, err = e.connectRoute(nicID, dst, netProto) + if err != nil { + return WriteContext{}, err + } + } + + if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { + route.Release() + return WriteContext{}, &tcpip.ErrBroadcastDisabled{} + } + + return WriteContext{ + transProto: e.transProto, + route: route, + ttl: calculateTTL(route, e.ttl, e.multicastTTL), + tos: e.sendTOS, + owner: e.owner, + }, nil +} + +// Disconnect disconnects the endpoint from its peer. +func (e *Endpoint) Disconnect() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.State() != transport.DatagramEndpointStateConnected { + return + } + + // Exclude ephemerally bound endpoints. + if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" { + e.info.ID = stack.TransportEndpointID{ + LocalAddress: e.info.ID.LocalAddress, + } + e.setEndpointState(transport.DatagramEndpointStateBound) + } else { + e.info.ID = stack.TransportEndpointID{} + e.setEndpointState(transport.DatagramEndpointStateInitial) + } + + e.connectedRoute.Release() + e.connectedRoute = nil +} + +// connectRoute establishes a route to the specified interface or the +// configured multicast interface if no interface is specified and the +// specified address is a multicast address. +func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { + localAddr := e.info.ID.LocalAddress + if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" + } + + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { + if nicID == 0 { + nicID = e.multicastNICID + } + if localAddr == "" && nicID == 0 { + localAddr = e.multicastAddr + } + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) + if err != nil { + return nil, 0, err + } + return r, nicID, nil +} + +// Connect connects the endpoint to the address. +func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error { + return nil + }) +} + +// ConnectAndThen connects the endpoint to the address and then calls the +// provided function. +// +// If the function returns an error, the endpoint's state does not change. The +// function will be called with the network protocol used to connect to the peer +// and the source and destination addresses that will be used to send traffic to +// the peer. +func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error { + addr.Port = 0 + + e.mu.Lock() + defer e.mu.Unlock() + + nicID := addr.NIC + switch e.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + if e.info.BindNICID == 0 { + break + } + + if nicID != 0 && nicID != e.info.BindNICID { + return &tcpip.ErrInvalidEndpointState{} + } + + nicID = e.info.BindNICID + default: + return &tcpip.ErrInvalidEndpointState{} + } + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + r, nicID, err := e.connectRoute(nicID, addr, netProto) + if err != nil { + return err + } + + id := stack.TransportEndpointID{ + LocalAddress: e.info.ID.LocalAddress, + RemoteAddress: r.RemoteAddress(), + } + if e.State() == transport.DatagramEndpointStateInitial { + id.LocalAddress = r.LocalAddress() + } + + if err := f(r.NetProto(), e.info.ID, id); err != nil { + return err + } + + if e.connectedRoute != nil { + // If the endpoint was previously connected then release any previous route. + e.connectedRoute.Release() + } + e.connectedRoute = r + e.info.ID = id + e.info.RegisterNICID = nicID + e.effectiveNetProto = netProto + e.setEndpointState(transport.DatagramEndpointStateConnected) + return nil +} + +// Shutdown shutsdown the endpoint. +func (e *Endpoint) Shutdown() tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + e.writeShutdown = true + return nil + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } +} + +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { + unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) + if err != nil { + return tcpip.FullAddress{}, 0, err + } + return unwrapped, netProto, nil +} + +func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) +} + +// Bind binds the endpoint to the address. +func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { + return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error { + return nil + }) +} + +// BindAndThen binds the endpoint to the address and then calls the provided +// function. +// +// If the function returns an error, the endpoint's state does not change. The +// function will be called with the bound network protocol and address. +func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error { + addr.Port = 0 + + e.mu.Lock() + defer e.mu.Unlock() + + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.State() != transport.DatagramEndpointStateInitial { + return &tcpip.ErrInvalidEndpointState{} + } + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + nicID := addr.NIC + if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { + nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr) + if nicID == 0 { + return &tcpip.ErrBadLocalAddress{} + } + } + + if err := f(netProto, addr.Addr); err != nil { + return err + } + + e.info.ID = stack.TransportEndpointID{ + LocalAddress: addr.Addr, + } + e.info.BindNICID = nicID + e.info.RegisterNICID = nicID + e.info.BindAddr = addr.Addr + e.effectiveNetProto = netProto + e.setEndpointState(transport.DatagramEndpointStateBound) + return nil +} + +// GetLocalAddress returns the address that the endpoint is bound to. +func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { + e.mu.RLock() + defer e.mu.RUnlock() + + addr := e.info.BindAddr + if e.State() == transport.DatagramEndpointStateConnected { + addr = e.connectedRoute.LocalAddress() + } + + return tcpip.FullAddress{ + NIC: e.info.RegisterNICID, + Addr: addr, + } +} + +// GetRemoteAddress returns the address that the endpoint is connected to. +func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.State() != transport.DatagramEndpointStateConnected { + return tcpip.FullAddress{}, false + } + + return tcpip.FullAddress{ + Addr: e.connectedRoute.RemoteAddress(), + NIC: e.info.RegisterNICID, + }, true +} + +// SetSockOptInt sets the socket option. +func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { + switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return &tcpip.ErrNotSupported{} + } + + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() + + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + case tcpip.IPv4TOSOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + } + + return nil +} + +// GetSockOptInt returns the socket option. +func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { + switch opt { + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + + case tcpip.MulticastTTLOption: + e.mu.Lock() + v := int(e.multicastTTL) + e.mu.Unlock() + return v, nil + + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + default: + return -1, &tcpip.ErrUnknownProtocolOption{} + } +} + +// SetSockOpt sets the socket option. +func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + switch v := opt.(type) { + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + defer e.mu.Unlock() + + fa := tcpip.FullAddress{Addr: v.InterfaceAddr} + fa, netProto, err := e.checkV4MappedLocked(fa) + if err != nil { + return err + } + nic := v.NIC + addr := fa.Addr + + if nic == 0 && addr == "" { + e.multicastAddr = "" + e.multicastNICID = 0 + break + } + + if nic != 0 { + if !e.stack.CheckNIC(nic) { + return &tcpip.ErrBadLocalAddress{} + } + } else { + nic = e.stack.CheckLocalAddress(0, netProto, addr) + if nic == 0 { + return &tcpip.ErrBadLocalAddress{} + } + } + + if e.info.BindNICID != 0 && e.info.BindNICID != nic { + return &tcpip.ErrInvalidEndpointState{} + } + + e.multicastNICID = nic + e.multicastAddr = addr + + case *tcpip.AddMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return &tcpip.ErrInvalidOptionValue{} + } + + nicID := v.NIC + + if v.InterfaceAddr.Unspecified() { + if nicID == 0 { + if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return &tcpip.ErrUnknownDevice{} + } + + memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.multicastMemberships[memToInsert]; ok { + return &tcpip.ErrPortInUse{} + } + + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships[memToInsert] = struct{}{} + + case *tcpip.RemoveMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return &tcpip.ErrInvalidOptionValue{} + } + + nicID := v.NIC + if v.InterfaceAddr.Unspecified() { + if nicID == 0 { + if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return &tcpip.ErrUnknownDevice{} + } + + memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + + e.mu.Lock() + defer e.mu.Unlock() + + if _, ok := e.multicastMemberships[memToRemove]; !ok { + return &tcpip.ErrBadLocalAddress{} + } + + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + delete(e.multicastMemberships, memToRemove) + + case *tcpip.SocketDetachFilterOption: + return nil + } + return nil +} + +// GetSockOpt returns the socket option. +func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + switch o := opt.(type) { + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + *o = tcpip.MulticastInterfaceOption{ + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, + } + e.mu.Unlock() + + default: + return &tcpip.ErrUnknownProtocolOption{} + } + return nil +} + +// Info returns a copy of the endpoint info. +func (e *Endpoint) Info() stack.TransportEndpointInfo { + e.mu.RLock() + defer e.mu.RUnlock() + return e.info +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go new file mode 100644 index 000000000..858007156 --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -0,0 +1,56 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" +) + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *Endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + + e.stack = s + + for m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(fmt.Sprintf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err)) + } + } + + switch state := e.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound: + if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { + if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 { + panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress)) + } + } + case transport.DatagramEndpointStateConnected: + var err tcpip.Error + multicastLoop := e.ops.GetMulticastLoop() + e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) + if err != nil { + panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) + } + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go new file mode 100644 index 000000000..2c43eb66a --- /dev/null +++ b/pkg/tcpip/transport/internal/network/endpoint_test.go @@ -0,0 +1,209 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +func TestEndpointStateTransitions(t *testing.T) { + const ( + nicID = 1 + ) + + var ( + ipv4NICAddr = testutil.MustParse4("1.2.3.4") + ipv6NICAddr = testutil.MustParse6("a::1") + ipv4RemoteAddr = testutil.MustParse4("6.7.8.9") + ipv6RemoteAddr = testutil.MustParse6("b::1") + ) + + data := buffer.View([]byte{1, 2, 4, 5}) + v4Checker := func(t *testing.T, b buffer.View) { + checker.IPv4(t, b, + checker.SrcAddr(ipv4NICAddr), + checker.DstAddr(ipv4RemoteAddr), + checker.IPPayload(data), + ) + } + + v6Checker := func(t *testing.T, b buffer.View) { + checker.IPv6(t, b, + checker.SrcAddr(ipv6NICAddr), + checker.DstAddr(ipv6RemoteAddr), + checker.IPPayload(data), + ) + } + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + expectedMaxHeaderLength uint16 + expectedNetProto tcpip.NetworkProtocolNumber + expectedLocalAddr tcpip.Address + bindAddr tcpip.Address + expectedBoundAddr tcpip.Address + remoteAddr tcpip.Address + expectedRemoteAddr tcpip.Address + checker func(*testing.T, buffer.View) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + expectedMaxHeaderLength: header.IPv4MaximumHeaderSize, + expectedNetProto: ipv4.ProtocolNumber, + expectedLocalAddr: ipv4NICAddr, + bindAddr: header.IPv4AllSystems, + expectedBoundAddr: header.IPv4AllSystems, + remoteAddr: ipv4RemoteAddr, + expectedRemoteAddr: ipv4RemoteAddr, + checker: v4Checker, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + expectedMaxHeaderLength: header.IPv6FixedHeaderSize, + expectedNetProto: ipv6.ProtocolNumber, + expectedLocalAddr: ipv6NICAddr, + bindAddr: header.IPv6AllNodesMulticastAddress, + expectedBoundAddr: header.IPv6AllNodesMulticastAddress, + remoteAddr: ipv6RemoteAddr, + expectedRemoteAddr: ipv6RemoteAddr, + checker: v6Checker, + }, + { + name: "IPv4-mapped-IPv6", + netProto: ipv6.ProtocolNumber, + expectedMaxHeaderLength: header.IPv4MaximumHeaderSize, + expectedNetProto: ipv4.ProtocolNumber, + expectedLocalAddr: ipv4NICAddr, + bindAddr: testutil.MustParse6("::ffff:e000:0001"), + expectedBoundAddr: header.IPv4AllSystems, + remoteAddr: testutil.MustParse6("::ffff:0607:0809"), + expectedRemoteAddr: ipv4RemoteAddr, + checker: v4Checker, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID}, + {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID}, + }) + + var ops tcpip.SocketOptions + var ep network.Endpoint + ep.Init(s, test.netProto, udp.ProtocolNumber, &ops) + if state := ep.State(); state != transport.DatagramEndpointStateInitial { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial) + } + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + if state := ep.State(); state != transport.DatagramEndpointStateBound { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound) + } + if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" { + t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff) + } + if addr, connected := ep.GetRemoteAddress(); connected { + t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr) + } + + connectAddr := tcpip.FullAddress{Addr: test.remoteAddr} + if err := ep.Connect(connectAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", connectAddr, err) + } + if state := ep.State(); state != transport.DatagramEndpointStateConnected { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected) + } + if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" { + t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff) + } + if addr, connected := ep.GetRemoteAddress(); !connected { + t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr) + } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" { + t.Errorf("remote address mismatch (-want +got):\n%s", diff) + } + + ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("ep.AcquireContexForWrite({}): %s", err) + } + defer ctx.Release() + info := ctx.PacketInfo() + if diff := cmp.Diff(network.WritePacketInfo{ + NetProto: test.expectedNetProto, + LocalAddress: test.expectedLocalAddr, + RemoteAddress: test.expectedRemoteAddr, + MaxHeaderLength: test.expectedMaxHeaderLength, + RequiresTXTransportChecksum: true, + }, info); diff != "" { + t.Errorf("write packet info mismatch (-want +got):\n%s", diff) + } + if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(info.MaxHeaderLength), + Data: data.ToVectorisedView(), + }), false /* headerIncluded */); err != nil { + t.Fatalf("ctx.WritePacket(_, false): %s", err) + } + if pkt, ok := e.Read(); !ok { + t.Fatalf("expected packet to be read from link endpoint") + } else { + test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader())) + } + + ep.Close() + if state := ep.State(); state != transport.DatagramEndpointStateClosed { + t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed) + } + }) + } +} diff --git a/pkg/tcpip/transport/transport.go b/pkg/tcpip/transport/transport.go new file mode 100644 index 000000000..4c2ae87f4 --- /dev/null +++ b/pkg/tcpip/transport/transport.go @@ -0,0 +1,16 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package transport supports transport protocols. +package transport diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index cdc344ab7..5cc7a2886 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -35,6 +35,8 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 108580508..ac7ecb5f8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,8 +15,8 @@ package udp import ( + "fmt" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sync" @@ -25,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -40,36 +42,6 @@ type udpPacket struct { tos uint8 } -// EndpointState represents the state of a UDP endpoint. -type EndpointState tcpip.EndpointState - -// Endpoint states. Note that are represented in a netstack-specific manner and -// may not be meaningful externally. Specifically, they need to be translated to -// Linux's representation for these states if presented to userspace. -const ( - _ EndpointState = iota - StateInitial - StateBound - StateConnected - StateClosed -) - -// String implements fmt.Stringer. -func (s EndpointState) String() string { - switch s { - case StateInitial: - return "INITIAL" - case StateBound: - return "BOUND" - case StateConnected: - return "CONNECTING" - case StateClosed: - return "CLOSED" - default: - return "UNKNOWN" - } -} - // endpoint represents a UDP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -79,7 +51,6 @@ func (s EndpointState) String() string { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and do not @@ -87,6 +58,10 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats tcpip.TransportEndpointStats `state:"nosave"` + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -96,37 +71,19 @@ type endpoint struct { rcvBufSize int rcvClosed bool - // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - // state must be read/set using the EndpointState()/setEndpointState() - // methods. - state uint32 - route *stack.Route `state:"manual"` - dstPort uint16 - ttl uint8 - multicastTTL uint8 - multicastAddr tcpip.Address - multicastNICID tcpip.NICID - portFlags ports.Flags - lastErrorMu sync.Mutex `state:"nosave"` lastError tcpip.Error + // The following fields are protected by the mu mutex. + mu sync.RWMutex `state:"nosave"` + portFlags ports.Flags + // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID boundPortFlags ports.Flags - // sendTOS represents IPv4 TOS or IPv6 TrafficClass, - // applied while sending packets. Defaults to 0 as on Linux. - sendTOS uint8 - - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - - // multicastMemberships that need to be remvoed when the endpoint is - // closed. Protected by the mu mutex. - multicastMemberships map[multicastMembership]struct{} + readShutdown bool // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -136,55 +93,25 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool -} -// +stateify savable -type multicastMembership struct { - nicID tcpip.NICID - multicastAddr tcpip.Address + localPort uint16 + remotePort uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: header.UDPProtocolNumber, - }, + stack: s, waiterQueue: waiterQueue, - // RFC 1075 section 5.4 recommends a TTL of 1 for membership - // requests. - // - // RFC 5135 4.2.1 appears to assume that IGMP messages have a - // TTL of 1. - // - // RFC 5135 Appendix A defines TTL=1: A multicast source that - // wants its traffic to not traverse a router (e.g., leave a - // home network) may find it useful to send traffic with IP - // TTL=1. - // - // Linux defaults to TTL=1. - multicastTTL: 1, - multicastMemberships: make(map[multicastMembership]struct{}), - state: uint32(StateInitial), - uniqueID: s.UniqueID(), + uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetSendBufferSize(32*1024, false /* notify */) e.ops.SetReceiveBufferSize(32*1024, false /* notify */) + e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -200,20 +127,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue return e } -// setEndpointState updates the state of the endpoint to state atomically. This -// method is unexported as the only place we should update the state is in this -// package but we allow the state to be read freely without holding e.mu. -// -// Precondition: e.mu must be held to call this method. -func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32(&e.state, uint32(state)) -} - -// EndpointState() returns the current state of the endpoint. -func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32(&e.state)) -} - // UniqueID implements stack.TransportEndpoint. func (e *endpoint) UniqueID() uint64 { return e.uniqueID @@ -244,16 +157,22 @@ func (e *endpoint) Abort() { // associated with it. func (e *endpoint) Close() { e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.EndpointState() { - case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + e.mu.Unlock() + return + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + id := e.net.Info().ID + id.LocalPort = e.localPort + id.RemotePort = e.remotePort + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice) portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: id.LocalAddress, + Port: id.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: tcpip.FullAddress{}, @@ -261,13 +180,10 @@ func (e *endpoint) Close() { e.stack.ReleasePort(portRes) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } - for mem := range e.multicastMemberships { - e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) - } - e.multicastMemberships = nil - // Close the receive list and drain it. e.rcvMu.Lock() e.rcvClosed = true @@ -278,14 +194,9 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - if e.route != nil { - e.route.Release() - e.route = nil - } - - // Update the state. - e.setEndpointState(StateClosed) - + e.net.Shutdown() + e.net.Close() + e.readShutdown = true e.mu.Unlock() e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) @@ -359,19 +270,19 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult return res, nil } -// prepareForWrite prepares the endpoint for sending data. In particular, it -// binds it if it's still in the initial state. To do so, it must first +// prepareForWriteInner prepares the endpoint for sending data. In particular, +// it binds it if it's still in the initial state. To do so, it must first // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. // +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.EndpointState() { - case StateInitial: - case StateConnected: +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - case StateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -386,7 +297,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.EndpointState() != StateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -398,33 +309,6 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip return true, nil } -// connectRoute establishes a route to the specified interface or the -// configured multicast interface if no interface is specified and the -// specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { - localAddr := e.ID.LocalAddress - if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { - // A packet can only originate from a unicast address (i.e., an interface). - localAddr = "" - } - - if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicID == 0 { - nicID = e.multicastNICID - } - if localAddr == "" && nicID == 0 { - localAddr = e.multicastAddr - } - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) - if err != nil { - return nil, 0, err - } - return r, nicID, nil -} - // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { @@ -448,18 +332,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { +func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return udpPacketInfo{}, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(opts.To) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { return udpPacketInfo{}, err } @@ -469,49 +348,27 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions } } - route := e.route - dstPort := e.dstPort + dst, connected := e.net.GetRemoteAddress() + dst.Port = e.remotePort if opts.To != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := opts.To.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return udpPacketInfo{}, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - if opts.To.Port == 0 { // Port 0 is an invalid port to send to. return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{} } - dst, netProto, err := e.checkV4MappedLocked(*opts.To) - if err != nil { - return udpPacketInfo{}, err - } - - r, _, err := e.connectRoute(nicID, dst, netProto) - if err != nil { - return udpPacketInfo{}, err - } - defer r.Release() - - route = r - dstPort = dst.Port + dst = *opts.To + } else if !connected { + return udpPacketInfo{}, &tcpip.ErrDestinationRequired{} } - if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{} + ctx, err := e.net.AcquireContextForWrite(opts) + if err != nil { + return udpPacketInfo{}, err } v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { + ctx.Release() return udpPacketInfo{}, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { @@ -520,50 +377,25 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions if so.GetRecvError() { so.QueueLocalErr( &tcpip.ErrMessageTooLong{}, - route.NetProto(), + e.net.NetProto(), header.UDPMaximumPacketSize, - tcpip.FullAddress{ - NIC: route.NICID(), - Addr: route.RemoteAddress(), - Port: dstPort, - }, + dst, v, ) } + ctx.Release() return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} } - ttl := e.ttl - useDefaultTTL := ttl == 0 - if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { - ttl = e.multicastTTL - // Multicast allows a 0 TTL. - useDefaultTTL = false - } - return udpPacketInfo{ - route: route, - data: buffer.View(v), - localPort: e.ID.LocalPort, - remotePort: dstPort, - ttl: ttl, - useDefaultTTL: useDefaultTTL, - tos: e.sendTOS, - owner: e.owner, - noChecksum: e.SocketOptions().GetNoChecksum(), + ctx: ctx, + data: v, + localPort: e.localPort, + remotePort: dst.Port, }, nil } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if err := e.LastError(); err != nil { - return 0, err - } - - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - // Do not hold lock when sending as loopback is synchronous and if the UDP // datagram ends up generating an ICMP response then it can result in a // deadlock where the ICMP response handling ends up acquiring this endpoint's @@ -574,15 +406,53 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read // locking is prohibited. - u, err := e.buildUDPPacketInfo(p, opts) - if err != nil { + + if err := e.LastError(); err != nil { return 0, err } - n, err := u.send() + + udpInfo, err := e.prepareForWrite(p, opts) if err != nil { return 0, err } - return int64(n), nil + defer udpInfo.ctx.Release() + + pktInfo := udpInfo.ctx.PacketInfo() + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength), + Data: udpInfo.data.ToVectorisedView(), + }) + + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = ProtocolNumber + + length := uint16(pkt.Size()) + udp.Encode(&header.UDPFields{ + SrcPort: udpInfo.localPort, + DstPort: udpInfo.remotePort, + Length: length, + }) + + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if pktInfo.RequiresTXTransportChecksum && + (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) { + udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine( + header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length), + pkt.Data().AsRange().Checksum(), + ))) + } + if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + e.stack.Stats().UDP.PacketSendErrors.Increment() + return 0, err + } + + // Track count of packets sent. + e.stack.Stats().UDP.PacketsSent.Increment() + return int64(len(udpInfo.data)), nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler. @@ -601,36 +471,7 @@ func (e *endpoint) OnReusePortSet(v bool) { // SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.MTUDiscoverOption: - // Return not supported if the value is not disabling path - // MTU discovery. - if v != tcpip.PMTUDiscoveryDont { - return &tcpip.ErrNotSupported{} - } - - case tcpip.MulticastTTLOption: - e.mu.Lock() - e.multicastTTL = uint8(v) - e.mu.Unlock() - - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - case tcpip.IPv4TOSOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - } - - return nil + return e.net.SetSockOptInt(opt, v) } var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) @@ -642,145 +483,12 @@ func (e *endpoint) HasNIC(id int32) bool { // SetSockOpt implements tcpip.Endpoint. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { - switch v := opt.(type) { - case *tcpip.MulticastInterfaceOption: - e.mu.Lock() - defer e.mu.Unlock() - - fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedLocked(fa) - if err != nil { - return err - } - nic := v.NIC - addr := fa.Addr - - if nic == 0 && addr == "" { - e.multicastAddr = "" - e.multicastNICID = 0 - break - } - - if nic != 0 { - if !e.stack.CheckNIC(nic) { - return &tcpip.ErrBadLocalAddress{} - } - } else { - nic = e.stack.CheckLocalAddress(0, netProto, addr) - if nic == 0 { - return &tcpip.ErrBadLocalAddress{} - } - } - - if e.BindNICID != 0 && e.BindNICID != nic { - return &tcpip.ErrInvalidEndpointState{} - } - - e.multicastNICID = nic - e.multicastAddr = addr - - case *tcpip.AddMembershipOption: - if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return &tcpip.ErrInvalidOptionValue{} - } - - nicID := v.NIC - - if v.InterfaceAddr.Unspecified() { - if nicID == 0 { - if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil { - nicID = r.NICID() - r.Release() - } - } - } else { - nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) - } - if nicID == 0 { - return &tcpip.ErrUnknownDevice{} - } - - memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - - e.mu.Lock() - defer e.mu.Unlock() - - if _, ok := e.multicastMemberships[memToInsert]; ok { - return &tcpip.ErrPortInUse{} - } - - if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { - return err - } - - e.multicastMemberships[memToInsert] = struct{}{} - - case *tcpip.RemoveMembershipOption: - if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return &tcpip.ErrInvalidOptionValue{} - } - - nicID := v.NIC - if v.InterfaceAddr.Unspecified() { - if nicID == 0 { - if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil { - nicID = r.NICID() - r.Release() - } - } - } else { - nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) - } - if nicID == 0 { - return &tcpip.ErrUnknownDevice{} - } - - memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - - e.mu.Lock() - defer e.mu.Unlock() - - if _, ok := e.multicastMemberships[memToRemove]; !ok { - return &tcpip.ErrBadLocalAddress{} - } - - if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { - return err - } - - delete(e.multicastMemberships, memToRemove) - - case *tcpip.SocketDetachFilterOption: - return nil - } - return nil + return e.net.SetSockOpt(opt) } // GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { - case tcpip.IPv4TOSOption: - e.mu.RLock() - v := int(e.sendTOS) - e.mu.RUnlock() - return v, nil - - case tcpip.IPv6TrafficClassOption: - e.mu.RLock() - v := int(e.sendTOS) - e.mu.RUnlock() - return v, nil - - case tcpip.MTUDiscoverOption: - // The only supported setting is path MTU discovery disabled. - return tcpip.PMTUDiscoveryDont, nil - - case tcpip.MulticastTTLOption: - e.mu.Lock() - v := int(e.multicastTTL) - e.mu.Unlock() - return v, nil - case tcpip.ReceiveQueueSizeOption: v := 0 e.rcvMu.Lock() @@ -791,108 +499,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.mu.Lock() - v := int(e.ttl) - e.mu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } // GetSockOpt implements tcpip.Endpoint. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { - switch o := opt.(type) { - case *tcpip.MulticastInterfaceOption: - e.mu.Lock() - *o = tcpip.MulticastInterfaceOption{ - NIC: e.multicastNICID, - InterfaceAddr: e.multicastAddr, - } - e.mu.Unlock() - - default: - return &tcpip.ErrUnknownProtocolOption{} - } - return nil + return e.net.GetSockOpt(opt) } -// udpPacketInfo contains all information required to send a UDP packet. -// -// This should be used as a value-only type, which exists in order to simplify -// return value syntax. It should not be exported or extended. +// udpPacketInfo holds information needed to send a UDP packet. type udpPacketInfo struct { - route *stack.Route - data buffer.View - localPort uint16 - remotePort uint16 - ttl uint8 - useDefaultTTL bool - tos uint8 - owner tcpip.PacketOwner - noChecksum bool -} - -// send sends the given packet. -func (u *udpPacketInfo) send() (int, tcpip.Error) { - vv := u.data.ToVectorisedView() - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()), - Data: vv, - }) - pkt.Owner = u.owner - - // Initialize the UDP header. - udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - pkt.TransportProtocolNumber = ProtocolNumber - - length := uint16(pkt.Size()) - udp.Encode(&header.UDPFields{ - SrcPort: u.localPort, - DstPort: u.remotePort, - Length: length, - }) - - // Set the checksum field unless TX checksum offload is enabled. - // On IPv4, UDP checksum is optional, and a zero value indicates the - // transmitter skipped the checksum generation (RFC768). - // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if u.route.RequiresTXTransportChecksum() && - (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) { - xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length) - for _, v := range vv.Views() { - xsum = header.Checksum(v, xsum) - } - udp.SetChecksum(^udp.CalculateChecksum(xsum)) - } - - if u.useDefaultTTL { - u.ttl = u.route.DefaultTTL() - } - if err := u.route.WritePacket(stack.NetworkHeaderParams{ - Protocol: ProtocolNumber, - TTL: u.ttl, - TOS: u.tos, - }, pkt); err != nil { - u.route.Stats().UDP.PacketSendErrors.Increment() - return 0, err - } - - // Track count of packets sent. - u.route.Stats().UDP.PacketsSent.Increment() - return len(u.data), nil -} - -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil + ctx network.WriteContext + data buffer.View + localPort uint16 + remotePort uint16 } // Disconnect implements tcpip.Endpoint. @@ -900,7 +522,7 @@ func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.EndpointState() != StateConnected { + if e.net.State() != transport.DatagramEndpointStateConnected { return nil } var ( @@ -913,26 +535,28 @@ func (e *endpoint) Disconnect() tcpip.Error { boundPortFlags := e.boundPortFlags // Exclude ephemerally bound endpoints. - if e.BindNICID != 0 || e.ID.LocalAddress == "" { + info := e.net.Info() + info.ID.LocalPort = e.localPort + info.ID.RemotePort = e.remotePort + if info.BindNICID != 0 || info.ID.LocalAddress == "" { var err tcpip.Error id = stack.TransportEndpointID{ - LocalPort: e.ID.LocalPort, - LocalAddress: e.ID.LocalAddress, + LocalPort: info.ID.LocalPort, + LocalAddress: info.ID.LocalAddress, } id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { return err } - e.setEndpointState(StateBound) boundPortFlags = e.boundPortFlags } else { - if e.ID.LocalPort != 0 { + if info.ID.LocalPort != 0 { // Release the ephemeral port. portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: info.ID.LocalAddress, + Port: info.ID.LocalPort, Flags: boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: tcpip.FullAddress{}, @@ -940,15 +564,14 @@ func (e *endpoint) Disconnect() tcpip.Error { e.stack.ReleasePort(portRes) e.boundPortFlags = ports.Flags{} } - e.setEndpointState(StateInitial) } - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) - e.ID = id + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice) e.boundBindToDevice = btd - e.route.Release() - e.route = nil - e.dstPort = 0 + e.localPort = id.LocalPort + e.remotePort = id.RemotePort + + e.net.Disconnect() return nil } @@ -958,88 +581,48 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - var localPort uint16 - switch e.EndpointState() { - case StateInitial: - case StateBound, StateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } - - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.localPort + nextID.RemotePort = addr.Port + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv4ProtocolNumber, + header.IPv6ProtocolNumber, + } } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } - - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - r, nicID, err := e.connectRoute(nicID, addr, netProto) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: e.ID.LocalAddress, - LocalPort: localPort, - RemotePort: addr.Port, - RemoteAddress: r.RemoteAddress(), - } + oldPortFlags := e.boundPortFlags - if e.EndpointState() == StateInitial { - id.LocalAddress = r.LocalAddress() - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv4ProtocolNumber, - header.IPv6ProtocolNumber, + nextID, btd, err := e.registerWithStack(netProtos, nextID) + if err != nil { + return err } - } - oldPortFlags := e.boundPortFlags + // Remove the old registration. + if e.localPort != 0 { + previousID.LocalPort = e.localPort + previousID.RemotePort = e.remotePort + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice) + } - id, btd, err := e.registerWithStack(netProtos, id) + e.localPort = nextID.LocalPort + e.remotePort = nextID.RemotePort + e.boundBindToDevice = btd + e.effectiveNetProtos = netProtos + return nil + }) if err != nil { - r.Release() return err } - // Remove the old registration. - if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) - } - - e.ID = id - e.boundBindToDevice = btd - if e.route != nil { - // If the endpoint was already connected then make sure we release the - // previous route. - e.route.Release() - } - e.route = r - e.dstPort = addr.Port - e.RegisterNICID = nicID - e.effectiveNetProtos = netProtos - - e.setEndpointState(StateConnected) - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() - return nil } @@ -1054,15 +637,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - // A socket in the bound state can still receive multicast messages, - // so we need to notify waiters on shutdown. - if state := e.EndpointState(); state != StateBound && state != StateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } - e.shutdownFlags |= flags + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } + } if flags&tcpip.ShutdownRead != 0 { + e.readShutdown = true + e.rcvMu.Lock() wasClosed := e.rcvClosed e.rcvClosed = true @@ -1088,7 +679,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - if e.ID.LocalPort == 0 { + if e.localPort == 0 { portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, @@ -1126,56 +717,43 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.EndpointState() != StateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{boundNetProto} + if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } } - } - nicID := addr.NIC - if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { - // A local unicast address was specified, verify that it's valid. - nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicID == 0 { - return &tcpip.ErrBadLocalAddress{} + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: boundAddr, + } + id, btd, err := e.registerWithStack(netProtos, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, btd, err := e.registerWithStack(netProtos, id) + e.localPort = id.LocalPort + e.boundBindToDevice = btd + e.effectiveNetProtos = netProtos + return nil + }) if err != nil { return err } - e.ID = id - e.boundBindToDevice = btd - e.RegisterNICID = nicID - e.effectiveNetProtos = netProtos - - // Mark endpoint as bound. - e.setEndpointState(StateBound) - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() - return nil } @@ -1190,9 +768,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { return err } - // Save the effective NICID generated by bindLocked. - e.BindNICID = e.RegisterNICID - return nil } @@ -1201,16 +776,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - addr := e.ID.LocalAddress - if e.EndpointState() == StateConnected { - addr = e.route.LocalAddress() - } - - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: addr, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.localPort + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -1218,15 +786,13 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.EndpointState() != StateConnected || e.dstPort == 0 { + addr, connected := e.net.GetRemoteAddress() + if !connected || e.remotePort == 0 { return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + addr.Port = e.remotePort + return addr, nil } // Readiness returns the current readiness of the endpoint. For example, if @@ -1376,19 +942,20 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p payload = udp.Payload() } + id := e.net.Info().ID e.SocketOptions().QueueErr(&tcpip.SockError{ Err: err, Cause: transErr, Payload: payload, Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: id.RemoteAddress, + Port: e.remotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: id.LocalAddress, + Port: e.localPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -1403,7 +970,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // TODO(gvisor.dev/issues/5270): Handle all transport errors. switch transErr.Kind() { case stack.DestinationPortUnreachableTransportError: - if e.EndpointState() == StateConnected { + if e.net.State() == transport.DatagramEndpointStateConnected { e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) } } @@ -1411,16 +978,17 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // State implements tcpip.Endpoint. func (e *endpoint) State() uint32 { - return uint32(e.EndpointState()) + return uint32(e.net.State()) } // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() - return &ret + defer e.mu.RUnlock() + info := e.net.Info() + info.ID.LocalPort = e.localPort + info.ID.RemotePort = e.remotePort + return &info } // Stats returns a pointer to the endpoint stats. @@ -1431,13 +999,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements tcpip.Endpoint. func (*endpoint) Wait() {} -func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) -} - // SetOwner implements tcpip.Endpoint. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // SocketOptions implements tcpip.Endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 1f638c3f6..20c45ab87 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,12 +15,13 @@ package udp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -66,50 +67,28 @@ func (e *endpoint) Resume(s *stack.Stack) { e.mu.Lock() defer e.mu.Unlock() + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - for m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } - - state := e.EndpointState() - if state != StateBound && state != StateConnected { - return - } - - netProto := e.effectiveNetProtos[0] - // Connect() and bindLocked() both assert - // - // netProto == header.IPv6ProtocolNumber - // - // before creating a multi-entry effectiveNetProtos. - if len(e.effectiveNetProtos) > 1 { - netProto = header.IPv6ProtocolNumber - } - - var err tcpip.Error - if state == StateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + var err tcpip.Error + id := e.net.Info().ID + id.LocalPort = e.localPort + id.RemotePort = e.remotePort + id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { panic(err) } - } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound - // A local unicast address is specified, verify that it's valid. - if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - // Our saved state had a port, but we don't actually have a - // reservation. We need to remove the port from our state, but still - // pass it to the reservation machinery. - id := e.ID - e.ID.LocalPort = 0 - e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) - if err != nil { - panic(err) + e.localPort = id.LocalPort + e.remotePort = id.RemotePort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 7c357cb09..7238fc019 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -70,28 +70,29 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) + ep.mu.Lock() + defer ep.mu.Unlock() + netHdr := r.pkt.Network() - route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) - if err != nil { + if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil { + return nil, err + } + + if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil { return nil, err } - ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() - route.Release() return nil, err } - ep.ID = r.id - ep.route = route - ep.dstPort = r.id.RemotePort + ep.localPort = r.id.LocalPort + ep.remotePort = r.id.RemotePort ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} - ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = uint32(StateConnected) - ep.rcvMu.Lock() ep.rcvReady = true ep.rcvMu.Unlock() |