diff options
-rw-r--r-- | g3doc/README.md | 2 | ||||
-rw-r--r-- | pkg/context/context.go | 20 | ||||
-rw-r--r-- | pkg/p9/file.go | 35 | ||||
-rw-r--r-- | pkg/p9/handlers.go | 6 | ||||
-rw-r--r-- | pkg/p9/p9test/BUILD | 2 | ||||
-rw-r--r-- | pkg/p9/p9test/p9test.go | 1 | ||||
-rw-r--r-- | pkg/p9/server.go | 9 | ||||
-rw-r--r-- | pkg/sentry/kernel/auth/context.go | 13 | ||||
-rw-r--r-- | pkg/sentry/kernel/mq/mq.go | 2 | ||||
-rw-r--r-- | pkg/sentry/kernel/shm/shm.go | 4 | ||||
-rw-r--r-- | pkg/sentry/kernel/task_context.go | 2 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/vfs2/mount.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 175 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_test.go | 220 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 270 | ||||
-rw-r--r-- | runsc/boot/network.go | 11 | ||||
-rw-r--r-- | runsc/fsgofer/fsgofer.go | 11 | ||||
-rw-r--r-- | runsc/sandbox/network.go | 18 | ||||
-rw-r--r-- | test/syscalls/BUILD | 4 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 15 | ||||
-rw-r--r-- | test/syscalls/linux/deleted.cc | 116 | ||||
-rw-r--r-- | test/syscalls/linux/mount.cc | 34 | ||||
-rw-r--r-- | test/util/fs_util.cc | 8 | ||||
-rw-r--r-- | test/util/fs_util.h | 1 |
25 files changed, 753 insertions, 231 deletions
diff --git a/g3doc/README.md b/g3doc/README.md index dc4179037..5e23aa5ec 100644 --- a/g3doc/README.md +++ b/g3doc/README.md @@ -23,7 +23,7 @@ links below to see detailed instructions for each of them: gVisor provides a virtualized environment in order to sandbox containers. The system interfaces normally implemented by the host kernel are moved into a -distinct, per-sandbox application kernel in order to minimize the risk of an +distinct, per-sandbox application kernel in order to minimize the risk of a container escape exploit. gVisor does not introduce large fixed overheads however, and still retains a process-like model with respect to resource utilization. diff --git a/pkg/context/context.go b/pkg/context/context.go index f3031fc60..e86c14195 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -29,26 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/log" ) -type contextID int - -// Globally accessible values from a context. These keys are defined in the -// context package to resolve dependency cycles by not requiring the caller to -// import packages usually required to get these information. -const ( - // CtxThreadGroupID is the current thread group ID when a context represents - // a task context. The value is represented as an int32. - CtxThreadGroupID contextID = iota -) - -// ThreadGroupIDFromContext returns the current thread group ID when ctx -// represents a task context. -func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { - if tgid := ctx.Value(CtxThreadGroupID); tgid != nil { - return tgid.(int32), true - } - return 0, false -} - // A Context represents a thread of execution (hereafter "goroutine" to reflect // Go idiosyncrasy). It carries state associated with the goroutine across API // boundaries. diff --git a/pkg/p9/file.go b/pkg/p9/file.go index 8d6af2d6b..b4b556cb9 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -21,13 +21,37 @@ import ( "gvisor.dev/gvisor/pkg/fd" ) +// AttacherOptions contains Attacher configuration. +type AttacherOptions struct { + // SetAttrOnDeleted is set to true if it's safe to call File.SetAttr for + // deleted files. + SetAttrOnDeleted bool + + // AllocateOnDeleted is set to true if it's safe to call File.Allocate for + // deleted files. + AllocateOnDeleted bool +} + +// NoServerOptions partially implements Attacher with empty AttacherOptions. +type NoServerOptions struct{} + +// ServerOptions implements Attacher. +func (*NoServerOptions) ServerOptions() AttacherOptions { + return AttacherOptions{} +} + // Attacher is provided by the server. type Attacher interface { // Attach returns a new File. // - // The client-side attach will be translate to a series of walks from + // The client-side attach will be translated to a series of walks from // the file returned by this Attach call. Attach() (File, error) + + // ServerOptions returns configuration options for this attach point. + // + // This is never caller in the client-side. + ServerOptions() AttacherOptions } // File is a set of operations corresponding to a single node. @@ -301,7 +325,7 @@ type File interface { type DefaultWalkGetAttr struct{} // WalkGetAttr implements File.WalkGetAttr. -func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) { +func (*DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) { return nil, nil, AttrMask{}, Attr{}, unix.ENOSYS } @@ -309,7 +333,7 @@ func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, er type DisallowClientCalls struct{} // SetAttrClose implements File.SetAttrClose. -func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { +func (*DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { panic("SetAttrClose should not be called on the server") } @@ -321,6 +345,11 @@ func (*DisallowServerCalls) Renamed(File, string) { panic("Renamed should not be called on the client") } +// ServerOptions implements Attacher. +func (*DisallowServerCalls) ServerOptions() AttacherOptions { + panic("ServerOptions should not be called on the client") +} + // DefaultMultiGetAttr implements File.MultiGetAttr() on top of File. func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { stats := make([]FullStat, 0, len(names)) diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 2657081e3..c85af5e9e 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -178,7 +178,7 @@ func (t *Tsetattrclunk) handle(cs *connState) message { // This might be technically incorrect, as it's possible that // there were multiple links and you can still change the // corresponding inode information. - if ref.isDeleted() { + if !cs.server.options.SetAttrOnDeleted && ref.isDeleted() { return unix.EINVAL } @@ -913,7 +913,7 @@ func (t *Tsetattr) handle(cs *connState) message { // This might be technically incorrect, as it's possible that // there were multiple links and you can still change the // corresponding inode information. - if ref.isDeleted() { + if !cs.server.options.SetAttrOnDeleted && ref.isDeleted() { return unix.EINVAL } @@ -946,7 +946,7 @@ func (t *Tallocate) handle(cs *connState) message { } // We don't allow allocate on files that have been deleted. - if ref.isDeleted() { + if !cs.server.options.AllocateOnDeleted && ref.isDeleted() { return unix.EINVAL } diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 9c1ada0cb..f3eb8468b 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -12,7 +12,7 @@ MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9" # mockgen_reflect is a source file that contains mock generation code that # imports the p9 package and generates a specification via reflection. The # usual generation path must be split into two distinct parts because the full -# source tree is not available to all build targets. Only declared depencies +# source tree is not available to all build targets. Only declared dependencies # are available (and even then, not the Go source files). genrule( name = "mockgen_reflect", diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index fd5ac3dbe..56939d100 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -307,6 +307,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) { } // Start the server, synchronized on exit. + h.Attacher.EXPECT().ServerOptions().Return(p9.AttacherOptions{}).Times(1) server := p9.NewServer(h.Attacher) h.wg.Add(1) go func() { diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 6428ad745..e7d129f9d 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -34,6 +34,8 @@ type Server struct { // attacher provides the attach function. attacher Attacher + options AttacherOptions + // pathTree is the full set of paths opened on this server. // // These may be across different connections, but rename operations @@ -48,10 +50,15 @@ type Server struct { renameMu sync.RWMutex } -// NewServer returns a new server. +// NewServer returns a new server. attacher may be nil. func NewServer(attacher Attacher) *Server { + opts := AttacherOptions{} + if attacher != nil { + opts = attacher.ServerOptions() + } return &Server{ attacher: attacher, + options: opts, pathTree: newPathNode(), } } diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go index c08d47787..2039a96ad 100644 --- a/pkg/sentry/kernel/auth/context.go +++ b/pkg/sentry/kernel/auth/context.go @@ -24,6 +24,10 @@ type contextID int const ( // CtxCredentials is a Context.Value key for Credentials. CtxCredentials contextID = iota + + // CtxThreadGroupID is the current thread group ID when a context represents + // a task context. The value is represented as an int32. + CtxThreadGroupID contextID = iota ) // CredentialsFromContext returns a copy of the Credentials used by ctx, or a @@ -35,6 +39,15 @@ func CredentialsFromContext(ctx context.Context) *Credentials { return NewAnonymousCredentials() } +// ThreadGroupIDFromContext returns the current thread group ID when ctx +// represents a task context. +func ThreadGroupIDFromContext(ctx context.Context) (tgid int32, ok bool) { + if tgid := ctx.Value(CtxThreadGroupID); tgid != nil { + return tgid.(int32), true + } + return 0, false +} + // ContextWithCredentials returns a copy of ctx carrying creds. func ContextWithCredentials(ctx context.Context, creds *Credentials) context.Context { return &authContext{ctx, creds} diff --git a/pkg/sentry/kernel/mq/mq.go b/pkg/sentry/kernel/mq/mq.go index 07482decf..7515a2772 100644 --- a/pkg/sentry/kernel/mq/mq.go +++ b/pkg/sentry/kernel/mq/mq.go @@ -399,7 +399,7 @@ func (q *Queue) Flush(ctx context.Context) { q.mu.Lock() defer q.mu.Unlock() - pid, ok := context.ThreadGroupIDFromContext(ctx) + pid, ok := auth.ThreadGroupIDFromContext(ctx) if ok { if q.subscriber != nil && pid == q.subscriber.pid { q.subscriber = nil diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index ab938fa3c..bb9a129ab 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -444,7 +444,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ hostarch. s.mu.Lock() defer s.mu.Unlock() s.attachTime = ktime.NowFromContext(ctx) - if pid, ok := context.ThreadGroupIDFromContext(ctx); ok { + if pid, ok := auth.ThreadGroupIDFromContext(ctx); ok { s.lastAttachDetachPID = pid } else { // AddMapping is called during a syscall, so ctx should always be a task @@ -468,7 +468,7 @@ func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ hostar // If called from a non-task context we also won't have a threadgroup // id. Silently skip updating the lastAttachDetachPid in that case. - if pid, ok := context.ThreadGroupIDFromContext(ctx); ok { + if pid, ok := auth.ThreadGroupIDFromContext(ctx); ok { s.lastAttachDetachPID = pid } else { log.Debugf("Couldn't obtain pid when removing mapping to %s, not updating the last detach pid.", s.debugLocked()) diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index cb9bcd7c0..ce38d9342 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -86,7 +86,7 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { return t case auth.CtxCredentials: return t.creds.Load() - case context.CtxThreadGroupID: + case auth.CtxThreadGroupID: return int32(t.tg.ID()) case fs.CtxRoot: if !isTaskGoroutine { diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index 4d73d46ef..fd0ab4c76 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -136,14 +136,14 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if err != nil { return 0, nil, err } - tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink) + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.UMOUNT_NOFOLLOW == 0)) if err != nil { return 0, nil, err } defer tpop.Release(t) opts := vfs.UmountOptions{ - Flags: uint32(flags), + Flags: uint32(flags &^ linux.UMOUNT_NOFOLLOW), } return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index ead36880f..5d76adac1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -134,6 +134,7 @@ go_test( srcs = [ "conntrack_test.go", "forwarding_test.go", + "iptables_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index c489506bb..1c6060b70 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -119,22 +119,24 @@ type conn struct { // // +checklocks:mu destinationManip bool + + stateMu sync.RWMutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states // of tcp connection. // - // +checklocks:mu + // +checklocks:stateMu tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. // - // +checklocks:mu + // +checklocks:stateMu lastUsed tcpip.MonotonicTime } // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { - cn.mu.RLock() - defer cn.mu.RUnlock() + cn.stateMu.RLock() + defer cn.stateMu.RUnlock() if cn.tcb.State() == tcpconntrack.ResultAlive { // Use the same default as Linux, which doesn't delete // established connections for 5(!) days. @@ -147,7 +149,7 @@ func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { // update the connection tracking state. // -// +checklocks:cn.mu +// +checklocks:cn.stateMu func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) { if pkt.TransportProtocolNumber != header.TCPProtocolNumber { return @@ -209,17 +211,41 @@ type bucket struct { tuples tupleList } -func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) { - switch pkt.tuple.id().transProto { +// A netAndTransHeadersFunc returns the network and transport headers found +// in an ICMP payload. The transport layer's payload will not be returned. +// +// May panic if the packet does not hold the transport header. +type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte) + +func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv4(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the buffer is smaller than + // the total length specified in the IPv4 header. + transHdr := icmpPayload[netHdr.HeaderLength():] + return netHdr, transHdr[:minTransHdrLen] +} + +func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv6(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the IP payload is smaller than + // the payload length specified in the IPv6 header. + transHdr := icmpPayload[header.IPv6MinimumSize:] + return netHdr, transHdr[:minTransHdrLen] +} + +func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) { + switch transProto { case header.TCPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.TCP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize) + return netHeader, header.TCP(transHeaderBytes), true } case header.UDPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.UDP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize) + return netHeader, header.UDP(transHeaderBytes), true } } return nil, nil, false @@ -246,7 +272,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic("should have dropped packets with IPv4 options") } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok { return netHdr, transHdr, true, true } case header.ICMPv6ProtocolNumber: @@ -264,7 +290,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto)) } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok { return netHdr, transHdr, true, true } } @@ -283,34 +309,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro } } -func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { - switch transProto { - case header.TCPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.TCP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } - case header.UDPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.UDP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } +func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok { + return tupleID{ + srcAddr: netHdr.DestinationAddress(), + srcPort: transHdr.DestinationPort(), + dstAddr: netHdr.SourceAddress(), + dstPort: transHdr.SourcePort(), + transProto: transProto, + netProto: netProto, + }, true } return tupleID{}, false @@ -349,7 +357,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { return tupleID{}, false, false } - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { return tid, true, true } case header.ICMPv6ProtocolNumber: @@ -370,7 +378,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { } // TODO(https://gvisor.dev/issue/6789): Handle extension headers. - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { return tid, true, true } } @@ -601,14 +609,17 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { // packets are fragmented. reply := pkt.tuple.reply - tid, performManip := func() (tupleID, bool) { - cn.mu.Lock() - defer cn.mu.Unlock() - // Mark the connection as having been used recently so it isn't reaped. - cn.lastUsed = cn.ct.clock.NowMonotonic() - // Update connection state. - cn.updateLocked(pkt, reply) + cn.stateMu.Lock() + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = cn.ct.clock.NowMonotonic() + // Update connection state. + cn.updateLocked(pkt, reply) + cn.stateMu.Unlock() + + tid, performManip := func() (tupleID, bool) { + cn.mu.RLock() + defer cn.mu.RUnlock() var tuple *tuple if reply { @@ -730,9 +741,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused deletes timed out entries from the conntrack map. The rules for // reaping are: -// - Most reaping occurs in connFor, which is called on each packet. connFor -// cleans up the bucket the packet's connection maps to. Thus calls to -// reapUnused should be fast. // - Each call to reapUnused traverses a fraction of the conntrack table. // Specifically, it traverses len(ct.buckets)/fractionPerReaping. // - After reaping, reapUnused decides when it should next run based on the @@ -799,45 +807,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // Precondition: ct.mu is read locked and bkt.mu is write locked. // +checklocksread:ct.mu // +checklocks:bkt.mu -func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { - if !tuple.conn.timedOut(now) { +func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { + if !reapingTuple.conn.timedOut(now) { return false } - // To maintain lock order, we can only reap both tuples if the reply appears - // later in the table. - replyBktID := ct.bucket(tuple.id().reply()) - tuple.conn.mu.RLock() - replyTupleInserted := tuple.conn.finalized - tuple.conn.mu.RUnlock() - if bktID > replyBktID && replyTupleInserted { - return true + var otherTuple *tuple + if reapingTuple.reply { + otherTuple = &reapingTuple.conn.original + } else { + otherTuple = &reapingTuple.conn.reply } - // Reap the reply. - if replyTupleInserted { - // Don't re-lock if both tuples are in the same bucket. - if bktID != replyBktID { - replyBkt := &ct.buckets[replyBktID] - replyBkt.mu.Lock() - removeConnFromBucket(replyBkt, tuple) - replyBkt.mu.Unlock() - } else { - removeConnFromBucket(bkt, tuple) - } + otherTupleBktID := ct.bucket(otherTuple.id()) + reapingTuple.conn.mu.RLock() + replyTupleInserted := reapingTuple.conn.finalized + reapingTuple.conn.mu.RUnlock() + + // To maintain lock order, we can only reap both tuples if the tuple for the + // other direction appears later in the table. + if bktID > otherTupleBktID && replyTupleInserted { + return true } - bkt.tuples.Remove(tuple) - return true -} + bkt.tuples.Remove(reapingTuple) + + if !replyTupleInserted { + // The other tuple is the reply which has not yet been inserted. + return true + } -// +checklocks:b.mu -func removeConnFromBucket(b *bucket, tuple *tuple) { - if tuple.reply { - b.tuples.Remove(&tuple.conn.original) + // Reap the other connection. + if bktID == otherTupleBktID { + // Don't re-lock if both tuples are in the same bucket. + bkt.tuples.Remove(otherTuple) } else { - b.tuples.Remove(&tuple.conn.reply) + otherTupleBkt := &ct.buckets[otherTupleBktID] + otherTupleBkt.mu.Lock() + otherTupleBkt.tuples.Remove(otherTuple) + otherTupleBkt.mu.Unlock() } + + return true } func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { diff --git a/pkg/tcpip/stack/iptables_test.go b/pkg/tcpip/stack/iptables_test.go new file mode 100644 index 000000000..1788e98c9 --- /dev/null +++ b/pkg/tcpip/stack/iptables_test.go @@ -0,0 +1,220 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// TestNATedConnectionReap tests that NATed connections are properly reaped. +func TestNATedConnectionReap(t *testing.T) { + // Note that the network protocol used for this test doesn't matter as this + // test focuses on reaping, not anything related to a specific network + // protocol. + + const ( + nattedDstPort = 1 + srcPort = 2 + dstPort = 3 + + nattedDstAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + srcAddr = tcpip.Address("\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + dstAddr = tcpip.Address("\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ) + + clock := faketime.NewManualClock() + iptables := DefaultTables(0 /* seed */, clock) + + table := Table{ + Rules: []Rule{ + // Prerouting + { + Target: &DNATTarget{NetworkProtocol: header.IPv6ProtocolNumber, Addr: nattedDstAddr, Port: nattedDstPort}, + }, + { + Target: &AcceptTarget{}, + }, + + // Input + { + Target: &AcceptTarget{}, + }, + + // Forward + { + Target: &AcceptTarget{}, + }, + + // Output + { + Target: &AcceptTarget{}, + }, + + // Postrouting + { + Target: &AcceptTarget{}, + }, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 2, + Forward: 3, + Output: 4, + Postrouting: 5, + }, + } + if err := iptables.ReplaceTable(NATID, table, true /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, true): %s", NATID, err) + } + + // Stop the reaper if it is running so we can reap manually as it is started + // on the first change to IPTables. + iptables.reaperDone <- struct{}{} + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + header.UDPMinimumSize, + }) + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udp.SetSourcePort(srcPort) + udp.SetDestinationPort(dstPort) + udp.SetChecksum(0) + udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + srcAddr, + dstAddr, + uint16(len(udp)), + ))) + pkt.TransportProtocolNumber = header.UDPProtocolNumber + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(udp)), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 64, + SrcAddr: srcAddr, + DstAddr: dstAddr, + }) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + + originalTID, _, ok := getTupleID(pkt) + if !ok { + t.Fatal("failed to get original tuple ID") + } + + if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) { + t.Fatal("got ipt.CheckPrerouting(...) = false, want = true") + } + if !iptables.CheckInput(pkt, "" /* inNicName */) { + t.Fatal("got ipt.CheckInput(...) = false, want = true") + } + + invertedReplyTID, _, ok := getTupleID(pkt) + if !ok { + t.Fatal("failed to get NATed packet's tuple ID") + } + if invertedReplyTID == originalTID { + t.Fatalf("NAT not performed; got invertedReplyTID = %#v", invertedReplyTID) + } + replyTID := invertedReplyTID.reply() + + originalBktID := iptables.connections.bucket(originalTID) + replyBktID := iptables.connections.bucket(replyTID) + + // This test depends on the original and reply tuples mapping to different + // buckets. + if originalBktID == replyBktID { + t.Fatalf("expected bucket IDs to be different; got = %d", originalBktID) + } + + lowerBktID := originalBktID + if lowerBktID > replyBktID { + lowerBktID = replyBktID + } + + runReaper := func() { + // Reaping the bucket with the lower ID should reap both tuples of the + // connection if it has timed out. + // + // We will manually pick the next start bucket ID and don't use the + // interval so we ignore the return values. + _, _ = iptables.connections.reapUnused(lowerBktID, 0 /* prevInterval */) + } + + iptables.connections.mu.RLock() + buckets := iptables.connections.buckets + iptables.connections.mu.RUnlock() + + originalBkt := &buckets[originalBktID] + replyBkt := &buckets[replyBktID] + + // Run the reaper and make sure the tuples were not reaped. + reapAndCheckForConnections := func() { + t.Helper() + + runReaper() + + now := clock.NowMonotonic() + if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple == nil { + t.Error("expected to get original tuple") + } + + if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple == nil { + t.Error("expected to get reply tuple") + } + + if t.Failed() { + t.FailNow() + } + } + + // Connection was just added and no time has passed - it should not be reaped. + reapAndCheckForConnections() + + // Time must advance past the unestablished timeout for a connection to be + // reaped. + clock.Advance(unestablishedTimeout) + reapAndCheckForConnections() + + // Connection should now be reaped. + clock.Advance(1) + runReaper() + now := clock.NowMonotonic() + if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple != nil { + t.Errorf("got originalBkt.connForTID(%#v, %#v) = %#v, want = nil", originalTID, now, originalTuple) + } + if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple != nil { + t.Errorf("got replyBkt.connForTID(%#v, %#v) = %#v, want = nil", replyTID, now, replyTuple) + } + // Make sure we don't have stale tuples just lying around. + // + // We manually check the buckets as connForTID will skip over tuples that + // have timed out. + checkNoTupleInBucket := func(bkt *bucket, tid tupleID, reply bool) { + t.Helper() + + bkt.mu.RLock() + defer bkt.mu.RUnlock() + for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() { + if tuple.id() == originalTID { + t.Errorf("unexpectedly found tuple with ID = %#v; reply = %t", tid, reply) + } + } + } + checkNoTupleInBucket(originalBkt, originalTID, false /* reply */) + checkNoTupleInBucket(replyBkt, replyTID, true /* reply */) +} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 7fe3b29d9..b2383576c 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -1781,8 +1781,11 @@ func TestNAT(t *testing.T) { } func TestNATICMPError(t *testing.T) { - const srcPort = 1234 - const dstPort = 5432 + const ( + srcPort = 1234 + dstPort = 5432 + dataSize = 4 + ) type icmpTypeTest struct { name string @@ -1836,8 +1839,7 @@ func TestNATICMPError(t *testing.T) { netProto: ipv4.ProtocolNumber, host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original) - hdr := buffer.NewPrependable(totalLen) + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)) if n := copy(hdr.Prepend(len(original)), original); n != len(original) { t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) } @@ -1845,8 +1847,9 @@ func TestNATICMPError(t *testing.T) { icmp.SetType(header.ICMPv4Type(icmpType)) icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0)) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.ICMPv4ProtocolNumber, utils.Host1IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, @@ -1875,9 +1878,9 @@ func TestNATICMPError(t *testing.T) { name: "UDP", proto: header.UDPProtocolNumber, buf: func() buffer.View { - totalLen := header.IPv4MinimumSize + header.UDPMinimumSize - hdr := buffer.NewPrependable(totalLen) - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpSize := header.UDPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize) + udp := header.UDP(hdr.Prepend(udpSize)) udp.SetSourcePort(srcPort) udp.SetDestinationPort(dstPort) udp.SetChecksum(0) @@ -1887,8 +1890,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, uint16(len(udp)), ))) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.UDPProtocolNumber, utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, @@ -1910,9 +1914,9 @@ func TestNATICMPError(t *testing.T) { name: "TCP", proto: header.TCPProtocolNumber, buf: func() buffer.View { - totalLen := header.IPv4MinimumSize + header.TCPMinimumSize - hdr := buffer.NewPrependable(totalLen) - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcpSize := header.TCPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize) + tcp := header.TCP(hdr.Prepend(tcpSize)) tcp.SetSourcePort(srcPort) tcp.SetDestinationPort(dstPort) tcp.SetDataOffset(header.TCPMinimumSize) @@ -1923,8 +1927,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, uint16(len(tcp)), ))) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.TCPProtocolNumber, utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, @@ -1989,7 +1994,8 @@ func TestNATICMPError(t *testing.T) { Src: utils.Host1IPv6Addr.AddressWithPrefix.Address, Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, })) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), payloadLen, header.ICMPv6ProtocolNumber, utils.Host1IPv6Addr.AddressWithPrefix.Address, @@ -2016,8 +2022,9 @@ func TestNATICMPError(t *testing.T) { name: "UDP", proto: header.UDPProtocolNumber, buf: func() buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpSize := header.UDPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize) + udp := header.UDP(hdr.Prepend(udpSize)) udp.SetSourcePort(srcPort) udp.SetDestinationPort(dstPort) udp.SetChecksum(0) @@ -2027,8 +2034,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, uint16(len(udp)), ))) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), - header.UDPMinimumSize, + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), + len(udp), header.UDPProtocolNumber, utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, @@ -2050,8 +2058,9 @@ func TestNATICMPError(t *testing.T) { name: "TCP", proto: header.TCPProtocolNumber, buf: func() buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.TCPMinimumSize) - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcpSize := header.TCPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize) + tcp := header.TCP(hdr.Prepend(tcpSize)) tcp.SetSourcePort(srcPort) tcp.SetDestinationPort(dstPort) tcp.SetDataOffset(header.TCPMinimumSize) @@ -2062,8 +2071,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, uint16(len(tcp)), ))) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), - header.TCPMinimumSize, + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), + len(tcp), header.TCPProtocolNumber, utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, @@ -2117,109 +2127,141 @@ func TestNATICMPError(t *testing.T) { }, } + trimTests := []struct { + name string + trimLen int + expectNATedICMP bool + }{ + { + name: "Trim nothing", + trimLen: 0, + expectNATedICMP: true, + }, + { + name: "Trim data", + trimLen: dataSize, + expectNATedICMP: true, + }, + { + name: "Trim data and transport header", + trimLen: dataSize + 1, + expectNATedICMP: false, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { for _, transportType := range test.transportTypes { t.Run(transportType.name, func(t *testing.T) { for _, icmpType := range test.icmpTypes { t.Run(icmpType.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - - ep1 := channel.New(1, header.IPv6MinimumMTU, "") - ep2 := channel.New(1, header.IPv6MinimumMTU, "") - utils.SetupRouterStack(t, s, ep1, ep2) - - ipv6 := test.netProto == ipv6.ProtocolNumber - ipt := s.IPTables() - - table := stack.Table{ - Rules: []stack.Rule{ - // Prerouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - InputInterface: utils.RouterNIC2Name, + for _, trimTest := range trimTests { + t.Run(trimTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + + ep1 := channel.New(1, header.IPv6MinimumMTU, "") + ep2 := channel.New(1, header.IPv6MinimumMTU, "") + utils.SetupRouterStack(t, s, ep1, ep2) + + ipv6 := test.netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + + table := stack.Table{ + Rules: []stack.Rule{ + // Prerouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, + }, + { + Target: &stack.AcceptTarget{}, + }, + + // Input + { + Target: &stack.AcceptTarget{}, + }, + + // Forward + { + Target: &stack.AcceptTarget{}, + }, + + // Output + { + Target: &stack.AcceptTarget{}, + }, + + // Postrouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, + }, + { + Target: &stack.AcceptTarget{}, + }, }, - Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, - }, - { - Target: &stack.AcceptTarget{}, - }, - - // Input - { - Target: &stack.AcceptTarget{}, - }, - - // Forward - { - Target: &stack.AcceptTarget{}, - }, - - // Output - { - Target: &stack.AcceptTarget{}, - }, - - // Postrouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - OutputInterface: utils.RouterNIC1Name, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 2, + stack.Forward: 3, + stack.Output: 4, + stack.Postrouting: 5, }, - Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, - }, - { - Target: &stack.AcceptTarget{}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 2, - stack.Forward: 3, - stack.Output: 4, - stack.Postrouting: 5, - }, - } + } - if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) - } + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } - ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(), - })) + buf := transportType.buf - { - pkt, ok := ep1.Read() - if !ok { - t.Fatal("expected to read a packet on ep1") - } - pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) - transportType.checkNATed(t, pktView) - if t.Failed() { - t.FailNow() - } + ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: append(buffer.View(nil), buf...).ToVectorisedView(), + })) - ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), - })) - } + { + pkt, ok := ep1.Read() + if !ok { + t.Fatal("expected to read a packet on ep1") + } + pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) + transportType.checkNATed(t, pktView) + if t.Failed() { + t.FailNow() + } + + pktView = pktView[:len(pktView)-trimTest.trimLen] + buf = buf[:len(buf)-trimTest.trimLen] + + ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), + })) + } - pkt, ok := ep2.Read() - if ok != icmpType.expectResponse { - t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse) - } - if !icmpType.expectResponse { - return + pkt, ok := ep2.Read() + expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP + if ok != expectResponse { + t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse) + } + if !expectResponse { + return + } + test.decrementTTL(buf) + test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val) + }) } - test.decrementTTL(transportType.buf) - test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val) }) } }) diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 9fb3ebd95..f819cf8fb 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -78,6 +78,11 @@ type DefaultRoute struct { Name string } +type Neighbor struct { + IP net.IP + HardwareAddr net.HardwareAddr +} + // FDBasedLink configures an fd-based link. type FDBasedLink struct { Name string @@ -90,6 +95,7 @@ type FDBasedLink struct { RXChecksumOffload bool LinkAddress net.HardwareAddr QDisc config.QueueingDiscipline + Neighbors []Neighbor // NumChannels controls how many underlying FD's are to be used to // create this endpoint. @@ -241,6 +247,11 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } routes = append(routes, route) } + + for _, neigh := range link.Neighbors { + proto, tcpipAddr := ipToAddressAndProto(neigh.IP) + n.Stack.AddStaticNeighbor(nicID, proto, tcpipAddr, tcpip.LinkAddress(neigh.HardwareAddr)) + } } if !args.Defaultv4Gateway.Route.Empty() { diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index 600b21189..3d610199c 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -140,6 +140,17 @@ func (a *attachPoint) Attach() (p9.File, error) { return lf, nil } +// ServerOptions implements p9.Attacher. It's safe to call SetAttr and Allocate +// on deleted files because fsgofer either uses an existing FD or opens a new +// one using the magic symlink in `/proc/[pid]/fd` and cannot mistakely open +// a file that was created in the same path as the delete file. +func (a *attachPoint) ServerOptions() p9.AttacherOptions { + return p9.AttacherOptions{ + SetAttrOnDeleted: true, + AllocateOnDeleted: true, + } +} + // makeQID returns a unique QID for the given stat buffer. func (a *attachPoint) makeQID(stat *unix.Stat_t) p9.QID { a.deviceMu.Lock() diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 3451d1037..03c5de2c6 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -173,6 +173,23 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG continue } + // Collect data from the ARP table. + dump, err := netlink.NeighList(iface.Index, 0) + if err != nil { + return fmt.Errorf("fetching ARP table for %q: %w", iface.Name, err) + } + + var neighbors []boot.Neighbor + for _, n := range dump { + // There are only two "good" states NUD_PERMANENT and NUD_REACHABLE, + // but NUD_REACHABLE is fully dynamic and will be re-probed anyway. + if n.State == netlink.NUD_PERMANENT { + log.Debugf("Copying a static ARP entry: %+v %+v", n.IP, n.HardwareAddr) + // No flags are copied because Stack.AddStaticNeighbor does not support flags right now. + neighbors = append(neighbors, boot.Neighbor{IP: n.IP, HardwareAddr: n.HardwareAddr}) + } + } + // Scrape the routes before removing the address, since that // will remove the routes as well. routes, defv4, defv6, err := routesForIface(iface) @@ -203,6 +220,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG RXChecksumOffload: rxChecksumOffload, NumChannels: numNetworkChannels, QDisc: qDisc, + Neighbors: neighbors, } // Get the link for the interface. diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index f748d685a..7952fd969 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1053,3 +1053,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:verity_mount_test", ) + +syscall_test( + test = "//test/syscalls/linux:deleted_test", +) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 6217ff4dc..020c4673a 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -4432,3 +4432,18 @@ cc_binary( "@com_google_absl//absl/container:flat_hash_set", ], ) + +cc_binary( + name = "deleted_test", + testonly = 1, + srcs = ["deleted.cc"], + linkstatic = 1, + deps = [ + "//test/util:file_descriptor", + "//test/util:fs_util", + gtest, + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/syscalls/linux/deleted.cc b/test/syscalls/linux/deleted.cc new file mode 100644 index 000000000..695ceafd3 --- /dev/null +++ b/test/syscalls/linux/deleted.cc @@ -0,0 +1,116 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <time.h> +#include <unistd.h> + +#include <string> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +constexpr mode_t mode = 1; + +namespace gvisor { +namespace testing { + +namespace { + +PosixErrorOr<FileDescriptor> createdDeleted() { + auto path = NewTempAbsPath(); + PosixErrorOr<FileDescriptor> fd = Open(path, O_RDWR | O_CREAT, mode); + if (!fd.ok()) { + return fd.error(); + } + + auto err = Unlink(path); + if (!err.ok()) { + return err; + } + return fd; +} + +TEST(DeletedTest, Utime) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted()); + + const struct timespec times[2] = {{10, 0}, {20, 0}}; + EXPECT_THAT(futimens(fd.get(), times), SyscallSucceeds()); + + struct stat stat; + ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds()); + EXPECT_EQ(10, stat.st_atime); + EXPECT_EQ(20, stat.st_mtime); +} + +TEST(DeletedTest, Chmod) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted()); + + ASSERT_THAT(fchmod(fd.get(), mode + 1), SyscallSucceeds()); + + struct stat stat; + ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds()); + EXPECT_EQ(mode + 1, stat.st_mode & ~S_IFMT); +} + +TEST(DeletedTest, Truncate) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted()); + const std::string data = "foobar"; + ASSERT_THAT(write(fd.get(), data.c_str(), data.size()), SyscallSucceeds()); + + ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds()); + + struct stat stat; + ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds()); + ASSERT_EQ(stat.st_size, 0); +} + +TEST(DeletedTest, Fallocate) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted()); + + ASSERT_THAT(fallocate(fd.get(), 0, 0, 123), SyscallSucceeds()); + + struct stat stat; + ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds()); + EXPECT_EQ(123, stat.st_size); +} + +// Tests that a file can be created with the same path as a deleted file that +// still have an open FD to it. +TEST(DeletedTest, Replace) { + auto path = NewTempAbsPath(); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT, mode)); + ASSERT_NO_ERRNO(Unlink(path)); + + auto other = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT | O_EXCL, mode)); + + auto stat = ASSERT_NO_ERRNO_AND_VALUE(Fstat(fd.get())); + auto stat_other = ASSERT_NO_ERRNO_AND_VALUE(Fstat(other.get())); + ASSERT_NE(stat.st_ino, stat_other.st_ino); + + // Check that the path points to the new file. + stat = ASSERT_NO_ERRNO_AND_VALUE(Stat(path)); + ASSERT_EQ(stat.st_ino, stat_other.st_ino); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index 3c7311782..e2a41d172 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -115,6 +115,40 @@ TEST(MountTest, OpenFileBusy) { EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); } +TEST(MountTest, UmountNoFollow) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + auto const mountPoint = NewTempAbsPathInDir(dir.path()); + ASSERT_THAT(mkdir(mountPoint.c_str(), 0777), SyscallSucceeds()); + + // Create a symlink in dir which will point to the actual mountpoint. + const std::string symlinkInDir = NewTempAbsPathInDir(dir.path()); + EXPECT_THAT(symlink(mountPoint.c_str(), symlinkInDir.c_str()), + SyscallSucceeds()); + + // Create a symlink to the dir. + const std::string symlinkToDir = NewTempAbsPath(); + EXPECT_THAT(symlink(dir.path().c_str(), symlinkToDir.c_str()), + SyscallSucceeds()); + + // Should fail with ELOOP when UMOUNT_NOFOLLOW is specified and the last + // component is a symlink. + auto mount = ASSERT_NO_ERRNO_AND_VALUE( + Mount("", mountPoint, "tmpfs", 0, "mode=0700", 0)); + EXPECT_THAT(umount2(symlinkInDir.c_str(), UMOUNT_NOFOLLOW), + SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(unlink(symlinkInDir.c_str()), SyscallSucceeds()); + + // UMOUNT_NOFOLLOW should only apply to the last path component. A symlink in + // non-last path component should be just fine. + EXPECT_THAT(umount2(JoinPath(symlinkToDir, Basename(mountPoint)).c_str(), + UMOUNT_NOFOLLOW), + SyscallSucceeds()); + mount.Release(); +} + TEST(MountTest, UmountDetach) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 253411858..1c24d9ffc 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -188,6 +188,14 @@ PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, return NoError(); } +PosixError Unlink(absl::string_view path) { + int res = unlink(std::string(path).c_str()); + if (res < 0) { + return PosixError(errno, absl::StrCat("unlink ", path)); + } + return NoError(); +} + PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, int flags) { int res = unlinkat(dfd.get(), std::string(path).c_str(), flags); diff --git a/test/util/fs_util.h b/test/util/fs_util.h index bb2d1d3c8..3ae0a725a 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -71,6 +71,7 @@ PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, dev_t dev); // Unlink the file. +PosixError Unlink(absl::string_view path); PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, int flags); |