summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/ilist/list.go33
-rw-r--r--pkg/safemem/seq_test.go21
-rw-r--r--pkg/safemem/seq_unsafe.go3
-rw-r--r--pkg/seccomp/seccomp_test.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go9
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go80
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go11
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go65
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids.go (renamed from pkg/sentry/platform/ring0/pagetables/pcids_x86.go)2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go6
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go23
-rw-r--r--pkg/tcpip/transport/tcp/connect.go11
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go4
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go3
21 files changed, 225 insertions, 143 deletions
diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go
index 019caadca..f3a609b57 100644
--- a/pkg/ilist/list.go
+++ b/pkg/ilist/list.go
@@ -88,8 +88,9 @@ func (l *List) Back() Element {
// PushFront inserts the element e at the front of list l.
func (l *List) PushFront(e Element) {
- ElementMapper{}.linkerFor(e).SetNext(l.head)
- ElementMapper{}.linkerFor(e).SetPrev(nil)
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
if l.head != nil {
ElementMapper{}.linkerFor(l.head).SetPrev(e)
@@ -102,8 +103,9 @@ func (l *List) PushFront(e Element) {
// PushBack inserts the element e at the back of list l.
func (l *List) PushBack(e Element) {
- ElementMapper{}.linkerFor(e).SetNext(nil)
- ElementMapper{}.linkerFor(e).SetPrev(l.tail)
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
if l.tail != nil {
ElementMapper{}.linkerFor(l.tail).SetNext(e)
@@ -132,10 +134,14 @@ func (l *List) PushBackList(m *List) {
// InsertAfter inserts e after b.
func (l *List) InsertAfter(b, e Element) {
- a := ElementMapper{}.linkerFor(b).Next()
- ElementMapper{}.linkerFor(e).SetNext(a)
- ElementMapper{}.linkerFor(e).SetPrev(b)
- ElementMapper{}.linkerFor(b).SetNext(e)
+ bLinker := ElementMapper{}.linkerFor(b)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
if a != nil {
ElementMapper{}.linkerFor(a).SetPrev(e)
@@ -146,10 +152,13 @@ func (l *List) InsertAfter(b, e Element) {
// InsertBefore inserts e before a.
func (l *List) InsertBefore(a, e Element) {
- b := ElementMapper{}.linkerFor(a).Prev()
- ElementMapper{}.linkerFor(e).SetNext(a)
- ElementMapper{}.linkerFor(e).SetPrev(b)
- ElementMapper{}.linkerFor(a).SetPrev(e)
+ aLinker := ElementMapper{}.linkerFor(a)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
if b != nil {
ElementMapper{}.linkerFor(b).SetNext(e)
diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go
index eba4bb535..de34005e9 100644
--- a/pkg/safemem/seq_test.go
+++ b/pkg/safemem/seq_test.go
@@ -20,6 +20,27 @@ import (
"testing"
)
+func TestBlockSeqOfEmptyBlock(t *testing.T) {
+ bs := BlockSeqOf(Block{})
+ if !bs.IsEmpty() {
+ t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs)
+ }
+}
+
+func TestBlockSeqOfNonemptyBlock(t *testing.T) {
+ b := BlockFromSafeSlice(make([]byte, 1))
+ bs := BlockSeqOf(b)
+ if bs.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs)
+ }
+ if head := bs.Head(); head != b {
+ t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b)
+ }
+ if tail := bs.Tail(); !tail.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail)
+ }
+}
+
type blockSeqTest struct {
desc string
diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go
index dcdfc9600..f5f0574f8 100644
--- a/pkg/safemem/seq_unsafe.go
+++ b/pkg/safemem/seq_unsafe.go
@@ -56,6 +56,9 @@ type BlockSeq struct {
// BlockSeqOf returns a BlockSeq representing the single Block b.
func BlockSeqOf(b Block) BlockSeq {
+ if b.length == 0 {
+ return BlockSeq{}
+ }
bs := BlockSeq{
data: b.start,
length: -1,
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index da5a5e4b2..88766f33b 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -451,7 +451,7 @@ func TestRandom(t *testing.T) {
}
}
- fmt.Printf("Testing filters: %v", syscallRules)
+ t.Logf("Testing filters: %v", syscallRules)
instrs, err := BuildProgram([]RuleSet{
RuleSet{
Rules: syscallRules,
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index daecc4ffe..1922ff08c 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -259,8 +259,8 @@ func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, ui
// RemoveXattr implements fs.InodeOperations.RemoveXattr.
func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error {
- i.mu.RLock()
- defer i.mu.RUnlock()
+ i.mu.Lock()
+ defer i.mu.Unlock()
if _, ok := i.xattrs[name]; ok {
delete(i.xattrs, name)
return nil
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 54c1031a7..e95209661 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -361,8 +361,15 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro
rw.d.handleMu.RLock()
if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off)
- rw.d.handleMu.RUnlock()
rw.off += n
+ rw.d.dataMu.Lock()
+ if rw.off > rw.d.size {
+ atomic.StoreUint64(&rw.d.size, rw.off)
+ // The remote file's size will implicitly be extended to the correct
+ // value when we write back to it.
+ }
+ rw.d.dataMu.Unlock()
+ rw.d.handleMu.RUnlock()
return n, err
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index e99798c56..cd74945e7 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -21,6 +21,7 @@ import (
"strings"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -183,13 +184,76 @@ func enableCpuidFault() {
// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
// Ref attachedThread() for more detail.
-func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet {
- return append(rules, seccomp.RuleSet{
- Rules: seccomp.SyscallRules{
- syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
+ rules = append(rules,
+ // Rules for trapping vsyscall access.
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_GETTIMEOFDAY: {},
+ syscall.SYS_TIME: {},
+ unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
},
- },
- Action: linux.SECCOMP_RET_ALLOW,
- })
+ Action: linux.SECCOMP_RET_TRAP,
+ Vsyscall: true,
+ })
+ if defaultAction != linux.SECCOMP_RET_ALLOW {
+ rules = append(rules,
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ })
+ }
+ return rules
+}
+
+// probeSeccomp returns true iff seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8. This check is dynamic
+// because kernels have be backported behavior.
+//
+// See createStub for more information.
+//
+// Precondition: the runtime OS thread must be locked.
+func probeSeccomp() bool {
+ // Create a completely new, destroyable process.
+ t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
+ if err != nil {
+ panic(fmt.Sprintf("seccomp probe failed: %v", err))
+ }
+ defer t.destroy()
+
+ // Set registers to the yield system call. This call is not allowed
+ // by the filters specified in the attachThread function.
+ regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
+ if err := t.setRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs failed: %v", err))
+ }
+
+ for {
+ // Attempt an emulation.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ sig := t.wait(stopped)
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Did the seccomp errno hook already run? This would
+ // indicate that seccomp is first in line and we're
+ // less than 4.8.
+ if err := t.getRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
+ }
+ if _, err := syscallReturnValue(&regs); err == nil {
+ // The seccomp errno mode ran first, and reset
+ // the error in the registers.
+ return false
+ }
+ // The seccomp hook did not run yet, and therefore it
+ // is safe to use RET_KILL mode for dispatched calls.
+ return true
+ }
+ }
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index 7b975137f..7f5c393f0 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -160,6 +160,15 @@ func enableCpuidFault() {
// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
// Ref attachedThread() for more detail.
-func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet {
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
return rules
}
+
+// probeSeccomp returns true if seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8.
+//
+// On arm64, the support of PTRACE_SYSEMU was added in the 5.3 kernel, so
+// probeSeccomp can always return true.
+func probeSeccomp() bool {
+ return true
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index 74968dfdf..2ce528601 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -20,7 +20,6 @@ import (
"fmt"
"syscall"
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
@@ -30,54 +29,6 @@ import (
const syscallEvent syscall.Signal = 0x80
-// probeSeccomp returns true iff seccomp is run after ptrace notifications,
-// which is generally the case for kernel version >= 4.8. This check is dynamic
-// because kernels have be backported behavior.
-//
-// See createStub for more information.
-//
-// Precondition: the runtime OS thread must be locked.
-func probeSeccomp() bool {
- // Create a completely new, destroyable process.
- t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
- if err != nil {
- panic(fmt.Sprintf("seccomp probe failed: %v", err))
- }
- defer t.destroy()
-
- // Set registers to the yield system call. This call is not allowed
- // by the filters specified in the attachThread function.
- regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
- if err := t.setRegs(&regs); err != nil {
- panic(fmt.Sprintf("ptrace set regs failed: %v", err))
- }
-
- for {
- // Attempt an emulation.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
- panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
- }
-
- sig := t.wait(stopped)
- if sig == (syscallEvent | syscall.SIGTRAP) {
- // Did the seccomp errno hook already run? This would
- // indicate that seccomp is first in line and we're
- // less than 4.8.
- if err := t.getRegs(&regs); err != nil {
- panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
- }
- if _, err := syscallReturnValue(&regs); err == nil {
- // The seccomp errno mode ran first, and reset
- // the error in the registers.
- return false
- }
- // The seccomp hook did not run yet, and therefore it
- // is safe to use RET_KILL mode for dispatched calls.
- return true
- }
- }
-}
-
// createStub creates a fresh stub processes.
//
// Precondition: the runtime OS thread must be locked.
@@ -123,18 +74,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// stub and all its children. This is used to create child stubs
// (below), so we must include the ability to fork, but otherwise lock
// down available calls only to what is needed.
- rules := []seccomp.RuleSet{
- // Rules for trapping vsyscall access.
- {
- Rules: seccomp.SyscallRules{
- syscall.SYS_GETTIMEOFDAY: {},
- syscall.SYS_TIME: {},
- unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
- },
- Action: linux.SECCOMP_RET_TRAP,
- Vsyscall: true,
- },
- }
+ rules := []seccomp.RuleSet{}
if defaultAction != linux.SECCOMP_RET_ALLOW {
rules = append(rules, seccomp.RuleSet{
Rules: seccomp.SyscallRules{
@@ -173,9 +113,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
},
Action: linux.SECCOMP_RET_ALLOW,
})
-
- rules = appendArchSeccompRules(rules)
}
+ rules = appendArchSeccompRules(rules, defaultAction)
instrs, err := seccomp.BuildProgram(rules, defaultAction)
if err != nil {
return nil, err
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 4f2406ce3..581841555 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -80,7 +80,7 @@ go_library(
"pagetables_amd64.go",
"pagetables_arm64.go",
"pagetables_x86.go",
- "pcids_x86.go",
+ "pcids.go",
"walker_amd64.go",
"walker_arm64.go",
"walker_empty.go",
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids.go
index e199bae18..9206030bf 100644
--- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pcids.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
-
package pagetables
import (
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 48c268bfa..13a9a60b4 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -712,6 +712,10 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
+ }
+
family := usermem.ByteOrder.Uint16(sockaddr)
var addr tcpip.FullAddress
@@ -2663,7 +2667,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
// Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
v += len(s.readView)
+ s.readMu.Unlock()
if v > math.MaxInt32 {
v = math.MaxInt32
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 46d3a6646..3e6196aee 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -451,7 +451,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
for _, r := range primaryAddrs {
// If r is not valid for outgoing connections, it is not a valid endpoint.
- if !r.isValidForOutgoing() {
+ if !r.isValidForOutgoingRLocked() {
continue
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index ebb6c5e3b..13354d884 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -551,11 +551,13 @@ type TransportEndpointInfo struct {
RegisterNICID tcpip.NICID
}
-// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address
-// and returns the network protocol number to be used to communicate with the
-// specified address. It returns an error if the passed address is incompatible
-// with the receiver.
-func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6
+// address and returns the network protocol number to be used to communicate
+// with the specified address. It returns an error if the passed address is
+// incompatible with the receiver.
+//
+// Preconditon: the parent endpoint mu must be held while calling this method.
+func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 426da1ee6..2a396e9bc 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -291,15 +291,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- toCopy := *to
- to = &toCopy
- netProto, err := e.checkV4Mapped(to)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- // Find the enpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
+ // Find the endpoint.
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -480,13 +478,14 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
})
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -517,7 +516,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -630,7 +629,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index cd247f3e1..ae4f3f3a9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.state = handshakeSynRcvd
h.ep.mu.Lock()
ttl := h.ep.ttl
+ amss := h.ep.amss
h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
@@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
if ttl == 0 {
ttl = s.route.DefaultTTL()
@@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+ h.ep.mu.RLock()
+ amss := h.ep.amss
+ h.ep.mu.RUnlock()
+
h.resetState()
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
@@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
@@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
+ h.ep.mu.Lock()
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
@@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
+ h.ep.mu.Unlock()
// Execute is also called in a listen context so we want to make sure we
// only send the TS/SACK option when we received the TS/SACK in the
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 9e72730bd..40cc664c0 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -959,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int {
// ModerateRecvBuf adjusts the receive buffer and the advertised window
// based on the number of bytes copied to user space.
func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.mu.RLock()
e.rcvListMu.Lock()
if e.rcvAutoParams.disabled {
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
return
}
now := time.Now()
if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
e.rcvAutoParams.copied += copied
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
return
}
prevRTTCopied := e.rcvAutoParams.copied + copied
@@ -1008,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvBufSize = rcvWnd
availAfter := e.receiveBufferAvailableLocked()
mask := uint32(notifyReceiveWindowChanged)
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.notifyProtocolGoroutine(mask)
@@ -1023,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvAutoParams.measureTime = now
e.rcvAutoParams.copied = 0
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
}
// IPTables implements tcpip.Endpoint.IPTables.
@@ -1052,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
-
e.mu.RUnlock()
if err == tcpip.ErrClosedForReceive {
@@ -1085,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1303,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// windowCrossedACKThreshold checks if the receive window to be announced now
-// would be under aMSS or under half receive buffer, whichever smaller. This is
-// useful as a receive side silly window syndrome prevention mechanism. If
+// windowCrossedACKThresholdLocked checks if the receive window to be announced
+// now would be under aMSS or under half receive buffer, whichever smaller. This
+// is useful as a receive side silly window syndrome prevention mechanism. If
// window grows to reasonable value, we should send ACK to the sender to inform
// the rx space is now large. We also want ensure a series of small read()'s
// won't trigger a flood of spurious tiny ACK's.
@@ -1316,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// crossed will be true if the window size crossed the ACK threshold.
// above will be true if the new window is >= ACK threshold and false
// otherwise.
-func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
+//
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
newAvail := e.receiveBufferAvailableLocked()
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
@@ -1379,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
mask := uint32(notifyReceiveWindowChanged)
+ e.mu.RLock()
e.rcvListMu.Lock()
// Make sure the receive buffer size allows us to send a
@@ -1405,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Immediately send an ACK to uncork the sender silly window
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
-
+ e.mu.RUnlock()
e.notifyProtocolGoroutine(mask)
return nil
@@ -1868,13 +1874,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -1904,7 +1911,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
connectingAddr := addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -2270,7 +2277,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
e.BindAddr = addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -2414,13 +2421,14 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// to be read, or when the connection is closed for receiving (in which case
// s will be nil).
func (e *endpoint) readyToRead(s *segment) {
+ e.mu.RLock()
e.rcvListMu.Lock()
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
// Increase counter if the receive window falls down below MSS
// or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
e.rcvList.PushBack(s)
@@ -2428,7 +2436,7 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
-
+ e.mu.RUnlock()
e.waiterQueue.Notify(waiter.EventIn)
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 958f03ac1..d80aff1b6 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -195,6 +195,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
for i := first; i < len(r.pendingRcvdSegments); i++ {
r.pendingRcvdSegments[i].decRef()
+ // Note that slice truncation does not allow garbage collection of
+ // truncated items, thus truncated items must be set to nil to avoid
+ // memory leaks.
+ r.pendingRcvdSegments[i] = nil
}
r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
index 9fd061d7d..e28f213ba 100644
--- a/pkg/tcpip/transport/tcp/segment_heap.go
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -41,6 +41,7 @@ func (h *segmentHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
+ old[n-1] = nil
*h = old[:n-1]
return x
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1c6a600b8..0af4514e1 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -443,19 +443,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- netProto, err := e.checkV4Mapped(to)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- r, _, err := e.connectRoute(nicID, *to, netProto)
+ r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
return 0, nil, err
}
defer r.Release()
route = &r
- dstPort = to.Port
+ dstPort = dst.Port
}
if route.IsResolutionRequired() {
@@ -566,7 +566,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- netProto, err := e.checkV4Mapped(&fa)
+ fa, netProto, err := e.checkV4MappedLocked(fa)
if err != nil {
return err
}
@@ -927,13 +927,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -981,10 +982,6 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- netProto, err := e.checkV4Mapped(&addr)
- if err != nil {
- return err
- }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -1012,6 +1009,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
r, nicID, err := e.connectRoute(nicID, addr, netProto)
if err != nil {
return err
@@ -1139,7 +1141,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 43fb047ed..466bd9381 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
e.stack = s
for _, m := range e.multicastMemberships {