summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--g3doc/README.md2
-rw-r--r--pkg/context/context.go20
-rw-r--r--pkg/p9/file.go35
-rw-r--r--pkg/p9/handlers.go6
-rw-r--r--pkg/p9/p9test/BUILD2
-rw-r--r--pkg/p9/p9test/p9test.go1
-rw-r--r--pkg/p9/server.go9
-rw-r--r--pkg/sentry/kernel/auth/context.go13
-rw-r--r--pkg/sentry/kernel/mq/mq.go2
-rw-r--r--pkg/sentry/kernel/shm/shm.go4
-rw-r--r--pkg/sentry/kernel/task_context.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go4
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/conntrack.go175
-rw-r--r--pkg/tcpip/stack/iptables_test.go220
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go270
-rw-r--r--runsc/boot/network.go11
-rw-r--r--runsc/fsgofer/fsgofer.go11
-rw-r--r--runsc/sandbox/network.go18
-rw-r--r--test/syscalls/BUILD4
-rw-r--r--test/syscalls/linux/BUILD15
-rw-r--r--test/syscalls/linux/deleted.cc116
-rw-r--r--test/syscalls/linux/mount.cc34
-rw-r--r--test/util/fs_util.cc8
-rw-r--r--test/util/fs_util.h1
25 files changed, 753 insertions, 231 deletions
diff --git a/g3doc/README.md b/g3doc/README.md
index dc4179037..5e23aa5ec 100644
--- a/g3doc/README.md
+++ b/g3doc/README.md
@@ -23,7 +23,7 @@ links below to see detailed instructions for each of them:
gVisor provides a virtualized environment in order to sandbox containers. The
system interfaces normally implemented by the host kernel are moved into a
-distinct, per-sandbox application kernel in order to minimize the risk of an
+distinct, per-sandbox application kernel in order to minimize the risk of a
container escape exploit. gVisor does not introduce large fixed overheads
however, and still retains a process-like model with respect to resource
utilization.
diff --git a/pkg/context/context.go b/pkg/context/context.go
index f3031fc60..e86c14195 100644
--- a/pkg/context/context.go
+++ b/pkg/context/context.go
@@ -29,26 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
)
-type contextID int
-
-// Globally accessible values from a context. These keys are defined in the
-// context package to resolve dependency cycles by not requiring the caller to
-// import packages usually required to get these information.
-const (
- // CtxThreadGroupID is the current thread group ID when a context represents
- // a task context. The value is represented as an int32.
- CtxThreadGroupID contextID = iota
-)
-
-// ThreadGroupIDFromContext returns the current thread group ID when ctx
-// represents a task context.
-func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
- if tgid := ctx.Value(CtxThreadGroupID); tgid != nil {
- return tgid.(int32), true
- }
- return 0, false
-}
-
// A Context represents a thread of execution (hereafter "goroutine" to reflect
// Go idiosyncrasy). It carries state associated with the goroutine across API
// boundaries.
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index 8d6af2d6b..b4b556cb9 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -21,13 +21,37 @@ import (
"gvisor.dev/gvisor/pkg/fd"
)
+// AttacherOptions contains Attacher configuration.
+type AttacherOptions struct {
+ // SetAttrOnDeleted is set to true if it's safe to call File.SetAttr for
+ // deleted files.
+ SetAttrOnDeleted bool
+
+ // AllocateOnDeleted is set to true if it's safe to call File.Allocate for
+ // deleted files.
+ AllocateOnDeleted bool
+}
+
+// NoServerOptions partially implements Attacher with empty AttacherOptions.
+type NoServerOptions struct{}
+
+// ServerOptions implements Attacher.
+func (*NoServerOptions) ServerOptions() AttacherOptions {
+ return AttacherOptions{}
+}
+
// Attacher is provided by the server.
type Attacher interface {
// Attach returns a new File.
//
- // The client-side attach will be translate to a series of walks from
+ // The client-side attach will be translated to a series of walks from
// the file returned by this Attach call.
Attach() (File, error)
+
+ // ServerOptions returns configuration options for this attach point.
+ //
+ // This is never caller in the client-side.
+ ServerOptions() AttacherOptions
}
// File is a set of operations corresponding to a single node.
@@ -301,7 +325,7 @@ type File interface {
type DefaultWalkGetAttr struct{}
// WalkGetAttr implements File.WalkGetAttr.
-func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) {
+func (*DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) {
return nil, nil, AttrMask{}, Attr{}, unix.ENOSYS
}
@@ -309,7 +333,7 @@ func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, er
type DisallowClientCalls struct{}
// SetAttrClose implements File.SetAttrClose.
-func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error {
+func (*DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error {
panic("SetAttrClose should not be called on the server")
}
@@ -321,6 +345,11 @@ func (*DisallowServerCalls) Renamed(File, string) {
panic("Renamed should not be called on the client")
}
+// ServerOptions implements Attacher.
+func (*DisallowServerCalls) ServerOptions() AttacherOptions {
+ panic("ServerOptions should not be called on the client")
+}
+
// DefaultMultiGetAttr implements File.MultiGetAttr() on top of File.
func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
stats := make([]FullStat, 0, len(names))
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 2657081e3..c85af5e9e 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -178,7 +178,7 @@ func (t *Tsetattrclunk) handle(cs *connState) message {
// This might be technically incorrect, as it's possible that
// there were multiple links and you can still change the
// corresponding inode information.
- if ref.isDeleted() {
+ if !cs.server.options.SetAttrOnDeleted && ref.isDeleted() {
return unix.EINVAL
}
@@ -913,7 +913,7 @@ func (t *Tsetattr) handle(cs *connState) message {
// This might be technically incorrect, as it's possible that
// there were multiple links and you can still change the
// corresponding inode information.
- if ref.isDeleted() {
+ if !cs.server.options.SetAttrOnDeleted && ref.isDeleted() {
return unix.EINVAL
}
@@ -946,7 +946,7 @@ func (t *Tallocate) handle(cs *connState) message {
}
// We don't allow allocate on files that have been deleted.
- if ref.isDeleted() {
+ if !cs.server.options.AllocateOnDeleted && ref.isDeleted() {
return unix.EINVAL
}
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 9c1ada0cb..f3eb8468b 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -12,7 +12,7 @@ MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9"
# mockgen_reflect is a source file that contains mock generation code that
# imports the p9 package and generates a specification via reflection. The
# usual generation path must be split into two distinct parts because the full
-# source tree is not available to all build targets. Only declared depencies
+# source tree is not available to all build targets. Only declared dependencies
# are available (and even then, not the Go source files).
genrule(
name = "mockgen_reflect",
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index fd5ac3dbe..56939d100 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -307,6 +307,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) {
}
// Start the server, synchronized on exit.
+ h.Attacher.EXPECT().ServerOptions().Return(p9.AttacherOptions{}).Times(1)
server := p9.NewServer(h.Attacher)
h.wg.Add(1)
go func() {
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index 6428ad745..e7d129f9d 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -34,6 +34,8 @@ type Server struct {
// attacher provides the attach function.
attacher Attacher
+ options AttacherOptions
+
// pathTree is the full set of paths opened on this server.
//
// These may be across different connections, but rename operations
@@ -48,10 +50,15 @@ type Server struct {
renameMu sync.RWMutex
}
-// NewServer returns a new server.
+// NewServer returns a new server. attacher may be nil.
func NewServer(attacher Attacher) *Server {
+ opts := AttacherOptions{}
+ if attacher != nil {
+ opts = attacher.ServerOptions()
+ }
return &Server{
attacher: attacher,
+ options: opts,
pathTree: newPathNode(),
}
}
diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go
index c08d47787..2039a96ad 100644
--- a/pkg/sentry/kernel/auth/context.go
+++ b/pkg/sentry/kernel/auth/context.go
@@ -24,6 +24,10 @@ type contextID int
const (
// CtxCredentials is a Context.Value key for Credentials.
CtxCredentials contextID = iota
+
+ // CtxThreadGroupID is the current thread group ID when a context represents
+ // a task context. The value is represented as an int32.
+ CtxThreadGroupID contextID = iota
)
// CredentialsFromContext returns a copy of the Credentials used by ctx, or a
@@ -35,6 +39,15 @@ func CredentialsFromContext(ctx context.Context) *Credentials {
return NewAnonymousCredentials()
}
+// ThreadGroupIDFromContext returns the current thread group ID when ctx
+// represents a task context.
+func ThreadGroupIDFromContext(ctx context.Context) (tgid int32, ok bool) {
+ if tgid := ctx.Value(CtxThreadGroupID); tgid != nil {
+ return tgid.(int32), true
+ }
+ return 0, false
+}
+
// ContextWithCredentials returns a copy of ctx carrying creds.
func ContextWithCredentials(ctx context.Context, creds *Credentials) context.Context {
return &authContext{ctx, creds}
diff --git a/pkg/sentry/kernel/mq/mq.go b/pkg/sentry/kernel/mq/mq.go
index 07482decf..7515a2772 100644
--- a/pkg/sentry/kernel/mq/mq.go
+++ b/pkg/sentry/kernel/mq/mq.go
@@ -399,7 +399,7 @@ func (q *Queue) Flush(ctx context.Context) {
q.mu.Lock()
defer q.mu.Unlock()
- pid, ok := context.ThreadGroupIDFromContext(ctx)
+ pid, ok := auth.ThreadGroupIDFromContext(ctx)
if ok {
if q.subscriber != nil && pid == q.subscriber.pid {
q.subscriber = nil
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index ab938fa3c..bb9a129ab 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -444,7 +444,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ hostarch.
s.mu.Lock()
defer s.mu.Unlock()
s.attachTime = ktime.NowFromContext(ctx)
- if pid, ok := context.ThreadGroupIDFromContext(ctx); ok {
+ if pid, ok := auth.ThreadGroupIDFromContext(ctx); ok {
s.lastAttachDetachPID = pid
} else {
// AddMapping is called during a syscall, so ctx should always be a task
@@ -468,7 +468,7 @@ func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ hostar
// If called from a non-task context we also won't have a threadgroup
// id. Silently skip updating the lastAttachDetachPid in that case.
- if pid, ok := context.ThreadGroupIDFromContext(ctx); ok {
+ if pid, ok := auth.ThreadGroupIDFromContext(ctx); ok {
s.lastAttachDetachPID = pid
} else {
log.Debugf("Couldn't obtain pid when removing mapping to %s, not updating the last detach pid.", s.debugLocked())
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index cb9bcd7c0..ce38d9342 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -86,7 +86,7 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} {
return t
case auth.CtxCredentials:
return t.creds.Load()
- case context.CtxThreadGroupID:
+ case auth.CtxThreadGroupID:
return int32(t.tg.ID())
case fs.CtxRoot:
if !isTaskGoroutine {
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
index 4d73d46ef..fd0ab4c76 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mount.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -136,14 +136,14 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if err != nil {
return 0, nil, err
}
- tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink)
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.UMOUNT_NOFOLLOW == 0))
if err != nil {
return 0, nil, err
}
defer tpop.Release(t)
opts := vfs.UmountOptions{
- Flags: uint32(flags),
+ Flags: uint32(flags &^ linux.UMOUNT_NOFOLLOW),
}
return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index ead36880f..5d76adac1 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -134,6 +134,7 @@ go_test(
srcs = [
"conntrack_test.go",
"forwarding_test.go",
+ "iptables_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
"nic_test.go",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index c489506bb..1c6060b70 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -119,22 +119,24 @@ type conn struct {
//
// +checklocks:mu
destinationManip bool
+
+ stateMu sync.RWMutex `state:"nosave"`
// tcb is TCB control block. It is used to keep track of states
// of tcp connection.
//
- // +checklocks:mu
+ // +checklocks:stateMu
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
// is updated by each packet on the connection.
//
- // +checklocks:mu
+ // +checklocks:stateMu
lastUsed tcpip.MonotonicTime
}
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
- cn.mu.RLock()
- defer cn.mu.RUnlock()
+ cn.stateMu.RLock()
+ defer cn.stateMu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
// Use the same default as Linux, which doesn't delete
// established connections for 5(!) days.
@@ -147,7 +149,7 @@ func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
// update the connection tracking state.
//
-// +checklocks:cn.mu
+// +checklocks:cn.stateMu
func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) {
if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
return
@@ -209,17 +211,41 @@ type bucket struct {
tuples tupleList
}
-func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) {
- switch pkt.tuple.id().transProto {
+// A netAndTransHeadersFunc returns the network and transport headers found
+// in an ICMP payload. The transport layer's payload will not be returned.
+//
+// May panic if the packet does not hold the transport header.
+type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte)
+
+func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv4(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the buffer is smaller than
+ // the total length specified in the IPv4 header.
+ transHdr := icmpPayload[netHdr.HeaderLength():]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv6(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the IP payload is smaller than
+ // the payload length specified in the IPv6 header.
+ transHdr := icmpPayload[header.IPv6MinimumSize:]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) {
+ switch transProto {
case header.TCPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.TCP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize)
+ return netHeader, header.TCP(transHeaderBytes), true
}
case header.UDPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.UDP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize)
+ return netHeader, header.UDP(transHeaderBytes), true
}
}
return nil, nil, false
@@ -246,7 +272,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic("should have dropped packets with IPv4 options")
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok {
return netHdr, transHdr, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -264,7 +290,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto))
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok {
return netHdr, transHdr, true, true
}
}
@@ -283,34 +309,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro
}
}
-func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
- switch transProto {
- case header.TCPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.TCP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
- case header.UDPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.UDP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
+func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok {
+ return tupleID{
+ srcAddr: netHdr.DestinationAddress(),
+ srcPort: transHdr.DestinationPort(),
+ dstAddr: netHdr.SourceAddress(),
+ dstPort: transHdr.SourcePort(),
+ transProto: transProto,
+ netProto: netProto,
+ }, true
}
return tupleID{}, false
@@ -349,7 +357,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
return tupleID{}, false, false
}
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
return tid, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -370,7 +378,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
}
// TODO(https://gvisor.dev/issue/6789): Handle extension headers.
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
return tid, true, true
}
}
@@ -601,14 +609,17 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
// packets are fragmented.
reply := pkt.tuple.reply
- tid, performManip := func() (tupleID, bool) {
- cn.mu.Lock()
- defer cn.mu.Unlock()
- // Mark the connection as having been used recently so it isn't reaped.
- cn.lastUsed = cn.ct.clock.NowMonotonic()
- // Update connection state.
- cn.updateLocked(pkt, reply)
+ cn.stateMu.Lock()
+ // Mark the connection as having been used recently so it isn't reaped.
+ cn.lastUsed = cn.ct.clock.NowMonotonic()
+ // Update connection state.
+ cn.updateLocked(pkt, reply)
+ cn.stateMu.Unlock()
+
+ tid, performManip := func() (tupleID, bool) {
+ cn.mu.RLock()
+ defer cn.mu.RUnlock()
var tuple *tuple
if reply {
@@ -730,9 +741,6 @@ func (ct *ConnTrack) bucket(id tupleID) int {
// reapUnused deletes timed out entries from the conntrack map. The rules for
// reaping are:
-// - Most reaping occurs in connFor, which is called on each packet. connFor
-// cleans up the bucket the packet's connection maps to. Thus calls to
-// reapUnused should be fast.
// - Each call to reapUnused traverses a fraction of the conntrack table.
// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
// - After reaping, reapUnused decides when it should next run based on the
@@ -799,45 +807,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// Precondition: ct.mu is read locked and bkt.mu is write locked.
// +checklocksread:ct.mu
// +checklocks:bkt.mu
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
- if !tuple.conn.timedOut(now) {
+func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
+ if !reapingTuple.conn.timedOut(now) {
return false
}
- // To maintain lock order, we can only reap both tuples if the reply appears
- // later in the table.
- replyBktID := ct.bucket(tuple.id().reply())
- tuple.conn.mu.RLock()
- replyTupleInserted := tuple.conn.finalized
- tuple.conn.mu.RUnlock()
- if bktID > replyBktID && replyTupleInserted {
- return true
+ var otherTuple *tuple
+ if reapingTuple.reply {
+ otherTuple = &reapingTuple.conn.original
+ } else {
+ otherTuple = &reapingTuple.conn.reply
}
- // Reap the reply.
- if replyTupleInserted {
- // Don't re-lock if both tuples are in the same bucket.
- if bktID != replyBktID {
- replyBkt := &ct.buckets[replyBktID]
- replyBkt.mu.Lock()
- removeConnFromBucket(replyBkt, tuple)
- replyBkt.mu.Unlock()
- } else {
- removeConnFromBucket(bkt, tuple)
- }
+ otherTupleBktID := ct.bucket(otherTuple.id())
+ reapingTuple.conn.mu.RLock()
+ replyTupleInserted := reapingTuple.conn.finalized
+ reapingTuple.conn.mu.RUnlock()
+
+ // To maintain lock order, we can only reap both tuples if the tuple for the
+ // other direction appears later in the table.
+ if bktID > otherTupleBktID && replyTupleInserted {
+ return true
}
- bkt.tuples.Remove(tuple)
- return true
-}
+ bkt.tuples.Remove(reapingTuple)
+
+ if !replyTupleInserted {
+ // The other tuple is the reply which has not yet been inserted.
+ return true
+ }
-// +checklocks:b.mu
-func removeConnFromBucket(b *bucket, tuple *tuple) {
- if tuple.reply {
- b.tuples.Remove(&tuple.conn.original)
+ // Reap the other connection.
+ if bktID == otherTupleBktID {
+ // Don't re-lock if both tuples are in the same bucket.
+ bkt.tuples.Remove(otherTuple)
} else {
- b.tuples.Remove(&tuple.conn.reply)
+ otherTupleBkt := &ct.buckets[otherTupleBktID]
+ otherTupleBkt.mu.Lock()
+ otherTupleBkt.tuples.Remove(otherTuple)
+ otherTupleBkt.mu.Unlock()
}
+
+ return true
}
func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
diff --git a/pkg/tcpip/stack/iptables_test.go b/pkg/tcpip/stack/iptables_test.go
new file mode 100644
index 000000000..1788e98c9
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// TestNATedConnectionReap tests that NATed connections are properly reaped.
+func TestNATedConnectionReap(t *testing.T) {
+ // Note that the network protocol used for this test doesn't matter as this
+ // test focuses on reaping, not anything related to a specific network
+ // protocol.
+
+ const (
+ nattedDstPort = 1
+ srcPort = 2
+ dstPort = 3
+
+ nattedDstAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ srcAddr = tcpip.Address("\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ dstAddr = tcpip.Address("\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ )
+
+ clock := faketime.NewManualClock()
+ iptables := DefaultTables(0 /* seed */, clock)
+
+ table := Table{
+ Rules: []Rule{
+ // Prerouting
+ {
+ Target: &DNATTarget{NetworkProtocol: header.IPv6ProtocolNumber, Addr: nattedDstAddr, Port: nattedDstPort},
+ },
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Target: &AcceptTarget{},
+ },
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Input: 2,
+ Forward: 3,
+ Output: 4,
+ Postrouting: 5,
+ },
+ }
+ if err := iptables.ReplaceTable(NATID, table, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, true): %s", NATID, err)
+ }
+
+ // Stop the reaper if it is running so we can reap manually as it is started
+ // on the first change to IPTables.
+ iptables.reaperDone <- struct{}{}
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize + header.UDPMinimumSize,
+ })
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ udp.SetSourcePort(srcPort)
+ udp.SetDestinationPort(dstPort)
+ udp.SetChecksum(0)
+ udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
+ header.UDPProtocolNumber,
+ srcAddr,
+ dstAddr,
+ uint16(len(udp)),
+ )))
+ pkt.TransportProtocolNumber = header.UDPProtocolNumber
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(udp)),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: 64,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ })
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+
+ originalTID, _, ok := getTupleID(pkt)
+ if !ok {
+ t.Fatal("failed to get original tuple ID")
+ }
+
+ if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) {
+ t.Fatal("got ipt.CheckPrerouting(...) = false, want = true")
+ }
+ if !iptables.CheckInput(pkt, "" /* inNicName */) {
+ t.Fatal("got ipt.CheckInput(...) = false, want = true")
+ }
+
+ invertedReplyTID, _, ok := getTupleID(pkt)
+ if !ok {
+ t.Fatal("failed to get NATed packet's tuple ID")
+ }
+ if invertedReplyTID == originalTID {
+ t.Fatalf("NAT not performed; got invertedReplyTID = %#v", invertedReplyTID)
+ }
+ replyTID := invertedReplyTID.reply()
+
+ originalBktID := iptables.connections.bucket(originalTID)
+ replyBktID := iptables.connections.bucket(replyTID)
+
+ // This test depends on the original and reply tuples mapping to different
+ // buckets.
+ if originalBktID == replyBktID {
+ t.Fatalf("expected bucket IDs to be different; got = %d", originalBktID)
+ }
+
+ lowerBktID := originalBktID
+ if lowerBktID > replyBktID {
+ lowerBktID = replyBktID
+ }
+
+ runReaper := func() {
+ // Reaping the bucket with the lower ID should reap both tuples of the
+ // connection if it has timed out.
+ //
+ // We will manually pick the next start bucket ID and don't use the
+ // interval so we ignore the return values.
+ _, _ = iptables.connections.reapUnused(lowerBktID, 0 /* prevInterval */)
+ }
+
+ iptables.connections.mu.RLock()
+ buckets := iptables.connections.buckets
+ iptables.connections.mu.RUnlock()
+
+ originalBkt := &buckets[originalBktID]
+ replyBkt := &buckets[replyBktID]
+
+ // Run the reaper and make sure the tuples were not reaped.
+ reapAndCheckForConnections := func() {
+ t.Helper()
+
+ runReaper()
+
+ now := clock.NowMonotonic()
+ if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple == nil {
+ t.Error("expected to get original tuple")
+ }
+
+ if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple == nil {
+ t.Error("expected to get reply tuple")
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ // Connection was just added and no time has passed - it should not be reaped.
+ reapAndCheckForConnections()
+
+ // Time must advance past the unestablished timeout for a connection to be
+ // reaped.
+ clock.Advance(unestablishedTimeout)
+ reapAndCheckForConnections()
+
+ // Connection should now be reaped.
+ clock.Advance(1)
+ runReaper()
+ now := clock.NowMonotonic()
+ if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple != nil {
+ t.Errorf("got originalBkt.connForTID(%#v, %#v) = %#v, want = nil", originalTID, now, originalTuple)
+ }
+ if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple != nil {
+ t.Errorf("got replyBkt.connForTID(%#v, %#v) = %#v, want = nil", replyTID, now, replyTuple)
+ }
+ // Make sure we don't have stale tuples just lying around.
+ //
+ // We manually check the buckets as connForTID will skip over tuples that
+ // have timed out.
+ checkNoTupleInBucket := func(bkt *bucket, tid tupleID, reply bool) {
+ t.Helper()
+
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ if tuple.id() == originalTID {
+ t.Errorf("unexpectedly found tuple with ID = %#v; reply = %t", tid, reply)
+ }
+ }
+ }
+ checkNoTupleInBucket(originalBkt, originalTID, false /* reply */)
+ checkNoTupleInBucket(replyBkt, replyTID, true /* reply */)
+}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 7fe3b29d9..b2383576c 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -1781,8 +1781,11 @@ func TestNAT(t *testing.T) {
}
func TestNATICMPError(t *testing.T) {
- const srcPort = 1234
- const dstPort = 5432
+ const (
+ srcPort = 1234
+ dstPort = 5432
+ dataSize = 4
+ )
type icmpTypeTest struct {
name string
@@ -1836,8 +1839,7 @@ func TestNATICMPError(t *testing.T) {
netProto: ipv4.ProtocolNumber,
host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)
- hdr := buffer.NewPrependable(totalLen)
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original))
if n := copy(hdr.Prepend(len(original)), original); n != len(original) {
t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
}
@@ -1845,8 +1847,9 @@ func TestNATICMPError(t *testing.T) {
icmp.SetType(header.ICMPv4Type(icmpType))
icmp.SetChecksum(0)
icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.ICMPv4ProtocolNumber,
utils.Host1IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
@@ -1875,9 +1878,9 @@ func TestNATICMPError(t *testing.T) {
name: "UDP",
proto: header.UDPProtocolNumber,
buf: func() buffer.View {
- totalLen := header.IPv4MinimumSize + header.UDPMinimumSize
- hdr := buffer.NewPrependable(totalLen)
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpSize := header.UDPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize)
+ udp := header.UDP(hdr.Prepend(udpSize))
udp.SetSourcePort(srcPort)
udp.SetDestinationPort(dstPort)
udp.SetChecksum(0)
@@ -1887,8 +1890,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
uint16(len(udp)),
)))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.UDPProtocolNumber,
utils.Host2IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
@@ -1910,9 +1914,9 @@ func TestNATICMPError(t *testing.T) {
name: "TCP",
proto: header.TCPProtocolNumber,
buf: func() buffer.View {
- totalLen := header.IPv4MinimumSize + header.TCPMinimumSize
- hdr := buffer.NewPrependable(totalLen)
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcpSize := header.TCPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize)
+ tcp := header.TCP(hdr.Prepend(tcpSize))
tcp.SetSourcePort(srcPort)
tcp.SetDestinationPort(dstPort)
tcp.SetDataOffset(header.TCPMinimumSize)
@@ -1923,8 +1927,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
uint16(len(tcp)),
)))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.TCPProtocolNumber,
utils.Host2IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
@@ -1989,7 +1994,8 @@ func TestNATICMPError(t *testing.T) {
Src: utils.Host1IPv6Addr.AddressWithPrefix.Address,
Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
}))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
payloadLen,
header.ICMPv6ProtocolNumber,
utils.Host1IPv6Addr.AddressWithPrefix.Address,
@@ -2016,8 +2022,9 @@ func TestNATICMPError(t *testing.T) {
name: "UDP",
proto: header.UDPProtocolNumber,
buf: func() buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpSize := header.UDPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize)
+ udp := header.UDP(hdr.Prepend(udpSize))
udp.SetSourcePort(srcPort)
udp.SetDestinationPort(dstPort)
udp.SetChecksum(0)
@@ -2027,8 +2034,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
uint16(len(udp)),
)))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
- header.UDPMinimumSize,
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
+ len(udp),
header.UDPProtocolNumber,
utils.Host2IPv6Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -2050,8 +2058,9 @@ func TestNATICMPError(t *testing.T) {
name: "TCP",
proto: header.TCPProtocolNumber,
buf: func() buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.TCPMinimumSize)
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcpSize := header.TCPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize)
+ tcp := header.TCP(hdr.Prepend(tcpSize))
tcp.SetSourcePort(srcPort)
tcp.SetDestinationPort(dstPort)
tcp.SetDataOffset(header.TCPMinimumSize)
@@ -2062,8 +2071,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
uint16(len(tcp)),
)))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
- header.TCPMinimumSize,
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
+ len(tcp),
header.TCPProtocolNumber,
utils.Host2IPv6Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -2117,109 +2127,141 @@ func TestNATICMPError(t *testing.T) {
},
}
+ trimTests := []struct {
+ name string
+ trimLen int
+ expectNATedICMP bool
+ }{
+ {
+ name: "Trim nothing",
+ trimLen: 0,
+ expectNATedICMP: true,
+ },
+ {
+ name: "Trim data",
+ trimLen: dataSize,
+ expectNATedICMP: true,
+ },
+ {
+ name: "Trim data and transport header",
+ trimLen: dataSize + 1,
+ expectNATedICMP: false,
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for _, transportType := range test.transportTypes {
t.Run(transportType.name, func(t *testing.T) {
for _, icmpType := range test.icmpTypes {
t.Run(icmpType.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
-
- ep1 := channel.New(1, header.IPv6MinimumMTU, "")
- ep2 := channel.New(1, header.IPv6MinimumMTU, "")
- utils.SetupRouterStack(t, s, ep1, ep2)
-
- ipv6 := test.netProto == ipv6.ProtocolNumber
- ipt := s.IPTables()
-
- table := stack.Table{
- Rules: []stack.Rule{
- // Prerouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- InputInterface: utils.RouterNIC2Name,
+ for _, trimTest := range trimTests {
+ t.Run(trimTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ })
+
+ ep1 := channel.New(1, header.IPv6MinimumMTU, "")
+ ep2 := channel.New(1, header.IPv6MinimumMTU, "")
+ utils.SetupRouterStack(t, s, ep1, ep2)
+
+ ipv6 := test.netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ // Prerouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
},
- Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
- },
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Input
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Forward
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Output
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Postrouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- OutputInterface: utils.RouterNIC1Name,
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 2,
+ stack.Forward: 3,
+ stack.Output: 4,
+ stack.Postrouting: 5,
},
- Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
- },
- {
- Target: &stack.AcceptTarget{},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 2,
- stack.Forward: 3,
- stack.Output: 4,
- stack.Postrouting: 5,
- },
- }
+ }
- if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
- }
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
- ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(),
- }))
+ buf := transportType.buf
- {
- pkt, ok := ep1.Read()
- if !ok {
- t.Fatal("expected to read a packet on ep1")
- }
- pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
- transportType.checkNATed(t, pktView)
- if t.Failed() {
- t.FailNow()
- }
+ ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: append(buffer.View(nil), buf...).ToVectorisedView(),
+ }))
- ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
- }))
- }
+ {
+ pkt, ok := ep1.Read()
+ if !ok {
+ t.Fatal("expected to read a packet on ep1")
+ }
+ pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
+ transportType.checkNATed(t, pktView)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ pktView = pktView[:len(pktView)-trimTest.trimLen]
+ buf = buf[:len(buf)-trimTest.trimLen]
+
+ ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
+ }))
+ }
- pkt, ok := ep2.Read()
- if ok != icmpType.expectResponse {
- t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse)
- }
- if !icmpType.expectResponse {
- return
+ pkt, ok := ep2.Read()
+ expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP
+ if ok != expectResponse {
+ t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse)
+ }
+ if !expectResponse {
+ return
+ }
+ test.decrementTTL(buf)
+ test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val)
+ })
}
- test.decrementTTL(transportType.buf)
- test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val)
})
}
})
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 9fb3ebd95..f819cf8fb 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -78,6 +78,11 @@ type DefaultRoute struct {
Name string
}
+type Neighbor struct {
+ IP net.IP
+ HardwareAddr net.HardwareAddr
+}
+
// FDBasedLink configures an fd-based link.
type FDBasedLink struct {
Name string
@@ -90,6 +95,7 @@ type FDBasedLink struct {
RXChecksumOffload bool
LinkAddress net.HardwareAddr
QDisc config.QueueingDiscipline
+ Neighbors []Neighbor
// NumChannels controls how many underlying FD's are to be used to
// create this endpoint.
@@ -241,6 +247,11 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
routes = append(routes, route)
}
+
+ for _, neigh := range link.Neighbors {
+ proto, tcpipAddr := ipToAddressAndProto(neigh.IP)
+ n.Stack.AddStaticNeighbor(nicID, proto, tcpipAddr, tcpip.LinkAddress(neigh.HardwareAddr))
+ }
}
if !args.Defaultv4Gateway.Route.Empty() {
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 600b21189..3d610199c 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -140,6 +140,17 @@ func (a *attachPoint) Attach() (p9.File, error) {
return lf, nil
}
+// ServerOptions implements p9.Attacher. It's safe to call SetAttr and Allocate
+// on deleted files because fsgofer either uses an existing FD or opens a new
+// one using the magic symlink in `/proc/[pid]/fd` and cannot mistakely open
+// a file that was created in the same path as the delete file.
+func (a *attachPoint) ServerOptions() p9.AttacherOptions {
+ return p9.AttacherOptions{
+ SetAttrOnDeleted: true,
+ AllocateOnDeleted: true,
+ }
+}
+
// makeQID returns a unique QID for the given stat buffer.
func (a *attachPoint) makeQID(stat *unix.Stat_t) p9.QID {
a.deviceMu.Lock()
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index 3451d1037..03c5de2c6 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -173,6 +173,23 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
continue
}
+ // Collect data from the ARP table.
+ dump, err := netlink.NeighList(iface.Index, 0)
+ if err != nil {
+ return fmt.Errorf("fetching ARP table for %q: %w", iface.Name, err)
+ }
+
+ var neighbors []boot.Neighbor
+ for _, n := range dump {
+ // There are only two "good" states NUD_PERMANENT and NUD_REACHABLE,
+ // but NUD_REACHABLE is fully dynamic and will be re-probed anyway.
+ if n.State == netlink.NUD_PERMANENT {
+ log.Debugf("Copying a static ARP entry: %+v %+v", n.IP, n.HardwareAddr)
+ // No flags are copied because Stack.AddStaticNeighbor does not support flags right now.
+ neighbors = append(neighbors, boot.Neighbor{IP: n.IP, HardwareAddr: n.HardwareAddr})
+ }
+ }
+
// Scrape the routes before removing the address, since that
// will remove the routes as well.
routes, defv4, defv6, err := routesForIface(iface)
@@ -203,6 +220,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
RXChecksumOffload: rxChecksumOffload,
NumChannels: numNetworkChannels,
QDisc: qDisc,
+ Neighbors: neighbors,
}
// Get the link for the interface.
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index f748d685a..7952fd969 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -1053,3 +1053,7 @@ syscall_test(
syscall_test(
test = "//test/syscalls/linux:verity_mount_test",
)
+
+syscall_test(
+ test = "//test/syscalls/linux:deleted_test",
+)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 6217ff4dc..020c4673a 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -4432,3 +4432,18 @@ cc_binary(
"@com_google_absl//absl/container:flat_hash_set",
],
)
+
+cc_binary(
+ name = "deleted_test",
+ testonly = 1,
+ srcs = ["deleted.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ gtest,
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/syscalls/linux/deleted.cc b/test/syscalls/linux/deleted.cc
new file mode 100644
index 000000000..695ceafd3
--- /dev/null
+++ b/test/syscalls/linux/deleted.cc
@@ -0,0 +1,116 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+constexpr mode_t mode = 1;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<FileDescriptor> createdDeleted() {
+ auto path = NewTempAbsPath();
+ PosixErrorOr<FileDescriptor> fd = Open(path, O_RDWR | O_CREAT, mode);
+ if (!fd.ok()) {
+ return fd.error();
+ }
+
+ auto err = Unlink(path);
+ if (!err.ok()) {
+ return err;
+ }
+ return fd;
+}
+
+TEST(DeletedTest, Utime) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted());
+
+ const struct timespec times[2] = {{10, 0}, {20, 0}};
+ EXPECT_THAT(futimens(fd.get(), times), SyscallSucceeds());
+
+ struct stat stat;
+ ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds());
+ EXPECT_EQ(10, stat.st_atime);
+ EXPECT_EQ(20, stat.st_mtime);
+}
+
+TEST(DeletedTest, Chmod) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted());
+
+ ASSERT_THAT(fchmod(fd.get(), mode + 1), SyscallSucceeds());
+
+ struct stat stat;
+ ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds());
+ EXPECT_EQ(mode + 1, stat.st_mode & ~S_IFMT);
+}
+
+TEST(DeletedTest, Truncate) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted());
+ const std::string data = "foobar";
+ ASSERT_THAT(write(fd.get(), data.c_str(), data.size()), SyscallSucceeds());
+
+ ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds());
+
+ struct stat stat;
+ ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds());
+ ASSERT_EQ(stat.st_size, 0);
+}
+
+TEST(DeletedTest, Fallocate) {
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted());
+
+ ASSERT_THAT(fallocate(fd.get(), 0, 0, 123), SyscallSucceeds());
+
+ struct stat stat;
+ ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds());
+ EXPECT_EQ(123, stat.st_size);
+}
+
+// Tests that a file can be created with the same path as a deleted file that
+// still have an open FD to it.
+TEST(DeletedTest, Replace) {
+ auto path = NewTempAbsPath();
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT, mode));
+ ASSERT_NO_ERRNO(Unlink(path));
+
+ auto other =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT | O_EXCL, mode));
+
+ auto stat = ASSERT_NO_ERRNO_AND_VALUE(Fstat(fd.get()));
+ auto stat_other = ASSERT_NO_ERRNO_AND_VALUE(Fstat(other.get()));
+ ASSERT_NE(stat.st_ino, stat_other.st_ino);
+
+ // Check that the path points to the new file.
+ stat = ASSERT_NO_ERRNO_AND_VALUE(Stat(path));
+ ASSERT_EQ(stat.st_ino, stat_other.st_ino);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc
index 3c7311782..e2a41d172 100644
--- a/test/syscalls/linux/mount.cc
+++ b/test/syscalls/linux/mount.cc
@@ -115,6 +115,40 @@ TEST(MountTest, OpenFileBusy) {
EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EBUSY));
}
+TEST(MountTest, UmountNoFollow) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ auto const mountPoint = NewTempAbsPathInDir(dir.path());
+ ASSERT_THAT(mkdir(mountPoint.c_str(), 0777), SyscallSucceeds());
+
+ // Create a symlink in dir which will point to the actual mountpoint.
+ const std::string symlinkInDir = NewTempAbsPathInDir(dir.path());
+ EXPECT_THAT(symlink(mountPoint.c_str(), symlinkInDir.c_str()),
+ SyscallSucceeds());
+
+ // Create a symlink to the dir.
+ const std::string symlinkToDir = NewTempAbsPath();
+ EXPECT_THAT(symlink(dir.path().c_str(), symlinkToDir.c_str()),
+ SyscallSucceeds());
+
+ // Should fail with ELOOP when UMOUNT_NOFOLLOW is specified and the last
+ // component is a symlink.
+ auto mount = ASSERT_NO_ERRNO_AND_VALUE(
+ Mount("", mountPoint, "tmpfs", 0, "mode=0700", 0));
+ EXPECT_THAT(umount2(symlinkInDir.c_str(), UMOUNT_NOFOLLOW),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(unlink(symlinkInDir.c_str()), SyscallSucceeds());
+
+ // UMOUNT_NOFOLLOW should only apply to the last path component. A symlink in
+ // non-last path component should be just fine.
+ EXPECT_THAT(umount2(JoinPath(symlinkToDir, Basename(mountPoint)).c_str(),
+ UMOUNT_NOFOLLOW),
+ SyscallSucceeds());
+ mount.Release();
+}
+
TEST(MountTest, UmountDetach) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index 253411858..1c24d9ffc 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -188,6 +188,14 @@ PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode,
return NoError();
}
+PosixError Unlink(absl::string_view path) {
+ int res = unlink(std::string(path).c_str());
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("unlink ", path));
+ }
+ return NoError();
+}
+
PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path,
int flags) {
int res = unlinkat(dfd.get(), std::string(path).c_str(), flags);
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index bb2d1d3c8..3ae0a725a 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -71,6 +71,7 @@ PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode,
dev_t dev);
// Unlink the file.
+PosixError Unlink(absl::string_view path);
PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path,
int flags);