diff options
Diffstat (limited to 'pkg')
26 files changed, 534 insertions, 517 deletions
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 796efa240..59784eacb 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -509,6 +509,24 @@ func TestView(t *testing.T) { } } +func TestViewClone(t *testing.T) { + const ( + originalSize = 90 + bytesToDelete = 30 + ) + var v View + v.AppendOwned(bytes.Repeat([]byte{originalSize}, originalSize)) + + clonedV := v.Clone() + v.TrimFront(bytesToDelete) + if got, want := int(v.Size()), originalSize-bytesToDelete; got != want { + t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) + } + if got := clonedV.Size(); got != originalSize { + t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) + } +} + func TestViewPullUp(t *testing.T) { for _, tc := range []struct { desc string diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD index 0a045fc8e..2a1602e2b 100644 --- a/pkg/safecopy/BUILD +++ b/pkg/safecopy/BUILD @@ -18,9 +18,9 @@ go_library( ], visibility = ["//:sandbox"], deps = [ - "//pkg/abi/linux", "//pkg/errors", "//pkg/errors/linuxerr", + "//pkg/sighandling", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go index a9711e63d..0dd0aea83 100644 --- a/pkg/safecopy/safecopy.go +++ b/pkg/safecopy/safecopy.go @@ -23,6 +23,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/errors" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/sighandling" ) // SegvError is returned when a safecopy function receives SIGSEGV. @@ -132,10 +133,10 @@ func initializeAddresses() { func init() { initializeAddresses() - if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) } - if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) } linuxerr.AddErrorUnwrapper(func(e error) (*errors.Error, bool) { diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index 2365b2c0d..15f84abea 100644 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go @@ -20,7 +20,6 @@ import ( "unsafe" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" ) // maxRegisterSize is the maximum register size used in memcpy and memclr. It @@ -332,39 +331,3 @@ func errorFromFaultSignal(addr uintptr, sig int32) error { panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr)) } } - -// ReplaceSignalHandler replaces the existing signal handler for the provided -// signal with the one that handles faults in safecopy-protected functions. -// -// It stores the value of the previously set handler in previous. -// -// This function will be called on initialization in order to install safecopy -// handlers for appropriate signals. These handlers will call the previous -// handler however, and if this is function is being used externally then the -// same courtesy is expected. -func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { - var sa linux.SigAction - const maskLen = 8 - - // Get the existing signal handler information, and save the current - // handler. Once we replace it, we will use this pointer to fall back to - // it when we receive other signals. - if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { - return e - } - - // Fail if there isn't a previous handler. - if sa.Handler == 0 { - return fmt.Errorf("previous handler for signal %x isn't set", sig) - } - - *previous = uintptr(sa.Handler) - - // Install our own handler. - sa.Handler = uint64(handler) - if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { - return e - } - - return nil -} diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index c5b099559..f0c168ecc 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -191,9 +191,11 @@ const ( // // Preconditions: The task's owning TaskSet.mu must be locked. func (t *Task) updateInfoLocked() { - // Use the task's TID in the root PID namespace for logging. + // Use the task's TID and PID in the root PID namespace for logging. + pid := t.tg.pidns.owner.Root.tgids[t.tg] tid := t.tg.pidns.owner.Root.tids[t] - t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid)) + t.logPrefix.Store(fmt.Sprintf("[% 4d:% 4d] ", pid, tid)) + t.rebuildTraceContext(tid) } diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index a26f54269..834d72408 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -63,7 +63,6 @@ go_library( "//pkg/procid", "//pkg/ring0", "//pkg/ring0/pagetables", - "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/arch/fpu", @@ -71,6 +70,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/time", + "//pkg/sighandling", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index 826997e77..5be2215ed 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -19,8 +19,8 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/ring0" - "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sighandling" ) // bluepill enters guest mode. @@ -97,7 +97,7 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index dcf34015d..f1f7e4ea4 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -28,9 +28,9 @@ import ( "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/seccomp" ktime "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sighandling" "gvisor.dev/gvisor/pkg/sync" ) @@ -723,7 +723,7 @@ func addrOfSigsysHandler() uintptr func seccompMmapRules(m *machine) { seccompMmapRulesOnce.Do(func() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil { + if err := sighandling.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } rules := []seccomp.RuleSet{} diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 757ff2a40..4d3f4d556 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -610,9 +610,9 @@ func (i *SyscallInfo) printExit(t *kernel.Task, elapsed time.Duration, output [] if err == nil { // Fill in the output after successful execution. i.post(t, args, retval, output, LogMaximumSize) - rval = fmt.Sprintf("%#x (%v)", retval, elapsed) + rval = fmt.Sprintf("%d (%#x) (%v)", retval, retval, elapsed) } else { - rval = fmt.Sprintf("%#x errno=%d (%s) (%v)", retval, errno, err, elapsed) + rval = fmt.Sprintf("%d (%#x) errno=%d (%s) (%v)", retval, retval, errno, err, elapsed) } switch len(output) { diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sighandling/BUILD index 1790d57c9..72f10f982 100644 --- a/pkg/sentry/sighandling/BUILD +++ b/pkg/sighandling/BUILD @@ -8,7 +8,7 @@ go_library( "sighandling.go", "sighandling_unsafe.go", ], - visibility = ["//pkg/sentry:internal"], + visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sighandling/sighandling.go index bdaf8af29..bdaf8af29 100644 --- a/pkg/sentry/sighandling/sighandling.go +++ b/pkg/sighandling/sighandling.go diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sighandling/sighandling_unsafe.go index 3fe5c6770..7deeda042 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sighandling/sighandling_unsafe.go @@ -15,6 +15,7 @@ package sighandling import ( + "fmt" "unsafe" "golang.org/x/sys/unix" @@ -37,3 +38,36 @@ func IgnoreChildStop() error { return nil } + +// ReplaceSignalHandler replaces the existing signal handler for the provided +// signal with the function pointer at `handler`. This bypasses the Go runtime +// signal handlers, and should only be used for low-level signal handlers where +// use of signal.Notify is not appropriate. +// +// It stores the value of the previously set handler in previous. +func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { + var sa linux.SigAction + const maskLen = 8 + + // Get the existing signal handler information, and save the current + // handler. Once we replace it, we will use this pointer to fall back to + // it when we receive other signals. + if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { + return e + } + + // Fail if there isn't a previous handler. + if sa.Handler == 0 { + return fmt.Errorf("previous handler for signal %x isn't set", sig) + } + + *previous = uintptr(sa.Handler) + + // Install our own handler. + sa.Handler = uint64(handler) + if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { + return e + } + + return nil +} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index d51c36f19..1c3b0887f 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -167,14 +167,17 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet p := hdr.TransportProtocol() dstAddr := hdr.DestinationAddress() // Skip the ip header, then deliver the error. - pkt.Data().DeleteFront(hlen) + if _, ok := pkt.Data().Consume(hlen); !ok { + panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen)) + } e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived // ICMP packets don't have their TransportHeader fields set. See - // icmp/protocol.go:protocol.Parse for a full explanation. + // icmp/protocol.go:protocol.Parse for a full explanation. Not all ICMP types + // require consuming the header, so we only call PullUp. v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { received.invalid.Increment() @@ -242,7 +245,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // DeliverTransportPacket will take ownership of pkt so don't use it beyond // this point. Make a deep copy of the data before pkt gets sent as we will - // be modifying fields. + // be modifying fields. Both the ICMP header (with its type modified to + // EchoReply) and payload are reused in the reply packet. // // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no // waiting endpoints. Consider moving responsibility for doing the copy to @@ -331,6 +335,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.echoReply.Increment() + // ICMP sockets expect the ICMP header to be present, so we don't consume + // the ICMP header. e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: @@ -338,7 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { mtu := h.MTU() code := h.Code() - pkt.Data().DeleteFront(header.ICMPv4MinimumSize) + if _, ok := pkt.Data().Consume(header.ICMPv4MinimumSize); !ok { + panic("could not consume ICMPv4MinimumSize bytes") + } switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index dda473e48..9b71738ae 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -466,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -576,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6c6107264..ff23d48e7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().DeleteFront(header.IPv6MinimumSize) + if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok { + panic("could not consume IPv6MinimumSize bytes") + } if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) + if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok { + panic("could not consume IPv6FragmentHeaderSize bytes") + } } e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) @@ -325,7 +329,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.packetTooBig.Increment() - hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize) + hdr, ok := pkt.Data().Consume(header.ICMPv6PacketTooBigMinimumSize) if !ok { received.invalid.Increment() return @@ -334,18 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r if err != nil { networkMTU = 0 } - pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() - hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize) + hdr, ok := pkt.Data().Consume(header.ICMPv6DstUnreachableMinimumSize) if !ok { received.invalid.Increment() return } code := header.ICMPv6(hdr).Code() - pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) switch code { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e2d2cf907..600e805f8 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -1537,19 +1537,22 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. - // Calculate the number of octets parsed from data. We want to remove all - // the data except the unparsed portion located at the end, which its size - // is extHdr.Buf.Size(). + // Calculate the number of octets parsed from data. We want to consume all + // the data except the unparsed portion located at the end, whose size is + // extHdr.Buf.Size(). trim := pkt.Data().Size() - extHdr.Buf.Size() // For unfragmented packets, extHdr still contains the transport header. - // Get rid of it. + // Consume that too. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. trim += pkt.TransportHeader().View().Size() - pkt.Data().DeleteFront(trim) + if _, ok := pkt.Data().Consume(trim); !ok { + stats.MalformedPacketsReceived.Increment() + return fmt.Errorf("could not consume %d bytes", trim) + } stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 4fb7e9adb..b7cb54b1d 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -388,28 +388,33 @@ func (ct *ConnTrack) insertConn(conn *conn) { // connection exists. Returns whether, after the packet traverses the tables, // it should create a new entry in the table. func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { - if pkt.NatDone { - return false - } - switch hook { case Prerouting, Input, Output, Postrouting: default: return false } - transportHeader, ok := getTransportHeader(pkt) - if !ok { + if conn, dir := ct.connFor(pkt); conn != nil { + conn.handlePacket(pkt, hook, dir, r) return false } - conn, dir := ct.connFor(pkt) // Connection not found for the packet. - if conn == nil { - // If this is the last hook in the data path for this packet (Input if - // incoming, Postrouting if outgoing), indicate that a connection should be - // inserted by the end of this hook. - return hook == Input || hook == Postrouting + // + // If this is the last hook in the data path for this packet (Input if + // incoming, Postrouting if outgoing), indicate that a connection should be + // inserted by the end of this hook. + return hook == Input || hook == Postrouting +} + +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Route) { + if pkt.NatDone { + return + } + + transportHeader, ok := getTransportHeader(pkt) + if !ok { + return } netHeader := pkt.Network() @@ -425,24 +430,24 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { switch hook { case Prerouting, Output: - if conn.manip == manipDestination && dir == dirOriginal { - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr + if cn.manip == manipDestination && dir == dirOriginal { + newPort = cn.reply.srcPort + newAddr = cn.reply.srcAddr pkt.NatDone = true - } else if conn.manip == manipSource && dir == dirReply { - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr + } else if cn.manip == manipSource && dir == dirReply { + newPort = cn.original.srcPort + newAddr = cn.original.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource && dir == dirOriginal { - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr + if cn.manip == manipSource && dir == dirOriginal { + newPort = cn.reply.dstPort + newAddr = cn.reply.dstAddr updateSRCFields = true pkt.NatDone = true - } else if conn.manip == manipDestination && dir == dirReply { - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr + } else if cn.manip == manipDestination && dir == dirReply { + newPort = cn.original.dstPort + newAddr = cn.original.dstAddr updateSRCFields = true pkt.NatDone = true } @@ -451,7 +456,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } if !pkt.NatDone { - return false + return } fullChecksum := false @@ -486,15 +491,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { ) // Update the state of tcb. - conn.mu.Lock() - defer conn.mu.Unlock() + cn.mu.Lock() + defer cn.mu.Unlock() // Mark the connection as having been used recently so it isn't reaped. - conn.lastUsed = time.Now() + cn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(pkt, hook) - - return false + cn.updateLocked(pkt, hook) } // maybeInsertNoop tries to insert a no-op connection entry to keep connections diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 74c9075b4..dcba7eba6 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -310,8 +310,8 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) // must be dropped if false is returned. // // Precondition: The packet's network and transport header must be set. -func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool { - return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool { + return it.check(Postrouting, pkt, r, addressEP, "" /* inNicName */, outNicName) } // check runs pkt through the rules for hook. It returns true when the packet @@ -431,7 +431,9 @@ func (it *IPTables) startReaper(interval time.Duration) { // // Precondition: The packets' network and transport header must be set. func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { - return it.checkPackets(Output, pkts, r, outNicName) + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckOutput(pkt, r, outNicName) + }) } // CheckPostroutingPackets performs the postrouting hook on the packets. @@ -439,21 +441,16 @@ func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicNa // Returns a map of packets that must be dropped. // // Precondition: The packets' network and transport header must be set. -func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { - return it.checkPackets(Postrouting, pkts, r, outNicName) +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckPostrouting(pkt, r, addressEP, outNicName) + }) } -// checkPackets runs pkts through the rules for hook and returns a map of -// packets that should not go forward. -// -// NOTE: unlike the Check API the returned map contains packets that should be -// dropped. -// -// Precondition: The packets' network and transport header must be set. -func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok { + if ok := f(pkt); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index e8806ebdb..949c44c9b 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -162,7 +162,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r // packet of the connection comes here. Other packets will be // manipulated in connection tracking. if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, r) + conn.handlePacket(pkt, hook, dirOriginal, r) } default: return RuleDrop, 0 @@ -181,15 +181,7 @@ type SNATTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { - // Sanity check. - if st.NetworkProtocol != pkt.NetworkProtocolNumber { - panic(fmt.Sprintf( - "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", - st.NetworkProtocol, pkt.NetworkProtocolNumber)) - } - +func snatAction(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -200,16 +192,8 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleDrop, 0 } - switch hook { - case Postrouting, Input: - case Prerouting, Output, Forward: - panic(fmt.Sprintf("%s not supported", hook)) - default: - panic(fmt.Sprintf("%s unrecognized", hook)) - } - - port := st.Port - + // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a + // different port. if port == 0 { switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: @@ -228,13 +212,69 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou // tracking. // // Does nothing if the protocol does not support connection tracking. - if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) + if conn := ct.insertSNATConn(pkt, hook, port, address); conn != nil { + conn.handlePacket(pkt, hook, dirOriginal, r) } return RuleAccept, 0 } +// Action implements Target.Action. +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if st.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + st.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Postrouting, Input: + case Prerouting, Output, Forward: + panic(fmt.Sprintf("%s not supported", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + return snatAction(pkt, ct, hook, r, st.Port, st.Addr) +} + +// MasqueradeTarget modifies the source port/IP in the outgoing packets. +type MasqueradeTarget struct { + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if mt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + mt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Postrouting: + case Prerouting, Input, Forward, Output: + panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + // addressEP is expected to be set for the postrouting hook. + ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */) + if ep == nil { + // No address exists that we can use as a source address. + return RuleDrop, 0 + } + + address := ep.AddressWithPrefix().Address + ep.DecRef() + return snatAction(pkt, ct, hook, r, 0 /* port */, address) +} + func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { if updateSRCFields { if fullChecksum { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index bf248ef20..456b0cf80 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -425,13 +425,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// DeleteFront removes count from the beginning of d. It panics if count > -// d.Size(). All backing storage references after the front of the d are -// invalidated. -func (d PacketData) DeleteFront(count int) { - if !d.pk.buf.Remove(d.pk.dataOffset(), count) { - panic("count > d.Size()") +// Consume is the same as PullUp except that is additionally consumes the +// returned bytes. Subsequent PullUp or Consume will not return these bytes. +func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) { + v, ok := d.PullUp(size) + if ok { + d.pk.consumed += size } + return v, ok } // CapLength reduces d to at most length bytes. diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 87b023445..c376ed1a1 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) { } } -func TestPacketBufferClone(t *testing.T) { - data := concatViews(makeView(20), makeView(30), makeView(40)) - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - bytesToDelete := 30 - originalSize := data.Size() - - clonedPks := []*PacketBuffer{ - pk.Clone(), - pk.CloneToInbound(), - } - pk.Data().DeleteFront(bytesToDelete) - if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want { - t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) - } - for _, clonedPk := range clonedPks { - if got := clonedPk.Data().Size(); got != originalSize { - t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) - } - } -} - func TestPacketHeaderConsume(t *testing.T) { for _, test := range []struct { name string @@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) { } }) - // DeleteFront + // Consume. for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().DeleteFront(n) + v, ok := pkt.Data().Consume(n) + if !ok { + t.Fatalf("Consume failed") + } + if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) { + t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want) + } checkData(t, pkt, []byte(tc.data)[n:]) }) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index cd4137794..c23e91702 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().Consume(fakeNetHeaderLen) if !ok { return } - // DeleteFront invalidates slices. Make a copy before trimming. - nb := append([]byte(nil), hdr...) - pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), // Nothing checks the error. nil, /* transport error */ pkt, diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index bdf4a64b9..f01e2b128 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -1197,24 +1197,15 @@ func TestSNAT(t *testing.T) { tests := []struct { name string + netProto tcpip.NetworkProtocolNumber epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses }{ { - name: "IPv4 host1 server with host2 client", + name: "IPv4 host1 server with host2 client", + netProto: ipv4.ProtocolNumber, epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { t.Helper() - ipt := routerStack.IPTables() - filter := ipt.GetTable(stack.NATID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Postrouting] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} - filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err) - } - ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) return endpointAndAddresses{ @@ -1231,21 +1222,11 @@ func TestSNAT(t *testing.T) { }, }, { - name: "IPv6 host1 server with host2 client", + name: "IPv6 host1 server with host2 client", + netProto: ipv6.ProtocolNumber, epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { t.Helper() - ipt := routerStack.IPTables() - filter := ipt.GetTable(stack.NATID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Postrouting] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} - filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err) - } - ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) return endpointAndAddresses{ @@ -1324,120 +1305,165 @@ func TestSNAT(t *testing.T) { }, } + setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, target stack.Target) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + filter := ipt.GetTable(stack.NATID, ipv6) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} + filter.Rules[ruleIdx].Target = target + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } + } + + natTypes := []struct { + name string + setupNAT func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber, tcpip.Address) + }{ + { + name: "SNAT", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) { + t.Helper() + + setupNAT(t, s, netProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: natToAddr}) + }, + }, + { + name: "Masquerade", + setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) { + t.Helper() + + setupNAT(t, s, netProto, &stack.MasqueradeTarget{NetworkProtocol: netProto}) + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { for _, subTest := range subTests { t.Run(subTest.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - } + for _, natType := range natTypes { + t.Run(natType.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } - host1Stack := stack.New(stackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) - epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) - serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} - if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { - t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) - } - clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} - if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { - t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) - } + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) - if subTest.setupServer != nil { - subTest.setupServer(t, epsAndAddrs.serverEP) - } - { - err := epsAndAddrs.clientEP.Connect(serverAddr) - if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) - } - } - nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr} - if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { - t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) - } else { - nattedClientAddr.Port = addr.Port - } + natType.setupNAT(t, routerStack, test.netProto, epsAndAddrs.nattedClientAddr) - serverEP := epsAndAddrs.serverEP - serverCH := epsAndAddrs.serverReadableCH - if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil { - defer ep.Close() - serverEP = ep - serverCH = ch - } + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } - write := func(ep tcpip.Endpoint, data []byte) { - t.Helper() - - var r bytes.Reader - r.Reset(data) - var wOpts tcpip.WriteOptions - n, err := ep.Write(&r, wOpts) - if err != nil { - t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) - } - } + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } + } + nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr} + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + nattedClientAddr.Port = addr.Port + } - read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { - t.Helper() - - var buf bytes.Buffer - var res tcpip.ReadResult - for { - var err tcpip.Error - opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} - res, err = ep.Read(&buf, opts) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-ch - continue + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch } - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + var r bytes.Reader + r.Reset(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(&r, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } } - break - } - - readResult := tcpip.ReadResult{ - Count: len(data), - Total: len(data), - } - if subTest.needRemoteAddr { - readResult.RemoteAddr = expectedFrom - } - if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes(), data); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - - if t.Failed() { - t.FailNow() - } - } - { - data := []byte{1, 2, 3, 4} - write(epsAndAddrs.clientEP, data) - read(serverCH, serverEP, data, nattedClientAddr) - } + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + var buf bytes.Buffer + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } - { - data := []byte{5, 6, 7, 8, 9, 10, 11, 12} - write(serverEP, data) - read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) + { + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + read(serverCH, serverEP, data, nattedClientAddr) + } + + { + data := []byte{5, 6, 7, 8, 9, 10, 11, 12} + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) + } + }) } }) } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7115d0a12..95fcdc1b6 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -15,12 +15,12 @@ package tcp import ( + "container/list" "crypto/sha1" "encoding/binary" "fmt" "hash" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -100,18 +100,6 @@ type listenContext struct { // netProto indicates the network protocol(IPv4/v6) for the listening // endpoint. netProto tcpip.NetworkProtocolNumber - - // pendingMu protects pendingEndpoints. This should only be accessed - // by the listening endpoint's worker goroutine. - // - // Lock Ordering: listenEP.workerMu -> pendingMu - pendingMu sync.Mutex - // pending is used to wait for all pendingEndpoints to finish when - // a socket is closed. - pending sync.WaitGroup - // pendingEndpoints is a map of all endpoints for which a handshake is - // in progress. - pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -122,14 +110,13 @@ func timeStamp(clock tcpip.Clock) uint32 { // newListenContext creates a new listen context. func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stk, - protocol: protocol, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6Only: v6Only, - netProto: netProto, - listenEP: listenEP, - pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), + stack: stk, + protocol: protocol, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6Only: v6Only, + netProto: netProto, + listenEP: listenEP, } for i := range l.nonce { @@ -265,7 +252,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu return nil, &tcpip.ErrConnectionAborted{} } - l.addPendingEndpoint(ep) // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. @@ -275,8 +261,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu ep.mu.Unlock() ep.Close() - l.removePendingEndpoint(ep) - return nil, &tcpip.ErrConnectionAborted{} } @@ -295,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - ep.drainClosingSegmentQueue() return nil, err @@ -336,38 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, return ep, nil } -func (l *listenContext) addPendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - l.pendingEndpoints[n.TransportEndpointInfo.ID] = n - l.pending.Add(1) - l.pendingMu.Unlock() -} - -func (l *listenContext) removePendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) - l.pending.Done() - l.pendingMu.Unlock() -} - -func (l *listenContext) closeAllPendingEndpoints() { - l.pendingMu.Lock() - for _, n := range l.pendingEndpoints { - n.notifyProtocolGoroutine(notifyClose) - } - l.pendingMu.Unlock() - l.pending.Wait() -} - // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() e.Close() e.notifyAborted() - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.drainClosingSegmentQueue() e.h = nil } @@ -378,9 +332,6 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.isConnectNotified = true // Update the receive window scaling. We can't do it before the @@ -444,101 +395,30 @@ func (e *endpoint) notifyAborted() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -// handleSynSegment is called in its own goroutine once the listening endpoint -// receives a SYN segment. It is responsible for completing the handshake and -// queueing the new endpoint for acceptance. -// -// A limited number of these goroutines are allowed before TCP starts using SYN -// cookies to accept connections. -// -// +checklocks:e.mu -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error { - defer s.decRef() - - h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) - if err != nil { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - atomic.AddInt32(&e.synRcvdCount, -1) - return err - } - - go func() { - // Note that startHandshake returns a locked endpoint. The - // force call here just makes it so. - if err := h.complete(); err != nil { // +checklocksforce - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - ctx.cleanupFailedHandshake(h) - atomic.AddInt32(&e.synRcvdCount, -1) - return - } - ctx.cleanupCompletedHandshake(h) - h.ep.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - // Deliver the endpoint to the accept queue. - e.mu.Lock() - e.pendingAccepted.Add(1) - e.mu.Unlock() - defer e.pendingAccepted.Done() - - // Drop the lock before notifying to avoid deadlock in user-specified - // callbacks. - delivered := func() bool { - e.acceptMu.Lock() - defer e.acceptMu.Unlock() - for { - if e.accepted == (accepted{}) { - // If the listener has transitioned out of the listen state (accepted - // is the zero value), the new endpoint is reset instead. - return false - } - if e.accepted.acceptQueueIsFullLocked() { - e.acceptCond.Wait() - continue - } - - e.accepted.endpoints.PushBack(h.ep) - atomic.AddInt32(&e.synRcvdCount, -1) - return true - } - }() - - if delivered { - e.waiterQueue.Notify(waiter.ReadableEvents) - } else { - h.ep.notifyProtocolGoroutine(notifyReset) - } - }() - - return nil -} - -func (e *endpoint) synRcvdBacklogFull() bool { - e.acceptMu.Lock() - acceptedCap := e.accepted.cap - e.acceptMu.Unlock() - // The capacity of the accepted queue would always be one greater than the - // listen backlog. But, the SYNRCVD connections count is always checked - // against the listen backlog value for Linux parity reason. - // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 - // - // We maintain an equality check here as the synRcvdCount is incremented - // and compared only from a single listener context and the capacity of - // the accepted queue can only increase by a new listen call. - return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 -} - func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := e.accepted.acceptQueueIsFullLocked() + full := e.acceptQueue.isFull() e.acceptMu.Unlock() return full } -func (a *accepted) acceptQueueIsFullLocked() bool { - return a.endpoints.Len() == a.cap +// +stateify savable +type acceptQueue struct { + // NB: this could be an endpointList, but ilist only permits endpoints to + // belong to one list at a time, and endpoints are already stored in the + // dispatcher's list. + endpoints list.List `state:".([]*endpoint)"` + + // pendingEndpoints is a set of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[*endpoint]struct{} + + // capacity is the maximum number of endpoints that can be in endpoints. + capacity int +} + +func (a *acceptQueue) isFull() bool { + return a.endpoints.Len() == a.capacity } // handleListenSegment is called when a listening endpoint receives a segment @@ -571,20 +451,96 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } - alwaysUseSynCookies := func() bool { + opts := parseSynSegmentOptions(s) + + useSynCookies, err := func() (bool, tcpip.Error) { var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) } - return bool(alwaysUseSynCookies) - }() + if alwaysUseSynCookies { + return true, nil + } + e.acceptMu.Lock() + defer e.acceptMu.Unlock() - opts := parseSynSegmentOptions(s) - if !alwaysUseSynCookies && !e.synRcvdBacklogFull() { - s.incRef() - atomic.AddInt32(&e.synRcvdCount, 1) - return e.handleSynSegment(ctx, s, opts) + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 { + return true, nil + } + + h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) + if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return false, err + } + + e.acceptQueue.pendingEndpoints[h.ep] = struct{}{} + e.pendingAccepted.Add(1) + + go func() { + defer func() { + e.pendingAccepted.Done() + + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + delete(e.acceptQueue.pendingEndpoints, h.ep) + }() + + // Note that startHandshake returns a locked endpoint. The force call + // here just makes it so. + if err := h.complete(); err != nil { // +checklocksforce + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + ctx.cleanupFailedHandshake(h) + return + } + ctx.cleanupCompletedHandshake(h) + h.ep.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + + // Deliver the endpoint to the accept queue. + // + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + // The listener is transitioning out of the Listen state; bail. + if e.acceptQueue.capacity == 0 { + return false + } + if e.acceptQueue.isFull() { + e.acceptCond.Wait() + continue + } + + e.acceptQueue.endpoints.PushBack(h.ep) + return true + } + }() + + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + h.ep.notifyProtocolGoroutine(notifyReset) + } + }() + + return false, nil + }() + if err != nil { + return err } + if !useSynCookies { + return nil + } + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err @@ -631,7 +587,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // if there is an error), to guarantee that we will keep our spot in the // queue even if another handshake from the syn queue completes. e.acceptMu.Lock() - if e.accepted.acceptQueueIsFullLocked() { + if e.acceptQueue.isFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be // retransmitted by the sender anyway and we can @@ -769,7 +725,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() // Deliver the endpoint to the accept queue. - e.accepted.endpoints.PushBack(n) + e.acceptQueue.endpoints.PushBack(n) e.acceptMu.Unlock() e.waiterQueue.Notify(waiter.ReadableEvents) @@ -789,14 +745,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto) defer func() { - // Mark endpoint as closed. This will prevent goroutines running - // handleSynSegment() from attempting to queue new connections - // to the endpoint. e.setEndpointState(StateClose) - // Close any endpoints in SYN-RCVD state. - ctx.closeAllPendingEndpoints() - // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 407ab2664..b60f9becf 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,7 +15,6 @@ package tcp import ( - "container/list" "encoding/binary" "fmt" "io" @@ -205,6 +204,8 @@ type SACKInfo struct { } // ReceiveErrors collect segment receive errors within transport layer. +// +// +stateify savable type ReceiveErrors struct { tcpip.ReceiveErrors @@ -234,6 +235,8 @@ type ReceiveErrors struct { } // SendErrors collect segment send errors within the transport layer. +// +// +stateify savable type SendErrors struct { tcpip.SendErrors @@ -257,6 +260,8 @@ type SendErrors struct { } // Stats holds statistics about the endpoint. +// +// +stateify savable type Stats struct { // SegmentsReceived is the number of TCP segments received that // the transport layer successfully parsed. @@ -311,18 +316,6 @@ type rcvQueueInfo struct { rcvQueue segmentList `state:"wait"` } -// +stateify savable -type accepted struct { - // NB: this could be an endpointList, but ilist only permits endpoints to - // belong to one list at a time, and endpoints are already stored in the - // dispatcher's list. - endpoints list.List `state:".([]*endpoint)"` - - // cap is the maximum number of endpoints that can be in the accepted endpoint - // list. - cap int -} - // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -338,7 +331,7 @@ type accepted struct { // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> Protects e.accepted. +// e.acceptMu -> Protects e.acceptQueue. // e.rcvQueueMu -> Protects e.rcvQueue and associated fields. // e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. @@ -502,10 +495,6 @@ type endpoint struct { // and dropped when it is. segmentQueue segmentQueue `state:"wait"` - // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state; this is only accessed atomically. - synRcvdCount int32 - // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 @@ -579,7 +568,7 @@ type endpoint struct { // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. // +checklocks:acceptMu - accepted accepted + acceptQueue acceptQueue // The following are only used from the protocol goroutine, and // therefore don't need locks to protect them. @@ -612,8 +601,7 @@ type endpoint struct { gso stack.GSO - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats Stats `state:"nosave"` + stats Stats // tcpLingerTimeout is the maximum amount of a time a socket // a socket stays in TIME_WAIT state before being marked @@ -910,7 +898,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Check if there's anything in the accepted queue. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() - if e.accepted.endpoints.Len() != 0 { + if e.acceptQueue.endpoints.Len() != 0 { result |= waiter.ReadableEvents } e.acceptMu.Unlock() @@ -1093,20 +1081,20 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - acceptedCopy := e.accepted - e.accepted = accepted{} - e.acceptMu.Unlock() - - if acceptedCopy == (accepted{}) { - return + // Close any endpoints in SYN-RCVD state. + for n := range e.acceptQueue.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) } - - e.acceptCond.Broadcast() - + e.acceptQueue.pendingEndpoints = nil // Reset all connections that are waiting to be accepted. - for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { + for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() { n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) } + e.acceptQueue.endpoints.Init() + e.acceptMu.Unlock() + + e.acceptCond.Broadcast() + // Wait for reset of all endpoints that are still waiting to be delivered to // the now closed accepted. e.pendingAccepted.Wait() @@ -2498,22 +2486,23 @@ func (e *endpoint) listen(backlog int) tcpip.Error { if e.EndpointState() == StateListen && !e.closed { e.acceptMu.Lock() defer e.acceptMu.Unlock() - if e.accepted == (accepted{}) { - // listen is called after shutdown. - e.accepted.cap = backlog - e.shutdownFlags = 0 - e.rcvQueueInfo.rcvQueueMu.Lock() - e.rcvQueueInfo.RcvClosed = false - e.rcvQueueInfo.rcvQueueMu.Unlock() - } else { - // Adjust the size of the backlog iff we can fit - // existing pending connections into the new one. - if e.accepted.endpoints.Len() > backlog { - return &tcpip.ErrInvalidEndpointState{} - } - e.accepted.cap = backlog + + // Adjust the size of the backlog iff we can fit + // existing pending connections into the new one. + if e.acceptQueue.endpoints.Len() > backlog { + return &tcpip.ErrInvalidEndpointState{} + } + e.acceptQueue.capacity = backlog + + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) } + e.shutdownFlags = 0 + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = false + e.rcvQueueInfo.rcvQueueMu.Unlock() + // Notify any blocked goroutines that they can attempt to // deliver endpoints again. e.acceptCond.Broadcast() @@ -2548,8 +2537,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() - if e.accepted == (accepted{}) { - e.accepted.cap = backlog + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) + } + if e.acceptQueue.capacity == 0 { + e.acceptQueue.capacity = backlog } e.acceptMu.Unlock() @@ -2589,8 +2581,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. // Get the new accepted endpoint. var n *endpoint e.acceptMu.Lock() - if element := e.accepted.endpoints.Front(); element != nil { - n = e.accepted.endpoints.Remove(element).(*endpoint) + if element := e.acceptQueue.endpoints.Front(); element != nil { + n = e.acceptQueue.endpoints.Remove(element).(*endpoint) } e.acceptMu.Unlock() if n == nil { diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 381f4474d..94072a115 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() { } // saveEndpoints is invoked by stateify. -func (a *accepted) saveEndpoints() []*endpoint { +func (a *acceptQueue) saveEndpoints() []*endpoint { acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { acceptedEndpoints[i] = e.Value.(*endpoint) @@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint { } // loadEndpoints is invoked by stateify. -func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { +func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) { for _, ep := range acceptedEndpoints { a.endpoints.PushBack(ep) } @@ -252,7 +252,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() bind() e.acceptMu.Lock() - backlog := e.accepted.cap + backlog := e.acceptQueue.capacity e.acceptMu.Unlock() if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) |