diff options
author | Kevin Krakauer <krakauer@google.com> | 2019-04-02 11:12:29 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-04-02 11:13:49 -0700 |
commit | 52a51a8e20b3e5c28eb1e66bd57203216cf644c5 (patch) | |
tree | 4d98f67dc4abb15e6de568b2af12a3e12f41aa99 /pkg/tcpip/transport | |
parent | 1df3fa69977477092efa65a8de407bd6f0f88db4 (diff) |
Add a raw socket transport endpoint and use it for raw ICMP sockets.
Having raw socket code together will make it easier to add support for other raw
network protocols. Currently, only ICMP uses the raw endpoint. However, adding
support for other protocols such as UDP shouldn't be much more difficult than
adding a few switch cases.
PiperOrigin-RevId: 241564875
Change-Id: I77e03adafe4ce0fd29ba2d5dfdc547d2ae8f25bf
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/protocol.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/BUILD | 45 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/raw.go | 558 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/state.go | 88 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 |
8 files changed, 719 insertions, 58 deletions
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 74d9ff253..9aa6f3978 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/raw", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 182097b46..8f2e3aa20 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -59,10 +59,6 @@ type endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue - // raw indicates whether the endpoint is intended for use by a raw - // socket, which returns the network layer header along with the - // payload. It is immutable. - raw bool // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -80,32 +76,26 @@ type endpoint struct { shutdownFlags tcpip.ShutdownFlags id stack.TransportEndpointID state endpointState - bindNICID tcpip.NICID - bindAddr tcpip.Address - regNICID tcpip.NICID - route stack.Route `state:"manual"` + // bindNICID and bindAddr are set via calls to Bind(). They are used to + // reject attempts to send data or connect via a different NIC or + // address + bindNICID tcpip.NICID + bindAddr tcpip.Address + // regNICID is the default NIC to be used when callers don't specify a + // NIC. + regNICID tcpip.NICID + route stack.Route `state:"manual"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, raw bool) (*endpoint, *tcpip.Error) { - e := &endpoint{ +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return &endpoint{ stack: stack, netProto: netProto, transProto: transProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, - raw: raw, - } - - // Raw endpoints must be immediately bound because they receive all - // ICMP traffic starting from when they're created via socket(). - if raw { - if err := e.bindLocked(tcpip.FullAddress{}); err != nil { - return nil, err - } - } - - return e, nil + }, nil } // Close puts the endpoint in a closed state and frees all resources @@ -115,11 +105,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - if e.raw { - e.stack.UnregisterRawTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e) - } else { - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) - } + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) } // Close the receive list and drain it. @@ -244,8 +230,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c route = &e.route if route.IsResolutionRequired() { - // Promote lock to exclusive if using a shared route, given that it may - // need to change in Route.Resolve() call below. + // Promote lock to exclusive if using a shared route, + // given that it may need to change in Route.Resolve() + // call below. e.mu.RUnlock() defer e.mu.RLock() @@ -290,8 +277,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c waker := &sleep.Waker{} if ch, err := route.Resolve(waker); err != nil { if err == tcpip.ErrWouldBlock { - // Link address needs to be resolved. Resolution was triggered the - // background. Better luck next time. + // Link address needs to be resolved. + // Resolution was triggered the background. + // Better luck next time. route.RemoveWaker(waker) return 0, ch, tcpip.ErrNoLinkAddress } @@ -368,11 +356,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { - if e.raw { - hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) - } - if len(data) < header.ICMPv4EchoMinimumSize { return tcpip.ErrInvalidEndpointState } @@ -439,11 +422,6 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - // TODO: We don't yet support connect on a raw socket. - if e.raw { - return tcpip.ErrNotSupported - } - e.mu.Lock() defer e.mu.Unlock() @@ -547,11 +525,6 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if e.raw { - err := e.stack.RegisterRawTransportEndpoint(nicid, netProtos, e.transProto, e, false) - return stack.TransportEndpointID{}, err - } - if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. @@ -687,11 +660,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.stack.Stats().DroppedPackets.Increment() e.rcvMu.Unlock() return } @@ -706,13 +680,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, ne }, } - if e.raw { - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - pkt.data = combinedVV.Clone(pkt.views[:]) - } else { - pkt.data = vv.Clone(pkt.views[:]) - } + pkt.data = vv.Clone(pkt.views[:]) e.rcvList.PushBack(pkt) e.rcvBufSize += pkt.data.Size() diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 36b70988a..09ee2f892 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -30,6 +30,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -73,7 +74,7 @@ func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtoco if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue, false) + return newEndpoint(stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements @@ -82,7 +83,7 @@ func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProt if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue, true) + return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) } // MinimumPacketSize returns the minimum valid icmp packet size. diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD new file mode 100644 index 000000000..005079639 --- /dev/null +++ b/pkg/tcpip/transport/raw/BUILD @@ -0,0 +1,45 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + +go_template_instance( + name = "packet_list", + out = "packet_list.go", + package = "raw", + prefix = "packet", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*packet", + "Linker": "*packet", + }, +) + +go_library( + name = "raw", + srcs = [ + "packet_list.go", + "raw.go", + "state.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw", + imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sleep", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) + +filegroup( + name = "autogen", + srcs = [ + "packet_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/transport/raw/raw.go b/pkg/tcpip/transport/raw/raw.go new file mode 100644 index 000000000..8dada2e4f --- /dev/null +++ b/pkg/tcpip/transport/raw/raw.go @@ -0,0 +1,558 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package raw provides the implementation of raw sockets (see raw(7)). Raw +// sockets allow applications to: +// +// * manually write and inspect transport layer headers and payloads +// * receive all traffic of a given transport protcol (e.g. ICMP or UDP) +// * optionally write and inspect network layer and link layer headers for +// packets +// +// Raw sockets don't have any notion of ports, and incoming packets are +// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will +// receive every UDP packet received by netstack. bind(2) and connect(2) can be +// used to filter incoming packets by source and destination. +package raw + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// +stateify savable +type packet struct { + packetEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // views is pre-allocated space to back data. As long as the packet is + // made up of fewer than 8 buffer.Views, no extra allocation is + // necessary to store packet data. + views [8]buffer.View `state:"nosave"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to +// have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + waiterQueue *waiter.Queue + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + closed bool + connected bool + bound bool + // registeredNIC is the NIC to which th endpoint is explicitly + // registered. Is set when Connect or Bind are used to specify a NIC. + registeredNIC tcpip.NICID + // boundNIC and boundAddr are set on calls to Bind(). When callers + // attempt actions that would invalidate the binding data (e.g. sending + // data via a NIC other than boundNIC), the endpoint will return an + // error. + boundNIC tcpip.NICID + boundAddr tcpip.Address + // route is the route to a remote network endpoint. It is set via + // Connect(), and is valid only when conneted is true. + route stack.Route `state:"manual"` +} + +// NewEndpoint returns a raw endpoint for the given protocols. +// TODO: IP_HDRINCL, IPPROTO_RAW, and AF_PACKET. +func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != header.IPv4ProtocolNumber { + return nil, tcpip.ErrUnknownProtocol + } + + ep := &endpoint{ + stack: stack, + netProto: netProto, + transProto: transProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + return nil, err + } + + return ep, nil +} + +// Close implements tcpip.Endpoint.Close. +func (ep *endpoint) Close() { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return + } + + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + + // Clear the receive list. + ep.rcvClosed = true + ep.rcvBufSize = 0 + for !ep.rcvList.Empty() { + ep.rcvList.Remove(ep.rcvList.Front()) + } + + if ep.connected { + ep.route.Release() + } + + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + ep.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if ep.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if ep.rcvClosed { + err = tcpip.ErrClosedForReceive + } + ep.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + packet := ep.rcvList.Front() + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + + ep.rcvMu.Unlock() + + if addr != nil { + *addr = packet.senderAddr + } + + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil +} + +// Write implements tcpip.Endpoint.Write. +func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + ep.mu.RLock() + + if ep.closed { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + // Check whether we've shutdown writing. + if ep.shutdownFlags&tcpip.ShutdownWrite != 0 { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrClosedForSend + } + + // Did the user caller provide a destination? If not, use the connected + // destination. + if opts.To == nil { + // If the user doesn't specify a destination, they should have + // connected to another address. + if !ep.connected { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrNotConnected + } + + if ep.route.IsResolutionRequired() { + savedRoute := &ep.route + // Promote lock to exclusive if using a shared route, + // given that it may need to change in finishWrite. + ep.mu.RUnlock() + ep.mu.Lock() + + // Make sure that the route didn't change during the + // time we didn't hold the lock. + if !ep.connected || savedRoute != &ep.route { + ep.mu.Unlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + n, ch, err := ep.finishWrite(payload, savedRoute) + ep.mu.Unlock() + return n, ch, err + } + + n, ch, err := ep.finishWrite(payload, &ep.route) + ep.mu.RUnlock() + return n, ch, err + } + + // The caller provided a destination. Reject destination address if it + // goes through a different NIC than the endpoint was bound to. + nic := opts.To.NIC + if ep.bound && nic != 0 && nic != ep.boundNIC { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrNoRoute + } + + // We don't support IPv6 yet, so this has to be an IPv4 address. + if len(opts.To.Addr) != header.IPv4AddressSize { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + // Find the route to the destination. If boundAddress is 0, + // FindRoute will choose an appropriate source address. + route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false) + if err != nil { + ep.mu.RUnlock() + return 0, nil, err + } + + n, ch, err := ep.finishWrite(payload, &route) + route.Release() + ep.mu.RUnlock() + return n, ch, err +} + +// finishWrite writes the payload to a route. It resolves the route if +// necessary. It's really just a helper to make defer unnecessary in Write. +func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { + // We may need to resolve the route (match a link layer address to the + // network address). If that requires blocking (e.g. to use ARP), + // return a channel on which the caller can wait. + if route.IsResolutionRequired() { + waker := &sleep.Waker{} + if ch, err := route.Resolve(waker); err != nil { + if err == tcpip.ErrWouldBlock { + // Link address needs to be resolved. + // Resolution was triggered the background. + // Better luck next time. + route.RemoveWaker(waker) + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + payloadBytes, err := payload.Get(payload.Size()) + if err != nil { + return 0, nil, err + } + + switch ep.netProto { + case header.IPv4ProtocolNumber: + hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) + if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), header.ICMPv4ProtocolNumber, route.DefaultTTL()); err != nil { + return 0, nil, err + } + + default: + return 0, nil, tcpip.ErrUnknownProtocol + } + + return uintptr(len(payloadBytes)), nil, nil +} + +// Peek implements tcpip.Endpoint.Peek. +func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Connect implements tcpip.Endpoint.Connect. +func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return tcpip.ErrInvalidEndpointState + } + + // We don't support IPv6 yet. + if len(addr.Addr) != header.IPv4AddressSize { + return tcpip.ErrInvalidEndpointState + } + + nic := addr.NIC + if ep.bound { + if ep.boundNIC == 0 { + // If we're bound, but not to a specific NIC, the NIC + // in addr will be used. Nothing to do here. + } else if addr.NIC == 0 { + // If we're bound to a specific NIC, but addr doesn't + // specify a NIC, use the bound NIC. + nic = ep.boundNIC + } else if addr.NIC != ep.boundNIC { + // We're bound and addr specifies a NIC. They must be + // the same. + return tcpip.ErrInvalidEndpointState + } + } + + // Find a route to the destination. + route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false) + if err != nil { + return err + } + defer route.Release() + + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + // Save the route and NIC we've connected via. + ep.route = route.Clone() + ep.registeredNIC = nic + ep.connected = true + + return nil +} + +// Shutdown implements tcpip.Endpoint.Shutdown. +func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + if !ep.connected { + return tcpip.ErrNotConnected + } + + ep.shutdownFlags |= flags + + if flags&tcpip.ShutdownRead != 0 { + ep.rcvMu.Lock() + wasClosed := ep.rcvClosed + ep.rcvClosed = true + ep.rcvMu.Unlock() + + if !wasClosed { + ep.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen implements tcpip.Endpoint.Listen. +func (ep *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. +func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + // Callers must provide an IPv4 address or no network address (for + // binding to a NIC, but not an address). + if len(addr.Addr) != 0 && len(addr.Addr) != 4 { + return tcpip.ErrInvalidEndpointState + } + + // If a local address was specified, verify that it's valid. + if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + ep.registeredNIC = addr.NIC + ep.boundNIC = addr.NIC + ep.boundAddr = addr.Addr + ep.bound = true + + return nil +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + ep.mu.RLock() + defer ep.mu.RUnlock() + + if !ep.connected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: ep.registeredNIC, + Addr: ep.route.RemoteAddress, + }, nil +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + ep.rcvMu.Lock() + if !ep.rcvList.Empty() || ep.rcvClosed { + result |= waiter.EventIn + } + ep.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + ep.mu.Lock() + *o = tcpip.SendBufferSizeOption(ep.sndBufSize) + ep.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) + ep.rcvMu.Unlock() + return nil + + case *tcpip.ReceiveQueueSizeOption: + ep.rcvMu.Lock() + if ep.rcvList.Empty() { + *o = 0 + } else { + p := ep.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + ep.rcvMu.Unlock() + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// HandlePacket implements stack.RawTransportEndpoint.HandlePacket. +func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { + ep.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax { + ep.stack.Stats().DroppedPackets.Increment() + ep.rcvMu.Unlock() + return + } + + if ep.bound { + // If bound to a NIC, only accept data for that NIC. + if ep.boundNIC != 0 && ep.boundNIC != route.NICID() { + ep.rcvMu.Unlock() + return + } + // If bound to an address, only accept data for that address. + if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress { + ep.rcvMu.Unlock() + return + } + } + + // If connected, only accept packets from the remote address we + // connected to. + if ep.connected && ep.route.RemoteAddress != route.RemoteAddress { + ep.rcvMu.Unlock() + return + } + + wasEmpty := ep.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + packet := &packet{ + senderAddr: tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + }, + } + + combinedVV := netHeader.ToVectorisedView() + combinedVV.Append(vv) + packet.data = combinedVV.Clone(packet.views[:]) + packet.timestampNS = ep.stack.NowNanoseconds() + + ep.rcvList.PushBack(packet) + ep.rcvBufSize += packet.data.Size() + + ep.rcvMu.Unlock() + + // Notify waiters that there's data to be read. + if wasEmpty { + ep.waiterQueue.Notify(waiter.EventIn) + } +} diff --git a/pkg/tcpip/transport/raw/state.go b/pkg/tcpip/transport/raw/state.go new file mode 100644 index 000000000..e3891a8b8 --- /dev/null +++ b/pkg/tcpip/transport/raw/state.go @@ -0,0 +1,88 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raw + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// saveData saves packet.data field. +func (p *packet) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads packet.data field. +func (p *packet) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + // StackFromEnv is a stack used specifically for save/restore. + ep.stack = stack.StackFromEnv + + // If the endpoint is connected, re-connect via the save/restore stack. + if ep.connected { + var err *tcpip.Error + ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + if err != nil { + panic(*err) + } + } + + // If the endpoint is bound, re-bind via the save/restore stack. + if ep.bound { + if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0427af34f..41c87cc7e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1438,7 +1438,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { s := newSegment(r, id, vv) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5637f46e3..19e532180 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -940,7 +940,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { // Get the header then trim it from the view. hdr := header.UDP(vv.First()) if int(hdr.Length()) > vv.Size() { |