From a2c51efe3669f0380042b2375eae79e403d3680c Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 29 Oct 2019 16:13:43 -0700 Subject: Add endpoint tracking to the stack. In the future this will replace DanglingEndpoints. DanglingEndpoints must be kept for now due to issues with save/restore. This is arguably a cleaner design and allows the stack to know which transport endpoints might still be using its link endpoints. Updates #837 PiperOrigin-RevId: 277386633 --- pkg/sentry/socket/netstack/stack.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index d5db8c17c..a0db2d4fd 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -291,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() { func (s *Stack) Resume() { s.Stack.Resume() } + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { + return s.Stack.RegisteredEndpoints() +} + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { + return s.Stack.CleanupEndpoints() +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { + s.Stack.RestoreCleanupEndpoints(es) +} -- cgit v1.2.3 From 66ebb6575f929a389d3c929977ed5e31d706fcfe Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 7 Nov 2019 09:45:26 -0800 Subject: Add support for TIME_WAIT timeout. This change adds explicit support for honoring the 2MSL timeout for sockets in TIME_WAIT state. It also adds support for the TCP_LINGER2 option that allows modification of the FIN_WAIT2 state timeout duration for a given socket. It also adds an option to modify the Stack wide TIME_WAIT timeout but this is only for testing. On Linux this is fixed at 60s. Further, we also now correctly process RST's in CLOSE_WAIT and close the socket similar to linux without moving it to error state. We also now handle SYN in ESTABLISHED state as per RFC5961#section-4.1. Earlier we would just drop these SYNs. Which can result in some tests that pass on linux to fail on gVisor. Netstack now honors TIME_WAIT correctly as well as handles the following cases correctly. - TCP RSTs in TIME_WAIT are ignored. - A duplicate TCP FIN during TIME_WAIT extends the TIME_WAIT and a dup ACK is sent in response to the FIN as the dup FIN indicates potential loss of the original final ACK. - An out of order segment during TIME_WAIT generates a dup ACK. - A new SYN w/ a sequence number > the highest sequence number in the previous connection closes the TIME_WAIT early and opens a new connection. Further to make the SYN case work correctly the ISN (Initial Sequence Number) generation for Netstack has been updated to be as per RFC. Its not a pure random number anymore and follows the recommendation in https://tools.ietf.org/html/rfc6528#page-3. The current hash used is not a cryptographically secure hash function. A separate change will update the hash function used to Siphash similar to what is used in Linux. PiperOrigin-RevId: 279106406 --- pkg/sentry/socket/netstack/netstack.go | 20 + pkg/tcpip/adapters/gonet/gonet_test.go | 12 +- pkg/tcpip/stack/stack.go | 20 +- pkg/tcpip/stack/transport_demuxer.go | 33 +- pkg/tcpip/tcpip.go | 12 +- pkg/tcpip/transport/tcp/BUILD | 2 +- pkg/tcpip/transport/tcp/accept.go | 17 +- pkg/tcpip/transport/tcp/connect.go | 322 ++++++++++++-- pkg/tcpip/transport/tcp/endpoint.go | 101 ++++- pkg/tcpip/transport/tcp/endpoint_state.go | 26 +- pkg/tcpip/transport/tcp/protocol.go | 43 ++ pkg/tcpip/transport/tcp/rcv.go | 167 ++++++- pkg/tcpip/transport/tcp/tcp_test.go | 622 ++++++++++++++++++++++++++- test/syscalls/BUILD | 22 +- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/socket_inet_loopback.cc | 336 +++++++++++++++ test/syscalls/linux/socket_ip_tcp_generic.cc | 93 ++++ 17 files changed, 1736 insertions(+), 113 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 27c6692c4..d92399efd 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1173,6 +1173,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa copy(b, v) return b, nil + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1556,6 +1568,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 8ced960bb..ee077ae83 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -151,10 +151,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -203,10 +201,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 99809df75..2f8d8e822 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -402,11 +402,11 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // portSeed is a one-time random value initialized at stack startup + // seed is a one-time random value initialized at stack startup // and is used to seed the TCP port picking on active connections // // TODO(gvisor.dev/issue/940): S/R this field. - portSeed uint32 + seed uint32 // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations @@ -544,7 +544,7 @@ func New(opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, icmpRateLimiter: NewICMPRateLimiter(), - portSeed: generateRandUint32(), + seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, @@ -1186,6 +1186,12 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { s.mu.Unlock() } +// FindTransportEndpoint finds an endpoint that most closely matches the provided +// id. If no endpoint is found it returns nil. +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, r) +} + // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. @@ -1573,12 +1579,12 @@ func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRoute return nil } -// PortSeed returns a 32 bit value that can be used as a seed value for port -// picking. +// Seed returns a 32 bit value that can be used as a seed value for port +// picking, ISN generation etc. // // NOTE: The seed is generated once during stack initialization only. -func (s *Stack) PortSeed() uint32 { - return s.portSeed +func (s *Stack) Seed() uint32 { + return s.seed } func generateRandUint32() uint32 { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 594570216..cb805522b 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -103,7 +103,6 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } - // multiPortEndpoints are guaranteed to have at least one element. selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. @@ -507,10 +506,40 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr if ep, ok := eps.endpoints[nid]; ok { matchedEPs = append(matchedEPs, ep) } - return matchedEPs } +// findTransportEndpoint find a single endpoint that most closely matches the provided id. +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + return nil + } + // Try to find the endpoint. + eps.mu.RLock() + epsByNic := d.findEndpointLocked(eps, id) + // Fail if we didn't find one. + if epsByNic == nil { + eps.mu.RUnlock() + return nil + } + + epsByNic.mu.RLock() + eps.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return nil + } + } + + ep := selectEndpoint(id, mpep, epsByNic.seed) + epsByNic.mu.RUnlock() + return ep +} + // findEndpointLocked returns the endpoint that most closely matches the given // id. func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3edb513d4..bd5eb89ca 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -586,6 +586,16 @@ type MaxSegOption int // A zero value indicates the default. type TTLOption uint8 +// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state +// before being marked closed. +type TCPLingerTimeoutOption time.Duration + +// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TIME_WAIT state +// before being marked closed. +type TCPTimeWaitTimeoutOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -1329,8 +1339,8 @@ var ( // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { - es := make([]Endpoint, 0, len(danglingEndpoints)) danglingEndpointsMu.Lock() + es := make([]Endpoint, 0, len(danglingEndpoints)) for e := range danglingEndpoints { es = append(es, e) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f1dbc6f91..3f47b328d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -71,7 +71,7 @@ filegroup( go_test( name = "tcp_test", - size = "small", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index cb0e13ebc..0e8e0a2b4 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -269,8 +269,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + isn := generateSecureISN(s.id, l.stack.Seed()) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts) if err != nil { return nil, err } @@ -289,7 +289,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(isn, irs, opts) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -361,6 +361,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() + n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -368,6 +369,11 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header return } ctx.removePendingEndpoint(n) + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } @@ -543,6 +549,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index ca982c451..a114c06c1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "sync" "time" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -139,7 +141,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -809,7 +836,19 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.state = StateError e.HardError = err if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1< + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -856,7 +960,15 @@ func (e *endpoint) handleSegments() *tcpip.Error { // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment // information." - e.rcv.handleRcvdSegment(s) + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + s.decRef() + return err + } + if drop { + s.decRef() + continue + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -955,7 +1067,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() - // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1001,6 +1112,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // RTT itself. e.rcvAutoParams.prevCopied = initialRcvWnd e.rcvListMu.Unlock() + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.mu.Lock() + e.state = StateEstablished + e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -1008,10 +1123,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - if e.state != StateEstablished { - e.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished - } drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1042,7 +1153,13 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP_FIN_WAIT2 timeout was hit. Just + // mark the socket as closed. + e.mu.Lock() + e.state = StateClose + e.mu.Unlock() + return nil }, }, { @@ -1085,17 +1202,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + e.mu.Lock() + if e.state == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { + closeWaker.Assert() + }) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1117,6 +1235,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1143,15 +1267,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() + e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + + for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() @@ -1167,6 +1292,23 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil } + e.mu.Lock() + } + + state := e.state + e.mu.Unlock() + var reuseTW func() + if state == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + reuseTW = e.doTimeWait() } // Mark endpoint as closed. @@ -1176,8 +1318,130 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateClose } + // Lock released below. epilogue() + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + tcpEP.enqueueSegment(s) + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + <-e.undrain + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 79fec6b77..04c92c04c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -320,6 +325,11 @@ type endpoint struct { state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` @@ -503,6 +513,16 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -599,6 +619,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.SetSockOptInt(tcpip.DelayOption, 1) } + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -686,6 +711,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -706,6 +738,8 @@ func (e *endpoint) Close() { e.isPortReserved = false } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -731,9 +765,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -1349,6 +1381,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1562,6 +1616,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1696,7 +1756,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.PortSeed()) + h := jenkins.Sum32(e.stack.Seed()) h.Write([]byte(e.ID.LocalAddress)) h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) @@ -1782,9 +1842,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1799,6 +1858,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1817,14 +1877,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1832,11 +1889,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } @@ -1928,12 +1994,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -2058,6 +2119,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 19f003b6b..7aa4c3f0e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() { } fallthrough case StateError, StateClose: - for e.state == StateError && e.workerRunning { + for (e.state == StateError || e.state == StateClose) && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + // Freeze segment queue before registering to prevent any segments + // from being delivered while it is being restored. + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() + state := e.origEndpointState - state := e.state switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -189,7 +195,6 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - e.state = StateInitial if len(e.BindAddr) == 0 { e.BindAddr = e.ID.LocalAddress } @@ -219,6 +224,16 @@ func (e *endpoint) Resume(s *stack.Stack) { if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if state == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() case StateListen: tcpip.AsyncLoading.Add(1) @@ -265,8 +280,11 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() } - fallthrough + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) case StateError: + e.state = StateError e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c8e4a0d7e..89b965c23 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -23,6 +23,7 @@ package tcp import ( "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,6 +55,14 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second ) // SACKEnabled option can be used to enable SACK support in the TCP @@ -93,6 +102,8 @@ type protocol struct { congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration } // Number returns the tcp protocol number. @@ -212,6 +223,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -262,6 +291,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *tcpip.TCPLingerTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.Unlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -274,5 +315,7 @@ func NewProtocol() stack.TransportProtocol { recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..068b90fb6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -209,6 +210,11 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum switch r.ep.state { case StateFinWait1: r.ep.state = StateFinWait2 + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: r.ep.state = StateTimeWait case StateLastAck: @@ -253,23 +259,105 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait, StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + s.decRef() + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + r.ep.mu.RLock() + state := r.ep.state + closed := r.ep.closed + r.ep.mu.RUnlock() + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil } // Defer segment processing if it can't be consumed now. @@ -288,7 +376,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -315,4 +403,67 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed -= s.logicalLen() s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index f4ea5f091..0c1704d74 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -206,17 +206,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -256,7 +257,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -264,6 +265,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -285,6 +293,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -376,6 +389,13 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -396,12 +416,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -413,7 +445,7 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), ), @@ -1110,8 +1142,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -3085,6 +3116,13 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) @@ -3092,12 +3130,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3135,10 +3173,9 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -3155,7 +3192,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -3166,7 +3203,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -3176,11 +3213,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -4347,7 +4384,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -4893,3 +4931,545 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) } } + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(uint32(ackHeaders.SeqNum)), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) +} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 3e5b6b3c3..722d14b53 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -9,7 +9,7 @@ syscall_test(test = "//test/syscalls/linux:accept_bind_stream_test") syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", ) @@ -434,7 +434,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", ) @@ -445,7 +445,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", ) @@ -458,19 +458,19 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", ) @@ -481,13 +481,13 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", ) syscall_test( size = "medium", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", ) @@ -498,7 +498,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", ) @@ -560,7 +560,7 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", ) @@ -599,7 +599,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 93bff8299..f8b8cb724 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2141,6 +2141,7 @@ cc_library( deps = [ ":socket_test_util", "//test/util:test_util", + "//test/util:thread_util", "@com_google_googletest//:gtest", ], alwayslink = 1, diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index ab375aaaf..2eeee352e 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -31,6 +32,7 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" @@ -267,6 +269,340 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) { } } +// TCPFinWait2Test creates a pair of connected sockets then closes one end to +// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local +// IP/port on a new socket and tries to connect. The connect should fail w/ +// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the +// connect again with a new socket and this time it should succeed. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Lower FIN_WAIT2 state to 5 seconds for test. + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + // Now bind and connect a new socket. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Disable cooperative saves after this point. As a save between the first + // bind/connect and the second one can cause the linger timeout timer to + // be restarted causing the final bind/connect to fail. + DisableSave ds; + + // TODO(gvisor.dev/issue/1030): Portmanager does not track all 5 tuple + // reservations which causes the bind() to succeed on gVisor but connect + // correctly fails. + if (IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } else { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } + + // Sleep for a little over the linger timeout to reduce flakiness in + // save/restore tests. + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + ds.reset(); + + if (!IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + } + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPLinger2TimeoutAfterClose creates a pair of connected sockets +// then closes one end to trigger FIN_WAIT2 state for the closed endpont. +// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ +// connecting the same address succeeds. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPResetAfterClose creates a pair of connected sockets then closes +// one end to trigger FIN_WAIT2 state for the closed endpoint verifies +// that we generate RSTs for any new data after the socket is fully +// closed. +TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + int data = 1234; + + // Now send data which should trigger a RST as the other end should + // have timed out and closed the socket. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceeds()); + // Sleep for a shortwhile to get a RST back. + absl::SleepFor(absl::Seconds(1)); + + // Try writing again and we should get an EPIPE back. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallFailsWithErrno(EPIPE)); + + // Trying to read should return zero as the other end did send + // us a FIN. We do it twice to verify that the RST does not cause an + // ECONNRESET on the read after EOF has been read by applicaiton. + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); +} + +// This test is disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // We disable saves after this point as a S/R causes the netstack seed + // to be regenerated which changes what ports/ISN is picked for a given + // tuple (src ip,src port, dst ip, dst port). This can cause the final + // SYN to use a sequence number that looks like one from the current + // connection in TIME_WAIT and will not be accepted causing the test + // to timeout. + // + // TODO(gvisor.dev/issue/940): S/R portSeed/portHint + DisableSave ds; + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the accept FD to trigger TIME_WAIT on the accepted socket which + // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of + // TIME_WAIT. + accepted.reset(); + absl::SleepFor(absl::Seconds(1)); + conn_fd.reset(); + absl::SleepFor(absl::Seconds(1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 592448289..a37b49447 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -26,6 +26,7 @@ #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -243,6 +244,31 @@ TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) { SyscallSucceedsWithValue(0)); } +// This test verifies that a shutdown(wr) by the server after sending +// data allows the client to still read() the queued data and a client +// close after sending response allows server to read the incoming +// response. +TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char buf[10] = {}; + ScopedThread t([&]() { + ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(close(sockets->release_first_fd()), + SyscallSucceedsWithValue(0)); + }); + ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR), + SyscallSucceedsWithValue(0)); + t.Join(); + + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); +} + TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -696,5 +722,72 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); } +// Linux and Netstack both default to a 60s TCP_LINGER2 timeout. +constexpr int kDefaultTCPLingerTimeout = 60; + +TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + + constexpr int kNegative = -1234; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kNegative, sizeof(kNegative)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kAboveDefault, sizeof(kAboveDefault)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPLingerTimeout); +} + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From b1d44be7ad893bd6bdfd164a54a7142f4462414b Mon Sep 17 00:00:00 2001 From: Mithun Iyer Date: Fri, 6 Dec 2019 17:15:52 -0800 Subject: Add TCP stats for connection close and keep-alive timeouts. Fix bugs in updates to TCP CurrentEstablished stat. Fixes #1277 PiperOrigin-RevId: 284292459 --- pkg/sentry/socket/netstack/netstack.go | 2 ++ pkg/tcpip/tcpip.go | 8 ++++++ pkg/tcpip/transport/tcp/connect.go | 5 ++-- pkg/tcpip/transport/tcp/snd.go | 1 - pkg/tcpip/transport/tcp/tcp_test.go | 46 ++++++++++++++++++++++++++++++++++ 5 files changed, 58 insertions(+), 4 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d92399efd..fe5a46aa3 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -151,6 +151,8 @@ var Metrics = tcpip.Stats{ PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 5746043cc..d5bb5b6ed 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -924,6 +924,14 @@ type TCPStats struct { // ESTABLISHED state or the CLOSE-WAIT state. EstablishedResets *StatCounter + // EstablishedClosed is the number of times established TCP connections + // made a transition to CLOSED state. + EstablishedClosed *StatCounter + + // EstablishedTimedout is the number of times an established connection + // was reset because of keep-alive time out. + EstablishedTimedout *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 2975a1c3c..3d059c302 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -924,6 +924,7 @@ func (e *endpoint) transitionToStateCloseLocked() { } e.cleanupLocked() e.state = StateClose + e.stack.Stats().TCP.EstablishedClosed.Increment() } // tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed @@ -1094,6 +1095,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } @@ -1179,8 +1181,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateError e.HardError = err @@ -1389,7 +1389,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { - e.stack.Stats().TCP.EstablishedResets.Increment() e.stack.Stats().TCP.CurrentEstablished.Decrement() e.transitionToStateCloseLocked() } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index d3f7c9125..8332a0179 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -674,7 +674,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.state = StateFinWait1 } - s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 52c2fa7e3..bc5cfcf0e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) { if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestConnectIncrementActiveConnection(t *testing.T) { @@ -548,6 +562,14 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + // Check if the endpoint was moved to CLOSED and netstack a reset in // response to the ACK packet that we sent after last-ACK. checker.IPv4(t, c.GetPacket(), @@ -2694,6 +2716,13 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { @@ -4363,9 +4392,17 @@ func TestKeepalive(t *testing.T) { ), ) + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -5992,6 +6029,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) } + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { @@ -6120,6 +6159,13 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.SeqNum(uint32(ackHeaders.AckNum)), checker.AckNum(uint32(ackHeaders.SeqNum)), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestTCPCloseWithData(t *testing.T) { -- cgit v1.2.3 From 6fc9f0aefd89ce42ef2c38ea7853f9ba7c4bee04 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 11 Dec 2019 17:51:37 -0800 Subject: Add support for TCP_USER_TIMEOUT option. The implementation follows the linux behavior where specifying a TCP_USER_TIMEOUT will cause the resend timer to honor the user specified timeout rather than the default rto based timeout. Further it alters when connections are timedout due to keepalive failures. It does not alter the behavior of when keepalives are sent. This is as per the linux behavior. PiperOrigin-RevId: 285099795 --- pkg/sentry/socket/netstack/netstack.go | 23 ++++ pkg/tcpip/tcpip.go | 5 + pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 15 +++ pkg/tcpip/transport/tcp/connect.go | 19 ++- pkg/tcpip/transport/tcp/endpoint.go | 19 +++ pkg/tcpip/transport/tcp/protocol.go | 21 ++- pkg/tcpip/transport/tcp/rcv.go | 19 ++- pkg/tcpip/transport/tcp/rcv_state.go | 29 ++++ pkg/tcpip/transport/tcp/snd.go | 48 ++++++- pkg/tcpip/transport/tcp/snd_state.go | 10 ++ pkg/tcpip/transport/tcp/tcp_test.go | 194 ++++++++++++++++++++++++--- test/syscalls/linux/socket_inet_loopback.cc | 56 +++++++- test/syscalls/linux/socket_ip_tcp_generic.cc | 63 +++++++++ test/syscalls/linux/tcp_socket.cc | 25 ++++ 15 files changed, 509 insertions(+), 38 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/rcv_state.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index fe5a46aa3..8a6522eac 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1127,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Millisecond), nil + case linux.TCP_INFO: var v tcpip.TCPInfoOption if err := ep.GetSockOpt(&v); err != nil { @@ -1563,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d5bb5b6ed..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -576,6 +576,11 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user +// specified timeout for a given TCP connection. +// See: RFC5482 for details. +type TCPUserTimeoutOption time.Duration + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 455a1c098..3b353d56c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -28,6 +28,7 @@ go_library( "forwarder.go", "protocol.go", "rcv.go", + "rcv_state.go", "reno.go", "sack.go", "sack_scoreboard.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 74df3edfb..5422ae80c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -242,6 +242,13 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() + // Now inherit any socket options that should be inherited from the + // listening endpoint. + // In case of Forwarder listenEP will be nil and hence this check. + if l.listenEP != nil { + l.listenEP.propagateInheritableOptions(n) + } + // Register new endpoint so that packets are routed to it. if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { n.Close() @@ -350,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } +// propagateInheritableOptions propagates any options set on the listening +// endpoint to the newly created endpoint. +func (e *endpoint) propagateInheritableOptions(n *endpoint) { + e.mu.Lock() + n.userTimeout = e.userTimeout + e.mu.Unlock() +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3d059c302..4c34fc9d2 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -862,7 +862,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } e.state = StateError e.HardError = err - if err != tcpip.ErrConnectionReset { + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk // which can cause sndNxt to be outside the acceptable window on the @@ -1087,12 +1087,24 @@ func (e *endpoint) handleSegments() *tcpip.Error { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.mu.RLock() + userTimeout := e.userTimeout + e.mu.RUnlock() + e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() @@ -1112,7 +1124,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // whether it is enabled for this endpoint. func (e *endpoint) resetKeepaliveTimer(receivedData bool) { e.keepalive.Lock() - defer e.keepalive.Unlock() if receivedData { e.keepalive.unacked = 0 } @@ -1120,6 +1131,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { // data to send. if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() + e.keepalive.Unlock() return } if e.keepalive.unacked > 0 { @@ -1127,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } else { e.keepalive.timer.enable(e.keepalive.idle) } + e.keepalive.Unlock() } // disableKeepaliveTimer stops the keepalive timer. @@ -1239,6 +1252,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.snd.resendWaker, f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } return nil @@ -1405,6 +1419,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if s == nil { break } + e.tryDeliverSegmentFromClosedEndpoint(s) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4861ab513..dd8b47cbe 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -341,6 +341,7 @@ type endpoint struct { // TCP should never broadcast but Linux nevertheless supports enabling/ // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID @@ -474,6 +475,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -1333,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.TCPUserTimeoutOption: + e.mu.Lock() + e.userTimeout = time.Duration(v) + e.mu.Unlock() + return nil + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -1591,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.keepalive.Unlock() return nil + case *tcpip.TCPUserTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.mu.Unlock() + return nil + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 89b965c23..bc718064c 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -162,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo func replyWithReset(s *segment) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. if s.flagIsSet(header.TCPFlagAck) { seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) } - - ack := s.sequenceNumber.Add(s.logicalLen()) - - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 5ee499c36..0a5534959 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -50,16 +50,20 @@ type receiver struct { pendingRcvdSegments segmentHeap pendingBufUsed seqnum.Size pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), - rcvWnd: rcvWnd, - rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), } } @@ -360,6 +364,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { return true, nil } + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go new file mode 100644 index 000000000..2bf21a2e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} +} + +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8332a0179..8a947dc66 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -28,8 +28,11 @@ import ( ) const ( - // minRTO is the minimum allowed value for the retransmit timeout. - minRTO = 200 * time.Millisecond + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second // InitialCwnd is the initial congestion window. InitialCwnd = 10 @@ -134,6 +137,10 @@ type sender struct { // rttMeasureTime is the time when the rttMeasureSeqNum was sent. rttMeasureTime time.Time `state:".(unixTime)"` + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + closed bool writeNext *segment writeList segmentList @@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < minRTO { - s.rto = minRTO + if s.rto < MinRTO { + s.rto = MinRTO } } @@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool { s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() - // Give up if we've waited more than a minute since the last resend. - if s.rto >= 60*time.Second { + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + s.ep.mu.RLock() + uto := s.ep.userTimeout + s.ep.mu.RUnlock() + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } + + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := MaxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + if remaining <= 0 || s.rto >= MaxRTO { return false } @@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool { // below. s.rto *= 2 + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. // // Retransmit timeouts: @@ -1168,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) { // RFC 6298 Rule 5.3 if s.sndUna == s.sndNxt { s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() } } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 12eff8afc..8b20c3455 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) } + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index bc5cfcf0e..2a83f7bcc 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -323,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) } func TestTCPResetsReceivedIncrement(t *testing.T) { @@ -460,18 +460,17 @@ func TestConnectResetAfterClose(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) break } } -// TestClosingWithEnqueuedSegments tests handling of -// still enqueued segments when the endpoint transitions -// to StateClose. The in-flight segments would be re-enqueued -// to a any listening endpoint. +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. func TestClosingWithEnqueuedSegments(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -576,8 +575,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) } @@ -914,7 +913,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagRst), checker.SeqNum(200))) } @@ -942,7 +941,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagRst), checker.SeqNum(200))) } @@ -4291,8 +4290,9 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) @@ -4382,13 +4382,29 @@ func TestKeepalive(t *testing.T) { ) } + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), ) @@ -6157,8 +6173,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) @@ -6336,7 +6352,147 @@ func TestTCPCloseWithData(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + userTimeout := 50 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for a little over the minimum retransmit timeout of 200ms for + // the retransmitTimer to fire and close the connection. + time.Sleep(tcp.MinRTO + 10*time.Millisecond) + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 10 * time.Millisecond + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) + c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + + // Set userTimeout to be the duration for 3 keepalive probes. + userTimeout := 30 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Now receive 2 keepalives, but don't ACK them. The connection should + // be reset when the 3rd one should be sent due to userTimeout being + // 30ms and each keepalive probe should be sent 10ms apart as set above after + // the connection has been idle for 10ms. + for i := 0; i < 2; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + + // The connection should be terminated after 30ms. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index fa4358ae4..761c3a9fe 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -206,7 +206,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked // before function end. - // ds.reset() + // ds.reset(); } TEST_P(SocketInetLoopbackTest, TCPbacklog) { @@ -603,6 +603,60 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { SyscallSucceeds()); } +TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the userTimeout on the listening socket. + constexpr int kUserTimeout = 10; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kUserTimeout, sizeof(kUserTimeout)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + // Verify that the accepted socket inherited the user timeout set on + // listening socket. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kUserTimeout); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index c74273436..57ce8e169 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -812,5 +812,68 @@ TEST_P(TCPSocketPairTest, TestTCPCloseWithData) { ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); } +TEST_P(TCPSocketPairTest, TCPUserTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kZero, sizeof(kZero)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutBelowZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kNeg = -10; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kNeg, sizeof(kNeg)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAbove = 10; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kAbove, sizeof(kAbove)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kAbove); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 99863b0ed..c503f3568 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1175,6 +1175,31 @@ TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) { } } +TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + { + constexpr int kTCPUserTimeout = -1; + EXPECT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallFailsWithErrno(EINVAL)); + } + + // kTCPUserTimeout is in milliseconds. + constexpr int kTCPUserTimeout = 100; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPUserTimeout); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 378d6c1f3697b8b939e6632e980562bfc8fb2781 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Thu, 12 Dec 2019 11:07:25 -0800 Subject: unix: allow to bind unix sockets only to AF_UNIX addresses Reported-by: syzbot+2c0bcfd87fb4e8b7b009@syzkaller.appspotmail.com PiperOrigin-RevId: 285228312 --- pkg/sentry/socket/netstack/netstack.go | 2 +- pkg/sentry/socket/unix/unix.go | 3 +++ test/syscalls/linux/socket_unix.cc | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8a6522eac..140851c17 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -326,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { + if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1aaae8487..885758054 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -118,6 +118,9 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { func extractPath(sockaddr []byte) (string, *syserr.Error) { addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { + if err == syserr.ErrAddressFamilyNotSupported { + err = syserr.ErrInvalidArgument + } return "", err } diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 8a28202a8..4cf1f76f1 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -65,6 +65,21 @@ TEST_P(UnixSocketPairTest, BindToBadName) { SyscallFailsWithErrno(ENOENT)); } +TEST_P(UnixSocketPairTest, BindToBadFamily) { + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + constexpr char kBadName[] = "/some/path/that/does/not/exist"; + sockaddr_un sockaddr; + sockaddr.sun_family = AF_INET; + memcpy(sockaddr.sun_path, kBadName, sizeof(kBadName)); + + EXPECT_THAT( + bind(pair->first_fd(), reinterpret_cast(&sockaddr), + sizeof(sockaddr)), + SyscallFailsWithErrno(EINVAL)); +} + TEST_P(UnixSocketPairTest, RecvmmsgTimeoutAfterRecv) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[10]; -- cgit v1.2.3 From e013c48c78c9a7daf245b7de9563e3a0bd8a1e97 Mon Sep 17 00:00:00 2001 From: Ryan Heacock Date: Tue, 24 Dec 2019 08:48:14 -0800 Subject: Enable IP_RECVTOS socket option for datagram sockets Added the ability to get/set the IP_RECVTOS socket option on UDP endpoints. If enabled, TOS from the incoming Network Header passed as ancillary data in the ControlMessages. Test: * Added unit test to udp_test.go that tests getting/setting as well as verifying that we receive expected TOS from incoming packet. * Added a syscall test PiperOrigin-RevId: 287029703 --- pkg/sentry/socket/control/control.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 42 ++++++++++++++++- pkg/tcpip/checker/checker.go | 16 +++++++ pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/tcpip.go | 6 ++- pkg/tcpip/transport/raw/endpoint.go | 2 +- pkg/tcpip/transport/udp/endpoint.go | 31 ++++++++++++- pkg/tcpip/transport/udp/udp_test.go | 69 ++++++++++++++++++++++++---- test/syscalls/linux/socket_ip_udp_generic.cc | 40 ++++++++++++++++ test/syscalls/linux/udp_socket_test_cases.cc | 8 ++-- 11 files changed, 201 insertions(+), 19 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index af1a4e95f..b649dd021 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -327,7 +327,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { } // PackTOS packs an IP_TOS socket control message. -func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { +func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 140851c17..d2f263402 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1323,6 +1323,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return int32(v), nil + case linux.IP_RECVTOS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.ReceiveTOSOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + if v { + return int32(1), nil + } + return int32(0), nil + default: emitUnimplementedEventIP(t, name) } @@ -1808,6 +1823,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + case linux.IP_RECVTOS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOpt( + tcpip.ReceiveTOSOption(v != 0), + )) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1828,7 +1853,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVORIGDSTADDR, - linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2139,6 +2163,21 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func (s *SocketOperations) fillCmsgTOS(cmsg *socket.ControlMessages) { + if s.skType != linux.SOCK_DGRAM { + return + } + var receiveTOS tcpip.ReceiveTOSOption + if err := s.Endpoint.GetSockOpt(&receiveTOS); err != nil { + return + } + if !receiveTOS { + return + } + cmsg.IP.HasTOS = s.readCM.HasTOS + cmsg.IP.TOS = s.readCM.TOS +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2244,6 +2283,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) + s.fillCmsgTOS(&cmsg) return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 2f15bf1f1..542abc99d 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -33,6 +33,9 @@ type NetworkChecker func(*testing.T, []header.Network) // TransportChecker is a function to check a property of a transport packet. type TransportChecker func(*testing.T, header.Transport) +// ControlMessagesChecker is a function to check a property of ancillary data. +type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) + // IPv4 checks the validity and properties of the given IPv4 packet. It is // expected to be used in conjunction with other network checkers for specific // properties. For example, to check the source and destination address, one @@ -158,6 +161,19 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. +func ReceiveTOS(want uint8) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTOS { + t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + } + if got := cm.TOS; got != want { + t.Fatalf("got cm.TOS = %d, want %d", got, want) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ddd014658..a4556674b 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -575,7 +575,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Unlock() } -// Subnets returns the Subnets associated with this NIC. +// AddressRanges returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7a9600679..251336224 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -829,7 +829,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } -// NICSubnets returns a map of NICIDs to their associated subnets. +// NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f62fd729f..5c7b2af88 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -322,7 +322,7 @@ type ControlMessages struct { HasTOS bool // TOS is the IPv4 type of service of the associated packet. - TOS int8 + TOS uint8 // HasTClass indicates whether Tclass is valid/set. HasTClass bool @@ -666,6 +666,10 @@ type IPv4TOSOption uint8 // for all subsequent outgoing IPv6 packets from the endpoint. type IPv6TrafficClassOption uint8 +// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS +// ancillary message is passed with incoming packets. +type ReceiveTOSOption bool + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5aafe2615..6d23ab5a1 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -510,7 +510,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1ac4705af..269470ed4 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -32,6 +32,7 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -114,6 +115,10 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 + // receiveTOS determines if the incoming IPv4 TOS header field is passed + // as ancillary data to ControlMessages on Read. + receiveTOS bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -244,7 +249,12 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess *addr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + return p.data.ToView(), tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + HasTOS: e.receiveTOS, + TOS: p.tos, + }, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -656,6 +666,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sendTOS = uint8(v) e.mu.Unlock() return nil + + case tcpip.ReceiveTOSOption: + e.mu.Lock() + e.receiveTOS = bool(v) + e.mu.Unlock() + return nil } return nil } @@ -792,6 +808,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil + case *tcpip.ReceiveTOSOption: + e.mu.RLock() + *o = tcpip.ReceiveTOSOption(e.receiveTOS) + e.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1238,6 +1260,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvList.PushBack(packet) e.rcvBufSize += pkt.Data.Size() + // Save any useful information from the NetworkHeader to the packet. + switch r.NetProto { + case header.IPv4ProtocolNumber: + // This packet has already been validated before being passed up the stack. + packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + } + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 7051a7a9c..43b8b35ba 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -56,6 +56,7 @@ const ( multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast + testTOS = 0x80 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -453,6 +454,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, + TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), @@ -556,8 +558,8 @@ func TestBindToDeviceOption(t *testing.T) { // testReadInternal sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its -// correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { +// correctness including any additional checker functions provided. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { c.t.Helper() payload := newPayload() @@ -572,12 +574,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) + v, cm, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, _, err = c.ep.Read(&addr) + v, cm, err = c.ep.Read(&addr) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -610,15 +612,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + + // Run any checkers against the ControlMessages. + for _, f := range checkers { + f(c.t, cm) + } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it // into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// its correctness including any additional checker functions provided. +func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) } // testFailingRead sends a packet of the given test flow into the stack by @@ -1286,7 +1294,7 @@ func TestTOSV4(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1321,7 +1329,7 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1348,6 +1356,49 @@ func TestTOSV6(t *testing.T) { } } +func TestReceiveTOSV4(t *testing.T) { + for _, flow := range []testFlow{unicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Verify that setting and reading the option works. + const recvTos = true + var v tcpip.ReceiveTOSOption + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, false) + } + + if err := c.ep.SetSockOpt(tcpip.ReceiveTOSOption(recvTos)); err != nil { + c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.ReceiveTOSOption(recvTos), err) + } + + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.ReceiveTOSOption(recvTos); v != want { + c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, want) + } + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Verify that the correct received TOS is actually handed through as + // ancillary data to the ControlMessages struct. + testRead(c, flow, checker.ReceiveTOS(testTOS)) + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 66eb68857..53290bed7 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -209,6 +209,46 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } +// Ensure that Receiving TOS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS works as expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index dc35c2f50..68e0a8109 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1349,8 +1349,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. - SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. + SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && + !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1421,7 +1422,8 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. + // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); -- cgit v1.2.3 From 87e4d03fdf576348ac7023c599e0fc66ad4cccbd Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 26 Dec 2019 13:04:14 -0800 Subject: Automated rollback of changelist 287029703 PiperOrigin-RevId: 287217899 --- pkg/sentry/socket/control/control.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 42 +---------------- pkg/tcpip/checker/checker.go | 16 ------- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/tcpip.go | 6 +-- pkg/tcpip/transport/raw/endpoint.go | 2 +- pkg/tcpip/transport/udp/endpoint.go | 31 +------------ pkg/tcpip/transport/udp/udp_test.go | 69 ++++------------------------ test/syscalls/linux/socket_ip_udp_generic.cc | 40 ---------------- test/syscalls/linux/udp_socket_test_cases.cc | 8 ++-- 11 files changed, 19 insertions(+), 201 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index b649dd021..af1a4e95f 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -327,7 +327,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { } // PackTOS packs an IP_TOS socket control message. -func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { +func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d2f263402..140851c17 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1323,21 +1323,6 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return int32(v), nil - case linux.IP_RECVTOS: - if outLen < sizeOfInt32 { - return nil, syserr.ErrInvalidArgument - } - - var v tcpip.ReceiveTOSOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - if v { - return int32(1), nil - } - return int32(0), nil - default: emitUnimplementedEventIP(t, name) } @@ -1823,16 +1808,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) - case linux.IP_RECVTOS: - v, err := parseIntOrChar(optVal) - if err != nil { - return err - } - - return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.ReceiveTOSOption(v != 0), - )) - case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1853,6 +1828,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVORIGDSTADDR, + linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2163,21 +2139,6 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } -func (s *SocketOperations) fillCmsgTOS(cmsg *socket.ControlMessages) { - if s.skType != linux.SOCK_DGRAM { - return - } - var receiveTOS tcpip.ReceiveTOSOption - if err := s.Endpoint.GetSockOpt(&receiveTOS); err != nil { - return - } - if !receiveTOS { - return - } - cmsg.IP.HasTOS = s.readCM.HasTOS - cmsg.IP.TOS = s.readCM.TOS -} - // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2283,7 +2244,6 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) - s.fillCmsgTOS(&cmsg) return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 542abc99d..2f15bf1f1 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -33,9 +33,6 @@ type NetworkChecker func(*testing.T, []header.Network) // TransportChecker is a function to check a property of a transport packet. type TransportChecker func(*testing.T, header.Transport) -// ControlMessagesChecker is a function to check a property of ancillary data. -type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) - // IPv4 checks the validity and properties of the given IPv4 packet. It is // expected to be used in conjunction with other network checkers for specific // properties. For example, to check the source and destination address, one @@ -161,19 +158,6 @@ func FragmentFlags(flags uint8) NetworkChecker { } } -// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. -func ReceiveTOS(want uint8) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTOS { - t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) - } - if got := cm.TOS; got != want { - t.Fatalf("got cm.TOS = %d, want %d", got, want) - } - } -} - // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a4556674b..ddd014658 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -575,7 +575,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Unlock() } -// AddressRanges returns the Subnets associated with this NIC. +// Subnets returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 251336224..7a9600679 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -829,7 +829,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } -// NICAddressRanges returns a map of NICIDs to their associated subnets. +// NICSubnets returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 5c7b2af88..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -322,7 +322,7 @@ type ControlMessages struct { HasTOS bool // TOS is the IPv4 type of service of the associated packet. - TOS uint8 + TOS int8 // HasTClass indicates whether Tclass is valid/set. HasTClass bool @@ -666,10 +666,6 @@ type IPv4TOSOption uint8 // for all subsequent outgoing IPv6 packets from the endpoint. type IPv6TrafficClassOption uint8 -// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS -// ancillary message is passed with incoming packets. -type ReceiveTOSOption bool - // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 6d23ab5a1..5aafe2615 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -510,7 +510,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 269470ed4..1ac4705af 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -32,7 +32,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -115,10 +114,6 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - // receiveTOS determines if the incoming IPv4 TOS header field is passed - // as ancillary data to ControlMessages on Read. - receiveTOS bool - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -249,12 +244,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess *addr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{ - HasTimestamp: true, - Timestamp: p.timestamp, - HasTOS: e.receiveTOS, - TOS: p.tos, - }, nil + return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -666,12 +656,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sendTOS = uint8(v) e.mu.Unlock() return nil - - case tcpip.ReceiveTOSOption: - e.mu.Lock() - e.receiveTOS = bool(v) - e.mu.Unlock() - return nil } return nil } @@ -808,12 +792,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil - case *tcpip.ReceiveTOSOption: - e.mu.RLock() - *o = tcpip.ReceiveTOSOption(e.receiveTOS) - e.mu.RUnlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -1260,13 +1238,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvList.PushBack(packet) e.rcvBufSize += pkt.Data.Size() - // Save any useful information from the NetworkHeader to the packet. - switch r.NetProto { - case header.IPv4ProtocolNumber: - // This packet has already been validated before being passed up the stack. - packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() - } - packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 43b8b35ba..7051a7a9c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -56,7 +56,6 @@ const ( multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast - testTOS = 0x80 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -454,7 +453,6 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, - TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), @@ -558,8 +556,8 @@ func TestBindToDeviceOption(t *testing.T) { // testReadInternal sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its -// correctness including any additional checker functions provided. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { +// correctness. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { c.t.Helper() payload := newPayload() @@ -574,12 +572,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var addr tcpip.FullAddress - v, cm, err := c.ep.Read(&addr) + v, _, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, cm, err = c.ep.Read(&addr) + v, _, err = c.ep.Read(&addr) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -612,21 +610,15 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } - - // Run any checkers against the ControlMessages. - for _, f := range checkers { - f(c.t, cm) - } - c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it // into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness including any additional checker functions provided. -func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { +// its correctness. +func testRead(c *testContext, flow testFlow) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) } // testFailingRead sends a packet of the given test flow into the stack by @@ -1294,7 +1286,7 @@ func TestTOSV4(t *testing.T) { c.createEndpointForFlow(flow) - const tos = testTOS + const tos = 0xC0 var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1329,7 +1321,7 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = testTOS + const tos = 0xC0 var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1356,49 +1348,6 @@ func TestTOSV6(t *testing.T) { } } -func TestReceiveTOSV4(t *testing.T) { - for _, flow := range []testFlow{unicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Verify that setting and reading the option works. - const recvTos = true - var v tcpip.ReceiveTOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, false) - } - - if err := c.ep.SetSockOpt(tcpip.ReceiveTOSOption(recvTos)); err != nil { - c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.ReceiveTOSOption(recvTos), err) - } - - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - - if want := tcpip.ReceiveTOSOption(recvTos); v != want { - c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, want) - } - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Verify that the correct received TOS is actually handed through as - // ancillary data to the ControlMessages struct. - testRead(c, flow, checker.ReceiveTOS(testTOS)) - }) - } -} - func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 53290bed7..66eb68857 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -209,46 +209,6 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } -// Ensure that Receiving TOS is off by default. -TEST_P(UDPSocketPairTest, RecvTosDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -// Test that setting and getting IP_RECVTOS works as expected. -TEST_P(UDPSocketPairTest, SetRecvTos) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index 68e0a8109..dc35c2f50 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1349,9 +1349,8 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. - SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && - !IsRunningWithHostinet()); + // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1422,8 +1421,7 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. - // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. + // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); -- cgit v1.2.3 From 8cc1c35bbdc5c9bd6b3965311497885ce72317a8 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 12 Dec 2019 15:48:24 -0800 Subject: Write simple ACCEPT rules to the filter table. This gets us closer to passing the iptables tests and opens up iptables so it can be worked on by multiple people. A few restrictions are enforced for security (i.e. we don't want to let users write a bunch of iptables rules and then just not enforce them): - Only the filter table is writable. - Only ACCEPT rules with no matching criteria can be added. --- go.mod | 31 +-- go.sum | 10 + pkg/abi/linux/netfilter.go | 82 +++--- pkg/sentry/socket/netfilter/BUILD | 1 + pkg/sentry/socket/netfilter/netfilter.go | 411 ++++++++++++++++++++++++------- pkg/sentry/socket/netstack/netstack.go | 23 +- pkg/tcpip/iptables/iptables.go | 114 ++++++--- pkg/tcpip/iptables/targets.go | 8 + pkg/tcpip/iptables/types.go | 55 ++--- test/iptables/filter_input.go | 1 + test/iptables/iptables_test.go | 23 +- 11 files changed, 550 insertions(+), 209 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/go.mod b/go.mod index 304b8bf13..4802359f8 100644 --- a/go.mod +++ b/go.mod @@ -3,19 +3,20 @@ module gvisor.dev/gvisor go 1.13 require ( - github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 - github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 - github.com/golang/mock v1.3.1 - github.com/golang/protobuf v1.3.1 - github.com/google/btree v1.0.0 - github.com/google/go-cmp v0.2.0 - github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 - github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d - github.com/kr/pty v1.1.1 - github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 - github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 - github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e - github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 - golang.org/x/net v0.0.0-20190311183353-d8887717615a - golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a + github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 + github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 + github.com/golang/mock v1.3.1 + github.com/golang/protobuf v1.3.1 + github.com/google/btree v1.0.0 + github.com/google/go-cmp v0.2.0 + github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 + github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d + github.com/kr/pty v1.1.1 + github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 + github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 + github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e + github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 + golang.org/x/net v0.0.0-20190311183353-d8887717615a + golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 ) diff --git a/go.sum b/go.sum index 7a0bc175a..cf092956e 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,29 @@ +github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8= github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI= github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 269ba5567..0bcb232de 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -42,6 +42,13 @@ const ( NF_RETURN = -NF_REPEAT - 1 ) +var VerdictStrings = map[int32]string{ + -NF_DROP - 1: "DROP", + -NF_ACCEPT - 1: "ACCEPT", + -NF_QUEUE - 1: "QUEUE", + NF_RETURN: "RETURN", +} + // Socket options. These correspond to values in // include/uapi/linux/netfilter_ipv4/ip_tables.h. const ( @@ -179,7 +186,7 @@ const SizeOfXTCounters = 16 // the user data. type XTEntryMatch struct { MatchSize uint16 - Name [XT_EXTENSION_MAXNAMELEN]byte + Name ExtensionName Revision uint8 // Data is omitted here because it would cause XTEntryMatch to be an // extra byte larger (see http://www.catb.org/esr/structure-packing/). @@ -199,7 +206,7 @@ const SizeOfXTEntryMatch = 32 // the user data. type XTEntryTarget struct { TargetSize uint16 - Name [XT_EXTENSION_MAXNAMELEN]byte + Name ExtensionName Revision uint8 // Data is omitted here because it would cause XTEntryTarget to be an // extra byte larger (see http://www.catb.org/esr/structure-packing/). @@ -226,9 +233,9 @@ const SizeOfXTStandardTarget = 40 // ErrorName. It corresponds to struct xt_error_target in // include/uapi/linux/netfilter/x_tables.h. type XTErrorTarget struct { - Target XTEntryTarget - ErrorName [XT_FUNCTION_MAXNAMELEN]byte - _ [2]byte + Target XTEntryTarget + Name ErrorName + _ [2]byte } // SizeOfXTErrorTarget is the size of an XTErrorTarget. @@ -237,7 +244,7 @@ const SizeOfXTErrorTarget = 64 // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. type IPTGetinfo struct { - Name [XT_TABLE_MAXNAMELEN]byte + Name TableName ValidHooks uint32 HookEntry [NF_INET_NUMHOOKS]uint32 Underflow [NF_INET_NUMHOOKS]uint32 @@ -248,16 +255,11 @@ type IPTGetinfo struct { // SizeOfIPTGetinfo is the size of an IPTGetinfo. const SizeOfIPTGetinfo = 84 -// TableName returns the table name. -func (info *IPTGetinfo) TableName() string { - return tableName(info.Name[:]) -} - // IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It // corresponds to struct ipt_get_entries in // include/uapi/linux/netfilter_ipv4/ip_tables.h. type IPTGetEntries struct { - Name [XT_TABLE_MAXNAMELEN]byte + Name TableName Size uint32 _ [4]byte // Entrytable is omitted here because it would cause IPTGetEntries to @@ -266,34 +268,22 @@ type IPTGetEntries struct { // Entrytable [0]IPTEntry } -// TableName returns the entries' table name. -func (entries *IPTGetEntries) TableName() string { - return tableName(entries.Name[:]) -} - // SizeOfIPTGetEntries is the size of an IPTGetEntries. const SizeOfIPTGetEntries = 40 -// KernelIPTGetEntries is identical to IPTEntry, but includes the Elems field. -// This struct marshaled via the binary package to write an KernelIPTGetEntries -// to userspace. +// KernelIPTGetEntries is identical to IPTGetEntries, but includes the +// Entrytable field. This struct marshaled via the binary package to write an +// KernelIPTGetEntries to userspace. type KernelIPTGetEntries struct { - Name [XT_TABLE_MAXNAMELEN]byte - Size uint32 - _ [4]byte + IPTGetEntries Entrytable []KernelIPTEntry } -// TableName returns the entries' table name. -func (entries *KernelIPTGetEntries) TableName() string { - return tableName(entries.Name[:]) -} - // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. type IPTReplace struct { - Name [XT_TABLE_MAXNAMELEN]byte + Name TableName ValidHooks uint32 NumEntries uint32 Size uint32 @@ -306,14 +296,40 @@ type IPTReplace struct { // Entries [0]IPTEntry } +type KernelIPTReplace struct { + IPTReplace + Entries [0]IPTEntry +} + // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 -func tableName(name []byte) string { - for i, c := range name { +type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte + +// String implements fmt.Stringer. +func (en ExtensionName) String() string { + return name(en[:]) +} + +type TableName [XT_TABLE_MAXNAMELEN]byte + +// String implements fmt.Stringer. +func (tn TableName) String() string { + return name(tn[:]) +} + +type ErrorName [XT_FUNCTION_MAXNAMELEN]byte + +// String implements fmt.Stringer. +func (fn ErrorName) String() string { + return name(fn[:]) +} + +func name(cstring []byte) string { + for i, c := range cstring { if c == 0 { - return string(name[:i]) + return string(cstring[:i]) } } - return string(name) + return string(cstring) } diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 5eb06bbf4..b70047d81 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -14,6 +14,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/log", "//pkg/sentry/kernel", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 9f87c32f1..8c7f3c7fc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" @@ -35,6 +36,7 @@ const errorTargetName = "ERROR" // metadata is opaque to netstack. It holds data that we need to translate // between Linux's and netstack's iptables representations. +// TODO(gvisor.dev/issue/170): This might be removable. type metadata struct { HookEntry [linux.NF_INET_NUMHOOKS]uint32 Underflow [linux.NF_INET_NUMHOOKS]uint32 @@ -51,7 +53,7 @@ func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTG } // Find the appropriate table. - table, err := findTable(ep, info.TableName()) + table, err := findTable(ep, info.Name.String()) if err != nil { return linux.IPTGetinfo{}, err } @@ -82,18 +84,19 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i } // Find the appropriate table. - table, err := findTable(ep, userEntries.TableName()) + table, err := findTable(ep, userEntries.Name.String()) if err != nil { return linux.KernelIPTGetEntries{}, err } // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, _, err := convertNetstackToBinary(userEntries.TableName(), table) + entries, _, err := convertNetstackToBinary(userEntries.Name.String(), table) if err != nil { return linux.KernelIPTGetEntries{}, err } if binary.Size(entries) > uintptr(outLen) { + log.Infof("Insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } @@ -142,103 +145,63 @@ func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPT // The table name has to fit in the struct. if linux.XT_TABLE_MAXNAMELEN < len(name) { + log.Infof("Table name too long.") return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument } copy(entries.Name[:], name) - // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of - // these chains ends with an unconditional policy entry. - for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ { - chain, ok := table.BuiltinChains[hook] - if !ok { - // This table doesn't support this hook. - continue - } - - // Sanity check. - if len(chain.Rules) < 1 { - return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument - } - - for ruleIdx, rule := range chain.Rules { - // If this is the first rule of a builtin chain, set - // the metadata hook entry point. - if ruleIdx == 0 { + for ruleIdx, rule := range table.Rules { + // Is this a chain entry point? + for hook, hookRuleIdx := range table.BuiltinChains { + if hookRuleIdx == ruleIdx { meta.HookEntry[hook] = entries.Size } - - // Each rule corresponds to an entry. - entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ - NextOffset: linux.SizeOfIPTEntry, - TargetOffset: linux.SizeOfIPTEntry, - }, + } + // Is this a chain underflow point? The underflow rule is the last rule + // in the chain, and is an unconditional rule (i.e. it matches any + // packet). This is enforced when saving iptables. + for underflow, underflowRuleIdx := range table.Underflows { + if underflowRuleIdx == ruleIdx { + meta.Underflow[underflow] = entries.Size } + } - for _, matcher := range rule.Matchers { - // Serialize the matcher and add it to the - // entry. - serialized := marshalMatcher(matcher) - entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) - } + // Each rule corresponds to an entry. + entry := linux.KernelIPTEntry{ + IPTEntry: linux.IPTEntry{ + NextOffset: linux.SizeOfIPTEntry, + TargetOffset: linux.SizeOfIPTEntry, + }, + } - // Serialize and append the target. - serialized := marshalTarget(rule.Target) + for _, matcher := range rule.Matchers { + // Serialize the matcher and add it to the + // entry. + serialized := marshalMatcher(matcher) entry.Elems = append(entry.Elems, serialized...) entry.NextOffset += uint16(len(serialized)) - - // The underflow rule is the last rule in the chain, - // and is an unconditional rule (i.e. it matches any - // packet). This is enforced when saving iptables. - if ruleIdx == len(chain.Rules)-1 { - meta.Underflow[hook] = entries.Size - } - - entries.Size += uint32(entry.NextOffset) - entries.Entrytable = append(entries.Entrytable, entry) - meta.NumEntries++ + entry.TargetOffset += uint16(len(serialized)) } - } + // Serialize and append the target. + serialized := marshalTarget(rule.Target) + entry.Elems = append(entry.Elems, serialized...) + entry.NextOffset += uint16(len(serialized)) - // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of - // these starts with an error node holding the chain's name and ends - // with an unconditional return. - - // Lastly, each table ends with an unconditional error target rule as - // its final entry. - errorEntry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ - NextOffset: linux.SizeOfIPTEntry, - TargetOffset: linux.SizeOfIPTEntry, - }, + entries.Size += uint32(entry.NextOffset) + entries.Entrytable = append(entries.Entrytable, entry) + meta.NumEntries++ } - var errorTarget linux.XTErrorTarget - errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget - copy(errorTarget.ErrorName[:], errorTargetName) - copy(errorTarget.Target.Name[:], errorTargetName) - - // Serialize and add it to the list of entries. - errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget) - serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget) - errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...) - errorEntry.NextOffset += uint16(len(serializedErrorTarget)) - - entries.Size += uint32(errorEntry.NextOffset) - entries.Entrytable = append(entries.Entrytable, errorEntry) - meta.NumEntries++ - meta.Size = entries.Size + meta.Size = entries.Size return entries, meta, nil } func marshalMatcher(matcher iptables.Matcher) []byte { switch matcher.(type) { default: - // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so - // any call to marshalMatcher will panic. + // TODO(gvisor.dev/issue/170): We don't support any matchers + // yet, so any call to marshalMatcher will panic. panic(fmt.Errorf("unknown matcher of type %T", matcher)) } } @@ -246,28 +209,46 @@ func marshalMatcher(matcher iptables.Matcher) []byte { func marshalTarget(target iptables.Target) []byte { switch target.(type) { case iptables.UnconditionalAcceptTarget: - return marshalUnconditionalAcceptTarget() + return marshalStandardTarget(iptables.Accept) + case iptables.UnconditionalDropTarget: + return marshalStandardTarget(iptables.Drop) + case iptables.PanicTarget: + return marshalPanicTarget() default: panic(fmt.Errorf("unknown target of type %T", target)) } } -func marshalUnconditionalAcceptTarget() []byte { +func marshalStandardTarget(verdict iptables.Verdict) []byte { // The target's name will be the empty string. target := linux.XTStandardTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTStandardTarget, }, - Verdict: translateStandardVerdict(iptables.Accept), + Verdict: translateFromStandardVerdict(verdict), } ret := make([]byte, 0, linux.SizeOfXTStandardTarget) return binary.Marshal(ret, usermem.ByteOrder, target) } -// translateStandardVerdict translates verdicts the same way as the iptables +func marshalPanicTarget() []byte { + // This is an error target named error + target := linux.XTErrorTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTErrorTarget, + }, + } + copy(target.Name[:], errorTargetName) + copy(target.Target.Name[:], errorTargetName) + + ret := make([]byte, 0, linux.SizeOfXTErrorTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +// translateFromStandardVerdict translates verdicts the same way as the iptables // tool. -func translateStandardVerdict(verdict iptables.Verdict) int32 { +func translateFromStandardVerdict(verdict iptables.Verdict) int32 { switch verdict { case iptables.Accept: return -linux.NF_ACCEPT - 1 @@ -280,7 +261,269 @@ func translateStandardVerdict(verdict iptables.Verdict) int32 { case iptables.Jump: // TODO(gvisor.dev/issue/170): Support Jump. panic("Jump isn't supported yet") + } + panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) +} + +// translateToStandardVerdict translates from the value in a +// linux.XTStandardTarget to an iptables.Verdict. +func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) { + // TODO(gvisor.dev/issue/170): Support other verdicts. + switch val { + case -linux.NF_ACCEPT - 1: + return iptables.Accept, nil + case -linux.NF_DROP - 1: + return iptables.Drop, nil + case -linux.NF_QUEUE - 1: + log.Infof("Unsupported iptables verdict QUEUE.") + case linux.NF_RETURN: + log.Infof("Unsupported iptables verdict RETURN.") + } + log.Infof("Unknown iptables verdict %d.", val) + return iptables.Invalid, syserr.ErrInvalidArgument +} + +// SetEntries sets iptables rules for a single table. See +// net/ipv4/netfilter/ip_tables.c:translate_table for reference. +func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { + printReplace(optVal) + + // Get the basic rules data (struct ipt_replace). + if len(optVal) < linux.SizeOfIPTReplace { + return syserr.ErrInvalidArgument + } + var replace linux.IPTReplace + replaceBuf := optVal[:linux.SizeOfIPTReplace] + optVal = optVal[linux.SizeOfIPTReplace:] + binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace) + + // TODO(gvisor.dev/issue/170): Support other tables. + var table iptables.Table + switch replace.Name.String() { + case iptables.TablenameFilter: + table = iptables.EmptyFilterTable() default: - panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) + log.Infof(fmt.Sprintf("We don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())) + return syserr.ErrInvalidArgument + } + + // Convert input into a list of rules and their offsets. + var offset uint32 + var offsets []uint32 + for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { + // Get the struct ipt_entry. + if len(optVal) < linux.SizeOfIPTEntry { + return syserr.ErrInvalidArgument + } + var entry linux.IPTEntry + buf := optVal[:linux.SizeOfIPTEntry] + optVal = optVal[linux.SizeOfIPTEntry:] + binary.Unmarshal(buf, usermem.ByteOrder, &entry) + if entry.TargetOffset != linux.SizeOfIPTEntry { + // TODO(gvisor.dev/issue/170): Support matchers. + return syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): We should support IPTIP + // filtering. We reject any nonzero IPTIP values for now. + emptyIPTIP := linux.IPTIP{} + if entry.IP != emptyIPTIP { + return syserr.ErrInvalidArgument + } + + // Get the target of the rule. + target, consumed, err := parseTarget(optVal) + if err != nil { + return err + } + optVal = optVal[consumed:] + + table.Rules = append(table.Rules, iptables.Rule{Target: target}) + offsets = append(offsets, offset) + offset += linux.SizeOfIPTEntry + consumed + } + + // Go through the list of supported hooks for this table and, for each + // one, set the rule it corresponds to. + for hook, _ := range replace.HookEntry { + if table.ValidHooks()&uint32(hook) != 0 { + hk := hookFromLinux(hook) + for ruleIdx, offset := range offsets { + if offset == replace.HookEntry[hook] { + table.BuiltinChains[hk] = ruleIdx + } + if offset == replace.Underflow[hook] { + table.Underflows[hk] = ruleIdx + } + } + if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset { + log.Infof("Hook %v is unset.", hk) + return syserr.ErrInvalidArgument + } + if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset { + log.Infof("Underflow %v is unset.", hk) + return syserr.ErrInvalidArgument + } + } + } + + ipt := stack.IPTables() + table.SetMetadata(metadata{ + HookEntry: replace.HookEntry, + Underflow: replace.Underflow, + NumEntries: replace.NumEntries, + Size: replace.Size, + }) + ipt.Tables[replace.Name.String()] = table + // TODO: Do we need to worry about locking? We could write rules while + // packets traverse tables. + stack.SetIPTables(ipt) + + return nil +} + +// parseTarget parses a target from the start of optVal and returns the target +// along with the number of bytes it occupies in optVal. +func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { + if len(optVal) < linux.SizeOfXTEntryTarget { + return nil, 0, syserr.ErrInvalidArgument + } + var target linux.XTEntryTarget + buf := optVal[:linux.SizeOfXTEntryTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &target) + switch target.Name.String() { + case "": + // Standard target. + if len(optVal) < linux.SizeOfXTStandardTarget { + return nil, 0, syserr.ErrInvalidArgument + } + var target linux.XTStandardTarget + buf = optVal[:linux.SizeOfXTStandardTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &target) + + verdict, err := translateToStandardVerdict(target.Verdict) + if err != nil { + return nil, 0, err + } + switch verdict { + case iptables.Accept: + return iptables.UnconditionalAcceptTarget{}, linux.SizeOfXTStandardTarget, nil + case iptables.Drop: + // TODO(gvisor.dev/issue/170): Return an + // iptables.UnconditionalDropTarget to support DROP. + log.Infof("netfilter DROP is not supported yet.") + return nil, 0, syserr.ErrInvalidArgument + default: + panic(fmt.Sprintf("Unknown verdict: %v", verdict)) + } + + case errorTargetName: + // Error target. + if len(optVal) < linux.SizeOfXTErrorTarget { + return nil, 0, syserr.ErrInvalidArgument + } + var target linux.XTErrorTarget + buf = optVal[:linux.SizeOfXTErrorTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &target) + + // Error targets are used in 2 cases: + // * An actual error case. These rules have an error + // named errorTargetName. The last entry of the table + // is usually an error case to catch any packets that + // somehow fall through every rule. + // * To mark the start of a user defined chain. These + // rules have an error with the name of the chain. + switch target.Name.String() { + case errorTargetName: + return iptables.PanicTarget{}, linux.SizeOfXTErrorTarget, nil + default: + log.Infof("Unknown error target %q doesn't exist or isn't supported yet.", target.Name.String()) + return nil, 0, syserr.ErrInvalidArgument + } + } + + // Unknown target. + log.Infof("Unknown target %q doesn't exist or isn't supported yet.", target.Name.String()) + return nil, 0, syserr.ErrInvalidArgument +} + +func chainNameFromHook(hook int) string { + switch hook { + case linux.NF_INET_PRE_ROUTING: + return iptables.ChainNamePrerouting + case linux.NF_INET_LOCAL_IN: + return iptables.ChainNameInput + case linux.NF_INET_FORWARD: + return iptables.ChainNameForward + case linux.NF_INET_LOCAL_OUT: + return iptables.ChainNameOutput + case linux.NF_INET_POST_ROUTING: + return iptables.ChainNamePostrouting + } + panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain")) +} + +func hookFromLinux(hook int) iptables.Hook { + switch hook { + case linux.NF_INET_PRE_ROUTING: + return iptables.Prerouting + case linux.NF_INET_LOCAL_IN: + return iptables.Input + case linux.NF_INET_FORWARD: + return iptables.Forward + case linux.NF_INET_LOCAL_OUT: + return iptables.Output + case linux.NF_INET_POST_ROUTING: + return iptables.Postrouting + } + panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain")) +} + +// printReplace prints information about the struct ipt_replace in optVal. It +// is only for debugging. +func printReplace(optVal []byte) { + // Basic replace info. + var replace linux.IPTReplace + replaceBuf := optVal[:linux.SizeOfIPTReplace] + optVal = optVal[linux.SizeOfIPTReplace:] + binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace) + log.Infof("kevin: Replacing table %q: %+v", replace.Name.String(), replace) + + // Read in the list of entries at the end of replace. + var totalOffset uint16 + for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { + var entry linux.IPTEntry + entryBuf := optVal[:linux.SizeOfIPTEntry] + binary.Unmarshal(entryBuf, usermem.ByteOrder, &entry) + log.Infof("kevin: Entry %d (total offset %d): %+v", entryIdx, totalOffset, entry) + + totalOffset += entry.NextOffset + if entry.TargetOffset == linux.SizeOfIPTEntry { + log.Infof("kevin: Entry has no matches.") + } else { + log.Infof("kevin: Entry has matches.") + } + + var target linux.XTEntryTarget + targetBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTEntryTarget] + binary.Unmarshal(targetBuf, usermem.ByteOrder, &target) + log.Infof("kevin: Target named %q: %+v", target.Name.String(), target) + + switch target.Name.String() { + case "": + var standardTarget linux.XTStandardTarget + stBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTStandardTarget] + binary.Unmarshal(stBuf, usermem.ByteOrder, &standardTarget) + log.Infof("kevin: Standard target with verdict %q (%d).", linux.VerdictStrings[standardTarget.Verdict], standardTarget.Verdict) + case errorTargetName: + var errorTarget linux.XTErrorTarget + etBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTErrorTarget] + binary.Unmarshal(etBuf, usermem.ByteOrder, &errorTarget) + log.Infof("kevin: Error target with name %q.", errorTarget.Name.String()) + default: + log.Infof("kevin: Unknown target type.") + } + + optVal = optVal[entry.NextOffset:] } } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 140851c17..f7caa45b4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -326,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { + if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } @@ -1356,6 +1356,27 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa return nil } + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + if name == linux.IPT_SO_SET_REPLACE { + if len(optVal) < linux.SizeOfIPTReplace { + return syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + if err := netfilter.SetEntries(stack.(*Stack).Stack, optVal); err != nil { + return err + } + return nil + } else if name == linux.IPT_SO_SET_ADD_COUNTERS { + // TODO(gvisor.dev/issue/170): Counter support. + return nil + } + } + return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index 68c68d4aa..9e7005374 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -17,65 +17,107 @@ package iptables const ( - tablenameNat = "nat" - tablenameMangle = "mangle" + TablenameNat = "nat" + TablenameMangle = "mangle" + TablenameFilter = "filter" ) +// TODO: Make this an iota? Faster! Do it. // Chain names as defined by net/ipv4/netfilter/ip_tables.c. const ( - chainNamePrerouting = "PREROUTING" - chainNameInput = "INPUT" - chainNameForward = "FORWARD" - chainNameOutput = "OUTPUT" - chainNamePostrouting = "POSTROUTING" + ChainNamePrerouting = "PREROUTING" + ChainNameInput = "INPUT" + ChainNameForward = "FORWARD" + ChainNameOutput = "OUTPUT" + ChainNamePostrouting = "POSTROUTING" ) +const HookUnset = -1 + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() IPTables { return IPTables{ Tables: map[string]Table{ - tablenameNat: Table{ - BuiltinChains: map[Hook]Chain{ - Prerouting: unconditionalAcceptChain(chainNamePrerouting), - Input: unconditionalAcceptChain(chainNameInput), - Output: unconditionalAcceptChain(chainNameOutput), - Postrouting: unconditionalAcceptChain(chainNamePostrouting), + TablenameNat: Table{ + Rules: []Rule{ + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: PanicTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, }, - DefaultTargets: map[Hook]Target{ - Prerouting: UnconditionalAcceptTarget{}, - Input: UnconditionalAcceptTarget{}, - Output: UnconditionalAcceptTarget{}, - Postrouting: UnconditionalAcceptTarget{}, + Underflows: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, }, - UserChains: map[string]Chain{}, + UserChains: map[string]int{}, }, - tablenameMangle: Table{ - BuiltinChains: map[Hook]Chain{ - Prerouting: unconditionalAcceptChain(chainNamePrerouting), - Output: unconditionalAcceptChain(chainNameOutput), + TablenameMangle: Table{ + Rules: []Rule{ + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: PanicTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Output: 1, }, - DefaultTargets: map[Hook]Target{ - Prerouting: UnconditionalAcceptTarget{}, - Output: UnconditionalAcceptTarget{}, + Underflows: map[Hook]int{ + Prerouting: 0, + Output: 1, }, - UserChains: map[string]Chain{}, + UserChains: map[string]int{}, + }, + TablenameFilter: Table{ + Rules: []Rule{ + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: PanicTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + Underflows: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + UserChains: map[string]int{}, }, }, Priorities: map[Hook][]string{ - Prerouting: []string{tablenameMangle, tablenameNat}, - Output: []string{tablenameMangle, tablenameNat}, + Input: []string{TablenameNat, TablenameFilter}, + Prerouting: []string{TablenameMangle, TablenameNat}, + Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, } } -func unconditionalAcceptChain(name string) Chain { - return Chain{ - Name: name, - Rules: []Rule{ - Rule{ - Target: UnconditionalAcceptTarget{}, - }, +func EmptyFilterTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + Underflows: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, }, + UserChains: map[string]int{}, } } diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 19a7f77e3..03c9f19ff 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -33,3 +33,11 @@ type UnconditionalDropTarget struct{} func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) { return Drop, "" } + +// PanicTarget just panics. +type PanicTarget struct{} + +// Actions implements Target.Action. +func (PanicTarget) Action(packet buffer.VectorisedView) (Verdict, string) { + panic("PanicTarget triggered.") +} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 42a79ef9f..76364ff1f 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -61,9 +61,12 @@ const ( type Verdict int const ( + // Invalid indicates an unkonwn or erroneous verdict. + Invalid Verdict = iota + // Accept indicates the packet should continue traversing netstack as // normal. - Accept Verdict = iota + Accept // Drop inicates the packet should be dropped, stopping traversing // netstack. @@ -109,24 +112,18 @@ type IPTables struct { // * nat // * mangle type Table struct { - // BuiltinChains holds the un-deletable chains built into netstack. If - // a hook isn't present in the map, this table doesn't utilize that - // hook. - BuiltinChains map[Hook]Chain + // A table is just a list of rules with some entrypoints. + Rules []Rule + + BuiltinChains map[Hook]int + + Underflows map[Hook]int - // DefaultTargets holds a target for each hook that will be executed if - // chain traversal doesn't yield a verdict. - DefaultTargets map[Hook]Target + // DefaultTargets map[Hook]int // UserChains holds user-defined chains for the keyed by name. Users // can give their chains arbitrary names. - UserChains map[string]Chain - - // Chains maps names to chains for both builtin and user-defined chains. - // Its entries point to Chains already either in BuiltinChains or - // UserChains, and its purpose is to make looking up tables by name - // fast. - Chains map[string]*Chain + UserChains map[string]int // Metadata holds information about the Table that is useful to users // of IPTables, but not to the netstack IPTables code itself. @@ -152,20 +149,20 @@ func (table *Table) SetMetadata(metadata interface{}) { table.metadata = metadata } -// A Chain defines a list of rules for packet processing. When a packet -// traverses a chain, it is checked against each rule until either a rule -// returns a verdict or the chain ends. -// -// By convention, builtin chains end with a rule that matches everything and -// returns either Accept or Drop. User-defined chains end with Return. These -// aren't strictly necessary here, but the iptables tool writes tables this way. -type Chain struct { - // Name is the chain name. - Name string - - // Rules is the list of rules to traverse. - Rules []Rule -} +//// A Chain defines a list of rules for packet processing. When a packet +//// traverses a chain, it is checked against each rule until either a rule +//// returns a verdict or the chain ends. +//// +//// By convention, builtin chains end with a rule that matches everything and +//// returns either Accept or Drop. User-defined chains end with Return. These +//// aren't strictly necessary here, but the iptables tool writes tables this way. +//type Chain struct { +// // Name is the chain name. +// Name string + +// // Rules is the list of rules to traverse. +// Rules []Rule +//} // A Rule is a packet processing rule. It consists of two pieces. First it // contains zero or more matchers, each of which is a specification of which diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 923f44e68..0cb668635 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -44,6 +44,7 @@ func (FilterInputDropUDP) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterInputDropUDP) ContainerAction(ip net.IP) error { if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { + // if err := filterTable("-A", "INPUT", "-j", "ACCEPT"); err != nil { return err } diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index bfbf1bb87..e761e0f2f 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -23,6 +23,7 @@ import ( "time" "flag" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/dockerutil" "gvisor.dev/gvisor/runsc/testutil" @@ -166,14 +167,14 @@ func TestFilterInputDropUDP(t *testing.T) { } } -func TestFilterInputDropUDPPort(t *testing.T) { - if err := singleTest(FilterInputDropUDPPort{}); err != nil { - t.Fatal(err) - } -} - -func TestFilterInputDropDifferentUDPPort(t *testing.T) { - if err := singleTest(FilterInputDropDifferentUDPPort{}); err != nil { - t.Fatal(err) - } -} +// func TestFilterInputDropUDPPort(t *testing.T) { +// if err := singleTest(FilterInputDropUDPPort{}); err != nil { +// t.Fatal(err) +// } +// } + +// func TestFilterInputDropDifferentUDPPort(t *testing.T) { +// if err := singleTest(FilterInputDropDifferentUDPPort{}); err != nil { +// t.Fatal(err) +// } +// } -- cgit v1.2.3 From 1e1921e2acdb7357972257219fdffb9edf17bf55 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 8 Jan 2020 11:15:46 -0800 Subject: Minor fixes to comments and logging --- pkg/sentry/socket/netfilter/netfilter.go | 10 +++++++--- pkg/sentry/socket/netstack/netstack.go | 8 +++++--- pkg/sentry/syscalls/linux/sys_socket.go | 2 +- pkg/tcpip/iptables/targets.go | 2 +- pkg/tcpip/iptables/types.go | 28 ++++++---------------------- test/iptables/filter_input.go | 3 +-- test/iptables/iptables_test.go | 22 +++++++++++----------- 7 files changed, 32 insertions(+), 43 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 8c7f3c7fc..b7867a576 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -157,9 +157,7 @@ func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPT meta.HookEntry[hook] = entries.Size } } - // Is this a chain underflow point? The underflow rule is the last rule - // in the chain, and is an unconditional rule (i.e. it matches any - // packet). This is enforced when saving iptables. + // Is this a chain underflow point? for underflow, underflowRuleIdx := range table.Underflows { if underflowRuleIdx == ruleIdx { meta.Underflow[underflow] = entries.Size @@ -290,6 +288,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Get the basic rules data (struct ipt_replace). if len(optVal) < linux.SizeOfIPTReplace { + log.Infof("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal)) return syserr.ErrInvalidArgument } var replace linux.IPTReplace @@ -313,6 +312,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { // Get the struct ipt_entry. if len(optVal) < linux.SizeOfIPTEntry { + log.Infof("netfilter: optVal has insufficient size for entry %d", len(optVal)) return syserr.ErrInvalidArgument } var entry linux.IPTEntry @@ -328,6 +328,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // filtering. We reject any nonzero IPTIP values for now. emptyIPTIP := linux.IPTIP{} if entry.IP != emptyIPTIP { + log.Infof("netfilter: non-empty struct iptip found") return syserr.ErrInvalidArgument } @@ -386,6 +387,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // along with the number of bytes it occupies in optVal. func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { if len(optVal) < linux.SizeOfXTEntryTarget { + log.Infof("netfilter: optVal has insufficient size for entry target %d", len(optVal)) return nil, 0, syserr.ErrInvalidArgument } var target linux.XTEntryTarget @@ -395,6 +397,7 @@ func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { case "": // Standard target. if len(optVal) < linux.SizeOfXTStandardTarget { + log.Infof("netfilter.SetEntries: optVal has insufficient size for standard target %d", len(optVal)) return nil, 0, syserr.ErrInvalidArgument } var target linux.XTStandardTarget @@ -420,6 +423,7 @@ func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { case errorTargetName: // Error target. if len(optVal) < linux.SizeOfXTErrorTarget { + log.Infof("netfilter.SetEntries: optVal has insufficient size for error target %d", len(optVal)) return nil, 0, syserr.ErrInvalidArgument } var target linux.XTErrorTarget diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f7caa45b4..8c07eef4b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -326,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { + if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } @@ -1357,7 +1357,8 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - if name == linux.IPT_SO_SET_REPLACE { + switch name { + case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { return syserr.ErrInvalidArgument } @@ -1371,7 +1372,8 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa return err } return nil - } else if name == linux.IPT_SO_SET_ADD_COUNTERS { + + case linux.IPT_SO_SET_ADD_COUNTERS: // TODO(gvisor.dev/issue/170): Counter support. return nil } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 4b5aafcc0..cda517a81 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -41,7 +41,7 @@ const maxListenBacklog = 1024 const maxAddrLen = 200 // maxOptLen is the maximum sockopt parameter length we're willing to accept. -const maxOptLen = 1024 +const maxOptLen = 1024 * 8 // maxControlLen is the maximum length of the msghdr.msg_control buffer we're // willing to accept. Note that this limit is smaller than Linux, which allows diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 03c9f19ff..2c3598e3d 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -34,7 +34,7 @@ func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, st return Drop, "" } -// PanicTarget just panics. +// PanicTarget just panics. It represents a target that should be unreachable. type PanicTarget struct{} // Actions implements Target.Action. diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 76364ff1f..fe0394a31 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -107,20 +107,19 @@ type IPTables struct { Priorities map[Hook][]string } -// A Table defines a set of chains and hooks into the network stack. The -// currently supported tables are: -// * nat -// * mangle +// A Table defines a set of chains and hooks into the network stack. It is +// really just a list of rules with some metadata for entrypoints and such. type Table struct { - // A table is just a list of rules with some entrypoints. + // Rules holds the rules that make up the table. Rules []Rule + // BuiltinChains maps builtin chains to their entrypoints. BuiltinChains map[Hook]int + // Underflows maps builtin chains to their underflow point (i.e. the + // rule to execute if the chain returns without a verdict). Underflows map[Hook]int - // DefaultTargets map[Hook]int - // UserChains holds user-defined chains for the keyed by name. Users // can give their chains arbitrary names. UserChains map[string]int @@ -149,21 +148,6 @@ func (table *Table) SetMetadata(metadata interface{}) { table.metadata = metadata } -//// A Chain defines a list of rules for packet processing. When a packet -//// traverses a chain, it is checked against each rule until either a rule -//// returns a verdict or the chain ends. -//// -//// By convention, builtin chains end with a rule that matches everything and -//// returns either Accept or Drop. User-defined chains end with Return. These -//// aren't strictly necessary here, but the iptables tool writes tables this way. -//type Chain struct { -// // Name is the chain name. -// Name string - -// // Rules is the list of rules to traverse. -// Rules []Rule -//} - // A Rule is a packet processing rule. It consists of two pieces. First it // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 0cb668635..34a85db97 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -43,8 +43,7 @@ func (FilterInputDropUDP) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterInputDropUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { - // if err := filterTable("-A", "INPUT", "-j", "ACCEPT"); err != nil { + if err := filterTable("-A", "INPUT", "-j", "ACCEPT"); err != nil { return err } diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index e761e0f2f..2465a4e65 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -167,14 +167,14 @@ func TestFilterInputDropUDP(t *testing.T) { } } -// func TestFilterInputDropUDPPort(t *testing.T) { -// if err := singleTest(FilterInputDropUDPPort{}); err != nil { -// t.Fatal(err) -// } -// } - -// func TestFilterInputDropDifferentUDPPort(t *testing.T) { -// if err := singleTest(FilterInputDropDifferentUDPPort{}); err != nil { -// t.Fatal(err) -// } -// } +func TestFilterInputDropUDPPort(t *testing.T) { + if err := singleTest(FilterInputDropUDPPort{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterInputDropDifferentUDPPort(t *testing.T) { + if err := singleTest(FilterInputDropDifferentUDPPort{}); err != nil { + t.Fatal(err) + } +} -- cgit v1.2.3 From a271bccfc61390be64ca0175b8fc7d20e66d05b6 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 8 Jan 2020 14:16:38 -0800 Subject: Rename tcpip.SockOpt{,Int} PiperOrigin-RevId: 288772878 --- pkg/sentry/socket/netstack/netstack.go | 4 ++-- pkg/sentry/socket/unix/transport/unix.go | 8 ++++---- pkg/tcpip/stack/transport_test.go | 4 ++-- pkg/tcpip/tcpip.go | 10 +++++----- pkg/tcpip/transport/icmp/endpoint.go | 4 ++-- pkg/tcpip/transport/packet/endpoint.go | 4 ++-- pkg/tcpip/transport/raw/endpoint.go | 4 ++-- pkg/tcpip/transport/tcp/endpoint.go | 4 ++-- pkg/tcpip/transport/udp/endpoint.go | 4 ++-- 9 files changed, 23 insertions(+), 23 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 140851c17..5f91a0d1a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -224,7 +224,7 @@ type commonEndpoint interface { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. @@ -232,7 +232,7 @@ type commonEndpoint interface { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) } // SocketOperations encapsulates all the state needed to represent a network stack diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 529a7a7a9..fcba49435 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -177,7 +177,7 @@ type Endpoint interface { // SetSockOptInt sets a socket option for simple cases when a value has // the int type. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. @@ -185,7 +185,7 @@ type Endpoint interface { // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) // State returns the current state of the socket, as represented by Linux in // procfs. @@ -851,11 +851,11 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } -func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 748ce4ea5..095346f0b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -103,12 +103,12 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { } // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f62fd729f..b172d71b0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -425,7 +425,7 @@ type Endpoint interface { // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. - SetSockOptInt(opt SockOpt, v int) *Error + SetSockOptInt(opt SockOptInt, v int) *Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. @@ -433,7 +433,7 @@ type Endpoint interface { // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. - GetSockOptInt(SockOpt) (int, *Error) + GetSockOptInt(SockOptInt) (int, *Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -488,13 +488,13 @@ type WriteOptions struct { Atomic bool } -// SockOpt represents socket options which values have the int type. -type SockOpt int +// SockOptInt represents socket options which values have the int type. +type SockOptInt int const ( // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOpt = iota + ReceiveQueueSizeOption SockOptInt = iota // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 9c40931b5..5816ce49a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -351,12 +351,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0010b5e5f..6360ce880 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -251,12 +251,12 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5aafe2615..0fd9c456a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -510,12 +510,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index fe629aa40..f79154b95 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1146,7 +1146,7 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { } // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max @@ -1447,7 +1447,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1ac4705af..dae373ea7 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -457,7 +457,7 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } @@ -661,7 +661,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 -- cgit v1.2.3 From d530df2f95c3f75488ecc56b8fd205c3ee0966f8 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 8 Jan 2020 15:39:22 -0800 Subject: Introduce tcpip.SockOptBool ...and port V6OnlyOption to it. PiperOrigin-RevId: 288789451 --- pkg/sentry/socket/netstack/netstack.go | 21 ++++-- pkg/sentry/socket/unix/transport/unix.go | 16 +++++ pkg/tcpip/stack/ndp_test.go | 15 ++--- pkg/tcpip/stack/transport_demuxer_test.go | 48 +++++++------- pkg/tcpip/stack/transport_test.go | 10 +++ pkg/tcpip/tcpip.go | 21 ++++-- pkg/tcpip/transport/icmp/endpoint.go | 10 +++ pkg/tcpip/transport/packet/endpoint.go | 22 +++++-- pkg/tcpip/transport/raw/endpoint.go | 40 +++++++----- pkg/tcpip/transport/tcp/dual_stack_test.go | 8 +-- pkg/tcpip/transport/tcp/endpoint.go | 75 ++++++++++++---------- pkg/tcpip/transport/tcp/tcp_test.go | 8 +-- pkg/tcpip/transport/tcp/testing/context/context.go | 6 +- pkg/tcpip/transport/udp/endpoint.go | 60 +++++++++-------- pkg/tcpip/transport/udp/udp_test.go | 2 +- 15 files changed, 224 insertions(+), 138 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 5f91a0d1a..9e0d69046 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -222,6 +222,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and + // transport.Endpoint.SetSockOptBool. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -230,6 +234,10 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and + // transport.Endpoint.GetSockOpt. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -1213,12 +1221,15 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf return nil, syserr.ErrInvalidArgument } - var v tcpip.V6OnlyOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + var o uint32 + if v { + o = 1 + } + return int32(o), nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1621,7 +1632,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index fcba49435..37c7ac3c1 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptBool sets a socket option for simple cases when a value has + // the int type. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -183,6 +187,10 @@ type Endpoint interface { // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOptBool gets a socket option for simple cases when a return + // value has the int type. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -851,10 +859,18 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return nil +} + func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 070d80c8d..e51462a55 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1701,9 +1701,8 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - v := tcpip.V6OnlyOption(1) - if err := ep.SetSockOpt(v); err != nil { - t.Fatalf("SetSockOpt(%+v): %s", v, err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } if err := ep.Connect(dstAddr); err != nil { t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) @@ -1728,9 +1727,8 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - v := tcpip.V6OnlyOption(1) - if err := ep.SetSockOpt(v); err != nil { - t.Fatalf("SetSockOpt(%+v): %s", v, err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } if err := ep.Bind(addr); err != nil { t.Fatalf("ep.Bind(%+v): %s", addr, err) @@ -2066,9 +2064,8 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - v := tcpip.V6OnlyOption(1) - if err := ep.SetSockOpt(v); err != nil { - t.Fatalf("SetSockOpt(%+v): %s", v, err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 33dbc0536..df5ced887 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -61,11 +61,7 @@ func (c *testContext) createV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed: %v", err) } } @@ -201,54 +197,54 @@ func TestDistribution(t *testing.T) { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, + {1, ""}, + {1, ""}, + {1, ""}, + {1, ""}, + {1, ""}, }, map[string][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + "dev0": {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{0, "dev0"}, - endpointSockopts{0, "dev1"}, - endpointSockopts{0, "dev2"}, + {0, "dev0"}, + {0, "dev1"}, + {0, "dev2"}, }, map[string][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": []float64{1, 0, 0}, + "dev0": {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": []float64{0, 1, 0}, + "dev1": {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": []float64{0, 0, 1}, + "dev2": {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, ""}, + {1, "dev0"}, + {1, "dev0"}, + {1, "dev1"}, + {1, "dev1"}, + {1, "dev1"}, + {1, ""}, }, map[string][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + "dev0": {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + "dev1": {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": []float64{0, 0, 0, 0, 0, 1}, + "dev999": {0, 0, 0, 0, 0, 1}, }, }, } { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 095346f0b..f50604a8a 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -102,11 +102,21 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptBool sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // SetSockOptInt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index b172d71b0..1eca76c30 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -423,6 +423,10 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptBool sets a socket option, for simple cases where a value + // has the bool type. + SetSockOptBool(opt SockOptBool, v bool) *Error + // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. SetSockOptInt(opt SockOptInt, v int) *Error @@ -431,6 +435,10 @@ type Endpoint interface { // *Option types. GetSockOpt(opt interface{}) *Error + // GetSockOptBool gets a socket option for simple cases where a return + // value has the bool type. + GetSockOptBool(SockOptBool) (bool, *Error) + // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. GetSockOptInt(SockOptInt) (int, *Error) @@ -488,6 +496,15 @@ type WriteOptions struct { Atomic bool } +// SockOptBool represents socket options which values have the bool type. +type SockOptBool int + +const ( + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 + // socket is to be restricted to sending and receiving IPv6 packets only. + V6OnlyOption SockOptBool = iota +) + // SockOptInt represents socket options which values have the int type. type SockOptInt int @@ -521,10 +538,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 -// socket is to be restricted to sending and receiving IPv6 packets only. -type V6OnlyOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 5816ce49a..c7ce74cdd 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -350,11 +350,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// SetSockOptBool sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return nil +} + // SetSockOptInt sets a socket option. Currently not supported. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 6360ce880..07ffa8aba 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -247,17 +247,17 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrNotSupported + return tcpip.ErrUnknownProtocolOption } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -265,6 +265,16 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrNotSupported } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrNotSupported +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + // HandlePacket implements stack.PacketEndpoint.HandlePacket. func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { ep.rcvMu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 0fd9c456a..85f7eb76b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -509,11 +509,36 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -544,21 +569,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } -} - // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { e.rcvMu.Lock() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index dfaa4a559..4f361b226 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) { // Make sure we get the same error when calling the original ep and the // new one. This validates that v4-mapped endpoints are still able to // query the V6Only flag, whereas pure v4 endpoints are not. - var v tcpip.V6OnlyOption - expected := c.EP.GetSockOpt(&v) - if err := nep.GetSockOpt(&v); err != expected { + _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) } @@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) { // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. - var v tcpip.V6OnlyOption - if err := nep.GetSockOpt(&v); err != nil { + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { t.Fatalf("GetSockOpt failed failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f79154b95..2ac1b6877 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1145,6 +1145,29 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } +// SetSockOptBool sets a socket option. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v + } + + return nil +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { @@ -1289,23 +1312,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyMSSChanged) return nil - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v != 0 - return nil - case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) @@ -1446,6 +1452,25 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + return v, nil + } + + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -1540,22 +1565,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.TTLOption: e.mu.Lock() *o = tcpip.TTLOption(e.ttl) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9d7b0910d..15745ebd4 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4028,12 +4028,12 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { switch network { case "ipv4": case "ipv6": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) } case "dual": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) } default: t.Fatalf("unknown network: '%s'", network) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 50c81aa65..822907998 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -475,11 +475,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.EP.SetSockOpt(v); err != nil { + if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index dae373ea7..1a5ee6317 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -456,14 +456,9 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return nil -} - -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -478,8 +473,20 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - e.v6only = v != 0 + e.v6only = v + } + + return nil +} +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) @@ -660,6 +667,25 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + return v, nil + } + + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -695,22 +721,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.TTLOption: e.mu.Lock() *o = tcpip.TTLOption(e.ttl) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 65382b7f1..149fff999 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -335,7 +335,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { c.t.Fatalf("SetSockOpt failed: %v", err) } } else if flow.isBroadcast() { -- cgit v1.2.3 From 0999ae8b34d83a4b2ea8342d0459c8131c35d6e1 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 8 Jan 2020 15:57:25 -0800 Subject: Getting a panic when running tests. For some reason the filter table is ending up with the wrong chains and is indexing -1 into rules. --- pkg/sentry/socket/netfilter/netfilter.go | 17 ++++++---------- pkg/sentry/socket/netstack/netstack.go | 12 +++++++++-- pkg/tcpip/BUILD | 1 - pkg/tcpip/iptables/BUILD | 1 + pkg/tcpip/iptables/iptables.go | 35 +++++++++++++++++++++++++------- pkg/tcpip/iptables/targets.go | 8 ++++---- pkg/tcpip/iptables/types.go | 8 +++----- pkg/tcpip/network/arp/arp.go | 2 +- pkg/tcpip/network/ipv4/BUILD | 1 + pkg/tcpip/network/ipv4/ipv4.go | 8 ++++++-- pkg/tcpip/network/ipv6/ipv6.go | 2 +- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/registration.go | 2 +- pkg/tcpip/tcpip.go | 4 ---- 14 files changed, 63 insertions(+), 40 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 57785220e..3a857ef6d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -45,7 +44,7 @@ type metadata struct { } // GetInfo returns information about iptables. -func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { +func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo if _, err := t.CopyIn(outPtr, &info); err != nil { @@ -53,7 +52,7 @@ func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTG } // Find the appropriate table. - table, err := findTable(ep, info.Name.String()) + table, err := findTable(stack, info.Name.String()) if err != nil { return linux.IPTGetinfo{}, err } @@ -76,7 +75,7 @@ func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTG } // GetEntries returns netstack's iptables rules encoded for the iptables tool. -func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { +func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries if _, err := t.CopyIn(outPtr, &userEntries); err != nil { @@ -84,7 +83,7 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i } // Find the appropriate table. - table, err := findTable(ep, userEntries.Name.String()) + table, err := findTable(stack, userEntries.Name.String()) if err != nil { return linux.KernelIPTGetEntries{}, err } @@ -103,12 +102,8 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i return entries, nil } -func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) { - ipt, err := ep.IPTables() - if err != nil { - return iptables.Table{}, syserr.FromError(err) - } - table, ok := ipt.Tables[tableName] +func findTable(stack *stack.Stack, tableName string) (iptables.Table, *syserr.Error) { + table, ok := stack.IPTables().Tables[tableName] if !ok { return iptables.Table{}, syserr.ErrInvalidArgument } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8c07eef4b..86a8104df 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -826,7 +826,11 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us return nil, syserr.ErrInvalidArgument } - info, err := netfilter.GetInfo(t, s.Endpoint, outPtr) + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) if err != nil { return nil, err } @@ -837,7 +841,11 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us return nil, syserr.ErrInvalidArgument } - entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen) + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 65d4d0cd8..36bc3a63b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -15,7 +15,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/buffer", - "//pkg/tcpip/iptables", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 6ed7c6da0..2893c80cd 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -12,6 +12,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index 025a4679d..aff8a680b 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -16,7 +16,12 @@ // tool. package iptables -import "github.com/google/netstack/tcpip" +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" +) const ( TablenameNat = "nat" @@ -135,31 +140,47 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { // Go through each table containing the hook. for _, tablename := range it.Priorities[hook] { - verdict := it.checkTable(tablename) + verdict := it.checkTable(hook, pkt, tablename) switch verdict { - // TODO: We either got a final verdict or simply continue on. + // If the table returns Accept, move on to the next table. + case Accept: + continue + // The Drop verdict is final. + case Drop: + log.Infof("kevin: Packet dropped") + return false + case Stolen, Queue, Repeat, None, Jump, Return, Continue: + panic(fmt.Sprintf("Unimplemented verdict %v.", verdict)) } } + + // Every table returned Accept. + log.Infof("kevin: Packet accepted") + return true } -func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) bool { +func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) Verdict { log.Infof("kevin: iptables.IPTables: checking table %q", tablename) table := it.Tables[tablename] - ruleIdx := table.BuiltinChains[hook] + log.Infof("kevin: iptables.IPTables: table %+v", table) // Start from ruleIdx and go down until a rule gives us a verdict. for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ { - verdict := checkRule(hook, pkt, table, ruleIdx) + verdict := it.checkRule(hook, pkt, table, ruleIdx) switch verdict { + // For either of these cases, this table is done with the + // packet. case Accept, Drop: return verdict + // Continue traversing the rules of the table. case Continue: continue case Stolen, Queue, Repeat, None, Jump, Return: + panic(fmt.Sprintf("Unimplemented verdict %v.", verdict)) } } - panic("Traversed past the entire list of iptables rules.") + panic(fmt.Sprintf("Traversed past the entire list of iptables rules in table %q.", tablename)) } func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) Verdict { diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 2c3598e3d..cb3ac1aff 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -16,13 +16,13 @@ package iptables -import "gvisor.dev/gvisor/pkg/tcpip/buffer" +import "gvisor.dev/gvisor/pkg/tcpip" // UnconditionalAcceptTarget accepts all packets. type UnconditionalAcceptTarget struct{} // Action implements Target.Action. -func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, string) { +func (UnconditionalAcceptTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { return Accept, "" } @@ -30,7 +30,7 @@ func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, type UnconditionalDropTarget struct{} // Action implements Target.Action. -func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) { +func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { return Drop, "" } @@ -38,6 +38,6 @@ func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, st type PanicTarget struct{} // Actions implements Target.Action. -func (PanicTarget) Action(packet buffer.VectorisedView) (Verdict, string) { +func (PanicTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { panic("PanicTarget triggered.") } diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 540f8c0b4..9f6906100 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -14,9 +14,7 @@ package iptables -import ( - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) +import "gvisor.dev/gvisor/pkg/tcpip" // A Hook specifies one of the hooks built into the network stack. // @@ -165,7 +163,7 @@ type Matcher interface { // Match returns whether the packet matches and whether the packet // should be "hotdropped", i.e. dropped immediately. This is usually // used for suspicious packets. - Match(hook Hook, packet buffer.VectorisedView, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. @@ -173,5 +171,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the name of the chain to jump to. - Action(packet buffer.VectorisedView) (Verdict, string) + Action(packet tcpip.PacketBuffer) (Verdict, string) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index da8482509..d88119f68 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -137,7 +137,7 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index aeddfcdd4..4e2aae9a3 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -15,6 +15,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bbb5aafee..f856081e6 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -54,10 +55,11 @@ type endpoint struct { dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation protocol *protocol + stack *stack.Stack } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, @@ -66,6 +68,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, + stack: st, } return e, nil @@ -351,7 +354,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { pkt.NetworkHeader = headerView[:h.HeaderLength()] // iptables filtering. - if ok := iptables.Check(iptables.Input, pkt); !ok { + ipt := e.stack.IPTables() + if ok := ipt.Check(iptables.Input, pkt); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e13f1fabf..4c940e9e5 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -221,7 +221,7 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4144d5d0f..f2d338bd1 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -467,7 +467,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP) + ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack) if err != nil { return nil, err } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 61fd46d66..754323e82 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -282,7 +282,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error) // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f62fd729f..d02950c7a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -40,7 +40,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" ) @@ -446,9 +445,6 @@ type Endpoint interface { // NOTE: This method is a no-op for sockets other than TCP. ModerateRecvBuf(copied int) - // IPTables returns the iptables for this endpoint's stack. - IPTables() (iptables.IPTables, error) - // Info returns a copy to the transport endpoint info. Info() EndpointInfo -- cgit v1.2.3 From f26a576984052a235b63ec79081a8c4a8c8ffc00 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 8 Jan 2020 16:35:01 -0800 Subject: Addressed GH comments --- pkg/abi/linux/netfilter.go | 15 ++++-- pkg/sentry/socket/netfilter/netfilter.go | 87 +++++++++++++------------------- pkg/sentry/socket/netstack/netstack.go | 5 +- 3 files changed, 47 insertions(+), 60 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 35d66d622..c4f4ea0b1 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -298,6 +298,7 @@ type IPTReplace struct { // Entries [0]IPTEntry } +// KernelIPTEntry is identical to IPTReplace, but includes the Entries field. type KernelIPTReplace struct { IPTReplace Entries [0]IPTEntry @@ -306,28 +307,32 @@ type KernelIPTReplace struct { // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 +// ExtensionName holds the name of a netfilter extension. type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte // String implements fmt.Stringer. func (en ExtensionName) String() string { - return name(en[:]) + return goString(en[:]) } +// ExtensionName holds the name of a netfilter table. type TableName [XT_TABLE_MAXNAMELEN]byte // String implements fmt.Stringer. func (tn TableName) String() string { - return name(tn[:]) + return goString(tn[:]) } +// ExtensionName holds the name of a netfilter error. These can also hold +// user-defined chains. type ErrorName [XT_FUNCTION_MAXNAMELEN]byte // String implements fmt.Stringer. -func (fn ErrorName) String() string { - return name(fn[:]) +func (en ErrorName) String() string { + return goString(en[:]) } -func name(cstring []byte) string { +func goString(cstring []byte) string { for i, c := range cstring { if c == 0 { return string(cstring[:i]) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 347342f98..799865b03 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -53,7 +53,7 @@ func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTG } // Find the appropriate table. - table, err := findTable(ep, info.Name.String()) + table, err := findTable(ep, info.Name) if err != nil { return linux.IPTGetinfo{}, err } @@ -84,7 +84,7 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i } // Find the appropriate table. - table, err := findTable(ep, userEntries.Name.String()) + table, err := findTable(ep, userEntries.Name) if err != nil { return linux.KernelIPTGetEntries{}, err } @@ -96,19 +96,19 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i return linux.KernelIPTGetEntries{}, err } if binary.Size(entries) > uintptr(outLen) { - log.Infof("Insufficient GetEntries output size: %d", uintptr(outLen)) + log.Warningf("Insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } return entries, nil } -func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) { +func findTable(ep tcpip.Endpoint, tablename linux.TableName) (iptables.Table, *syserr.Error) { ipt, err := ep.IPTables() if err != nil { return iptables.Table{}, syserr.FromError(err) } - table, ok := ipt.Tables[tableName] + table, ok := ipt.Tables[tablename.String()] if !ok { return iptables.Table{}, syserr.ErrInvalidArgument } @@ -138,17 +138,17 @@ func FillDefaultIPTables(stack *stack.Stack) { // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) { +func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) { // Return values. var entries linux.KernelIPTGetEntries var meta metadata // The table name has to fit in the struct. - if linux.XT_TABLE_MAXNAMELEN < len(name) { - log.Infof("Table name too long.") + if linux.XT_TABLE_MAXNAMELEN < len(tablename) { + log.Warningf("Table name %q too long.", tablename) return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument } - copy(entries.Name[:], name) + copy(entries.Name[:], tablename) for ruleIdx, rule := range table.Rules { // Is this a chain entry point? @@ -273,11 +273,12 @@ func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) { case -linux.NF_DROP - 1: return iptables.Drop, nil case -linux.NF_QUEUE - 1: - log.Infof("Unsupported iptables verdict QUEUE.") + log.Warningf("Unsupported iptables verdict QUEUE.") case linux.NF_RETURN: - log.Infof("Unsupported iptables verdict RETURN.") + log.Warningf("Unsupported iptables verdict RETURN.") + default: + log.Warningf("Unknown iptables verdict %d.", val) } - log.Infof("Unknown iptables verdict %d.", val) return iptables.Invalid, syserr.ErrInvalidArgument } @@ -288,7 +289,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Get the basic rules data (struct ipt_replace). if len(optVal) < linux.SizeOfIPTReplace { - log.Infof("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal)) + log.Warningf("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal)) return syserr.ErrInvalidArgument } var replace linux.IPTReplace @@ -302,7 +303,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { case iptables.TablenameFilter: table = iptables.EmptyFilterTable() default: - log.Infof(fmt.Sprintf("We don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())) + log.Warningf("We don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument } @@ -312,7 +313,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { // Get the struct ipt_entry. if len(optVal) < linux.SizeOfIPTEntry { - log.Infof("netfilter: optVal has insufficient size for entry %d", len(optVal)) + log.Warningf("netfilter: optVal has insufficient size for entry %d", len(optVal)) return syserr.ErrInvalidArgument } var entry linux.IPTEntry @@ -328,7 +329,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // filtering. We reject any nonzero IPTIP values for now. emptyIPTIP := linux.IPTIP{} if entry.IP != emptyIPTIP { - log.Infof("netfilter: non-empty struct iptip found") + log.Warningf("netfilter: non-empty struct iptip found") return syserr.ErrInvalidArgument } @@ -358,11 +359,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { } } if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset { - log.Infof("Hook %v is unset.", hk) + log.Warningf("Hook %v is unset.", hk) return syserr.ErrInvalidArgument } if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset { - log.Infof("Underflow %v is unset.", hk) + log.Warningf("Underflow %v is unset.", hk) return syserr.ErrInvalidArgument } } @@ -385,7 +386,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // along with the number of bytes it occupies in optVal. func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { if len(optVal) < linux.SizeOfXTEntryTarget { - log.Infof("netfilter: optVal has insufficient size for entry target %d", len(optVal)) + log.Warningf("netfilter: optVal has insufficient size for entry target %d", len(optVal)) return nil, 0, syserr.ErrInvalidArgument } var target linux.XTEntryTarget @@ -395,14 +396,14 @@ func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { case "": // Standard target. if len(optVal) < linux.SizeOfXTStandardTarget { - log.Infof("netfilter.SetEntries: optVal has insufficient size for standard target %d", len(optVal)) + log.Warningf("netfilter.SetEntries: optVal has insufficient size for standard target %d", len(optVal)) return nil, 0, syserr.ErrInvalidArgument } - var target linux.XTStandardTarget + var standardTarget linux.XTStandardTarget buf = optVal[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &target) + binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) - verdict, err := translateToStandardVerdict(target.Verdict) + verdict, err := translateToStandardVerdict(standardTarget.Verdict) if err != nil { return nil, 0, err } @@ -424,9 +425,9 @@ func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { log.Infof("netfilter.SetEntries: optVal has insufficient size for error target %d", len(optVal)) return nil, 0, syserr.ErrInvalidArgument } - var target linux.XTErrorTarget + var errorTarget linux.XTErrorTarget buf = optVal[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &target) + binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) // Error targets are used in 2 cases: // * An actual error case. These rules have an error @@ -435,11 +436,11 @@ func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { // somehow fall through every rule. // * To mark the start of a user defined chain. These // rules have an error with the name of the chain. - switch target.Name.String() { + switch errorTarget.Name.String() { case errorTargetName: return iptables.PanicTarget{}, linux.SizeOfXTErrorTarget, nil default: - log.Infof("Unknown error target %q doesn't exist or isn't supported yet.", target.Name.String()) + log.Infof("Unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String()) return nil, 0, syserr.ErrInvalidArgument } } @@ -449,22 +450,6 @@ func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) { return nil, 0, syserr.ErrInvalidArgument } -func chainNameFromHook(hook int) string { - switch hook { - case linux.NF_INET_PRE_ROUTING: - return iptables.ChainNamePrerouting - case linux.NF_INET_LOCAL_IN: - return iptables.ChainNameInput - case linux.NF_INET_FORWARD: - return iptables.ChainNameForward - case linux.NF_INET_LOCAL_OUT: - return iptables.ChainNameOutput - case linux.NF_INET_POST_ROUTING: - return iptables.ChainNamePostrouting - } - panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain")) -} - func hookFromLinux(hook int) iptables.Hook { switch hook { case linux.NF_INET_PRE_ROUTING: @@ -489,7 +474,7 @@ func printReplace(optVal []byte) { replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace) - log.Infof("kevin: Replacing table %q: %+v", replace.Name.String(), replace) + log.Infof("Replacing table %q: %+v", replace.Name.String(), replace) // Read in the list of entries at the end of replace. var totalOffset uint16 @@ -497,33 +482,33 @@ func printReplace(optVal []byte) { var entry linux.IPTEntry entryBuf := optVal[:linux.SizeOfIPTEntry] binary.Unmarshal(entryBuf, usermem.ByteOrder, &entry) - log.Infof("kevin: Entry %d (total offset %d): %+v", entryIdx, totalOffset, entry) + log.Infof("Entry %d (total offset %d): %+v", entryIdx, totalOffset, entry) totalOffset += entry.NextOffset if entry.TargetOffset == linux.SizeOfIPTEntry { - log.Infof("kevin: Entry has no matches.") + log.Infof("Entry has no matches.") } else { - log.Infof("kevin: Entry has matches.") + log.Infof("Entry has matches.") } var target linux.XTEntryTarget targetBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTEntryTarget] binary.Unmarshal(targetBuf, usermem.ByteOrder, &target) - log.Infof("kevin: Target named %q: %+v", target.Name.String(), target) + log.Infof("Target named %q: %+v", target.Name.String(), target) switch target.Name.String() { case "": var standardTarget linux.XTStandardTarget stBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTStandardTarget] binary.Unmarshal(stBuf, usermem.ByteOrder, &standardTarget) - log.Infof("kevin: Standard target with verdict %q (%d).", linux.VerdictStrings[standardTarget.Verdict], standardTarget.Verdict) + log.Infof("Standard target with verdict %q (%d).", linux.VerdictStrings[standardTarget.Verdict], standardTarget.Verdict) case errorTargetName: var errorTarget linux.XTErrorTarget etBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTErrorTarget] binary.Unmarshal(etBuf, usermem.ByteOrder, &errorTarget) - log.Infof("kevin: Error target with name %q.", errorTarget.Name.String()) + log.Infof("Error target with name %q.", errorTarget.Name.String()) default: - log.Infof("kevin: Unknown target type.") + log.Infof("Unknown target type.") } optVal = optVal[entry.NextOffset:] diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8c07eef4b..cd3dd1a53 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1368,10 +1368,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa return syserr.ErrNoDevice } // Stack must be a netstack stack. - if err := netfilter.SetEntries(stack.(*Stack).Stack, optVal); err != nil { - return err - } - return nil + return netfilter.SetEntries(stack.(*Stack).Stack, optVal) case linux.IPT_SO_SET_ADD_COUNTERS: // TODO(gvisor.dev/issue/170): Counter support. -- cgit v1.2.3 From 8643933d6e58492cbe9d5c78124873ab40f65feb Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Thu, 9 Jan 2020 13:06:24 -0800 Subject: Change BindToDeviceOption to store NICID This makes it possible to call the sockopt from go even when the NIC has no name. PiperOrigin-RevId: 288955236 --- pkg/sentry/socket/netstack/netstack.go | 29 ++++++++-- pkg/tcpip/stack/stack.go | 8 +++ pkg/tcpip/stack/transport_demuxer_test.go | 89 +++++++++++++++---------------- pkg/tcpip/tcpip.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 27 ++++------ pkg/tcpip/transport/tcp/tcp_test.go | 42 +++++++-------- pkg/tcpip/transport/udp/endpoint.go | 27 ++++------ pkg/tcpip/transport/udp/udp_test.go | 31 +++++------ 8 files changed, 127 insertions(+), 128 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9e0d69046..764f11a6b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -985,13 +985,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - if len(v) == 0 { + if v == 0 { return []byte{}, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument } - return append([]byte(v), 0), nil + s := t.NetworkContext() + if s == nil { + return nil, syserr.ErrNoDevice + } + nic, ok := s.Interfaces()[int32(v)] + if !ok { + // The NICID no longer indicates a valid interface, probably because that + // interface was removed. + return nil, syserr.ErrUnknownDevice + } + return append([]byte(nic.Name), 0), nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1438,7 +1448,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i if n == -1 { n = len(optVal) } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + name := string(optVal[:n]) + if name == "" { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0))) + } + s := t.NetworkContext() + if s == nil { + return syserr.ErrNoDevice + } + for nicID, nic := range s.Interfaces() { + if nic.Name == name { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID))) + } + } + return syserr.ErrUnknownDevice case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e2a2edb2c..41bf9fd9b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -901,6 +901,14 @@ type NICInfo struct { Context NICContext } +// HasNIC returns true if the NICID is defined in the stack. +func (s *Stack) HasNIC(id tcpip.NICID) bool { + s.mu.RLock() + _, ok := s.nics[id] + s.mu.RUnlock() + return ok +} + // NICInfo returns a map of NICIDs to their associated information. func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { s.mu.RLock() diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index df5ced887..5e9237de9 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -41,7 +41,7 @@ const ( type testContext struct { t *testing.T - linkEPs map[string]*channel.Endpoint + linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack ep tcpip.Endpoint @@ -66,27 +66,24 @@ func (c *testContext) createV6Endpoint(v6only bool) { } } -// newDualTestContextMultiNic creates the testing context and also linkEpNames -// named NICs. -func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { +// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. +func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEPs := make(map[string]*channel.Endpoint) - for i, linkEpName := range linkEpNames { - channelEP := channel.New(256, mtu, "") - nicID := tcpip.NICID(i + 1) - opts := stack.NICOptions{Name: linkEpName} - if err := s.CreateNICWithOptions(nicID, channelEP, opts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) + linkEps := make(map[tcpip.NICID]*channel.Endpoint) + for _, linkEpID := range linkEpIDs { + channelEp := channel.New(256, mtu, "") + if err := s.CreateNIC(linkEpID, channelEp); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } - linkEPs[linkEpName] = channelEP + linkEps[linkEpID] = channelEp - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } @@ -105,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) return &testContext{ t: t, s: s, - linkEPs: linkEPs, + linkEps: linkEps, } } @@ -122,7 +119,7 @@ func newPayload() []byte { return b } -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -153,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -183,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) { func TestDistribution(t *testing.T) { type endpointSockopts struct { reuse int - bindToDevice string + bindToDevice tcpip.NICID } for _, test := range []struct { name string @@ -191,71 +188,71 @@ func TestDistribution(t *testing.T) { endpoints []endpointSockopts // wantedDistribution is the wanted ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[string][]float64 + wantedDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {1, ""}, - {1, ""}, - {1, ""}, - {1, ""}, - {1, ""}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": {0.2, 0.2, 0.2, 0.2, 0.2}, + 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {0, "dev0"}, - {0, "dev1"}, - {0, "dev2"}, + {0, 1}, + {0, 2}, + {0, 3}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": {1, 0, 0}, + 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": {0, 1, 0}, + 2: {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": {0, 0, 1}, + 3: {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {1, "dev0"}, - {1, "dev0"}, - {1, "dev1"}, - {1, "dev1"}, - {1, "dev1"}, - {1, ""}, + {1, 1}, + {1, 1}, + {1, 2}, + {1, 2}, + {1, 2}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": {0.5, 0.5, 0, 0, 0, 0}, + 1: {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": {0, 0, 0, 0, 0, 1}, + 1000: {0, 0, 0, 0, 0, 1}, }, }, } { t.Run(test.name, func(t *testing.T) { for device, wantedDistribution := range test.wantedDistributions { - t.Run(device, func(t *testing.T) { - var devices []string + t.Run(string(device), func(t *testing.T) { + var devices []tcpip.NICID for d := range test.wantedDistributions { devices = append(devices, d) } - c := newDualTestContextMultiNic(t, defaultMTU, devices) + c := newDualTestContextMultiNIC(t, defaultMTU, devices) defer c.cleanup() c.createV6Endpoint(false) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 1eca76c30..72b5ce179 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -552,7 +552,7 @@ type ReusePortOption int // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. -type BindToDeviceOption string +type BindToDeviceOption NICID // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 2ac1b6877..920b24975 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1279,19 +1279,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.QuickAckOption: if v == 0 { @@ -1550,12 +1545,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = "" + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.QuickAckOption: diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 15745ebd4..1aa0733d0 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1083,12 +1083,12 @@ func TestTrafficClassV6(t *testing.T) { func TestConnectBindToDevice(t *testing.T) { for _, test := range []struct { name string - device string + device tcpip.NICID want tcp.EndpointState }{ - {"RightDevice", "nic1", tcp.StateEstablished}, - {"WrongDevice", "nic2", tcp.StateSynSent}, - {"AnyDevice", "", tcp.StateEstablished}, + {"RightDevice", 1, tcp.StateEstablished}, + {"WrongDevice", 2, tcp.StateSynSent}, + {"AnyDevice", 0, tcp.StateEstablished}, } { t.Run(test.name, func(t *testing.T) { c := context.New(t, defaultMTU) @@ -3794,47 +3794,41 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - opts := stack.NICOptions{Name: "my_device"} - if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { + if err := s.CreateNIC(321, loopback.New()); err != nil { t.Errorf("CreateNIC failed: %v", err) } - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1a5ee6317..864dc8733 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -631,19 +631,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.BroadcastOption: e.mu.Lock() @@ -767,12 +762,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = tcpip.BindToDeviceOption("") + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 149fff999..0a82bc4fa 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -513,42 +513,37 @@ func TestBindToDeviceOption(t *testing.T) { t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) - } - - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } -- cgit v1.2.3 From 27500d529f7fb87eef8812278fd1bbca67bcba72 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 9 Jan 2020 22:00:42 -0800 Subject: New sync package. * Rename syncutil to sync. * Add aliases to sync types. * Replace existing usage of standard library sync package. This will make it easier to swap out synchronization primitives. For example, this will allow us to use primitives from github.com/sasha-s/go-deadlock to check for lock ordering violations. Updates #1472 PiperOrigin-RevId: 289033387 --- pkg/amutex/BUILD | 1 + pkg/amutex/amutex_test.go | 3 +- pkg/atomicbitops/BUILD | 1 + pkg/atomicbitops/atomic_bitops_test.go | 3 +- pkg/compressio/BUILD | 5 +- pkg/compressio/compressio.go | 2 +- pkg/control/server/BUILD | 1 + pkg/control/server/server.go | 2 +- pkg/eventchannel/BUILD | 2 + pkg/eventchannel/event.go | 2 +- pkg/eventchannel/event_test.go | 2 +- pkg/fdchannel/BUILD | 1 + pkg/fdchannel/fdchannel_test.go | 3 +- pkg/fdnotifier/BUILD | 1 + pkg/fdnotifier/fdnotifier.go | 2 +- pkg/flipcall/BUILD | 3 +- pkg/flipcall/flipcall_example_test.go | 3 +- pkg/flipcall/flipcall_test.go | 3 +- pkg/flipcall/flipcall_unsafe.go | 10 +- pkg/gate/BUILD | 1 + pkg/gate/gate_test.go | 2 +- pkg/linewriter/BUILD | 1 + pkg/linewriter/linewriter.go | 3 +- pkg/log/BUILD | 5 +- pkg/log/log.go | 2 +- pkg/metric/BUILD | 1 + pkg/metric/metric.go | 2 +- pkg/p9/BUILD | 1 + pkg/p9/client.go | 2 +- pkg/p9/p9test/BUILD | 2 + pkg/p9/p9test/client_test.go | 2 +- pkg/p9/p9test/p9test.go | 2 +- pkg/p9/path_tree.go | 3 +- pkg/p9/pool.go | 2 +- pkg/p9/server.go | 2 +- pkg/p9/transport.go | 2 +- pkg/procid/BUILD | 2 + pkg/procid/procid_test.go | 3 +- pkg/rand/BUILD | 5 +- pkg/rand/rand_linux.go | 2 +- pkg/refs/BUILD | 2 + pkg/refs/refcounter.go | 2 +- pkg/refs/refcounter_test.go | 3 +- pkg/sentry/arch/BUILD | 1 + pkg/sentry/arch/arch_x86.go | 2 +- pkg/sentry/control/BUILD | 1 + pkg/sentry/control/pprof.go | 2 +- pkg/sentry/device/BUILD | 5 +- pkg/sentry/device/device.go | 2 +- pkg/sentry/fs/BUILD | 3 +- pkg/sentry/fs/copy_up.go | 2 +- pkg/sentry/fs/copy_up_test.go | 2 +- pkg/sentry/fs/dirent.go | 2 +- pkg/sentry/fs/dirent_cache.go | 3 +- pkg/sentry/fs/dirent_cache_limiter.go | 3 +- pkg/sentry/fs/fdpipe/BUILD | 1 + pkg/sentry/fs/fdpipe/pipe.go | 2 +- pkg/sentry/fs/fdpipe/pipe_state.go | 2 +- pkg/sentry/fs/file.go | 2 +- pkg/sentry/fs/file_overlay.go | 2 +- pkg/sentry/fs/filesystems.go | 2 +- pkg/sentry/fs/fs.go | 3 +- pkg/sentry/fs/fsutil/BUILD | 1 + pkg/sentry/fs/fsutil/host_file_mapper.go | 2 +- pkg/sentry/fs/fsutil/host_mappable.go | 2 +- pkg/sentry/fs/fsutil/inode.go | 3 +- pkg/sentry/fs/fsutil/inode_cached.go | 2 +- pkg/sentry/fs/gofer/BUILD | 1 + pkg/sentry/fs/gofer/inode.go | 2 +- pkg/sentry/fs/gofer/session.go | 2 +- pkg/sentry/fs/host/BUILD | 1 + pkg/sentry/fs/host/inode.go | 2 +- pkg/sentry/fs/host/socket.go | 2 +- pkg/sentry/fs/host/tty.go | 3 +- pkg/sentry/fs/inode.go | 3 +- pkg/sentry/fs/inode_inotify.go | 3 +- pkg/sentry/fs/inotify.go | 2 +- pkg/sentry/fs/inotify_watch.go | 2 +- pkg/sentry/fs/lock/BUILD | 1 + pkg/sentry/fs/lock/lock.go | 2 +- pkg/sentry/fs/mounts.go | 2 +- pkg/sentry/fs/overlay.go | 5 +- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/seqfile/BUILD | 1 + pkg/sentry/fs/proc/seqfile/seqfile.go | 2 +- pkg/sentry/fs/proc/sys_net.go | 2 +- pkg/sentry/fs/ramfs/BUILD | 1 + pkg/sentry/fs/ramfs/dir.go | 2 +- pkg/sentry/fs/restore.go | 2 +- pkg/sentry/fs/tmpfs/BUILD | 1 + pkg/sentry/fs/tmpfs/inode_file.go | 2 +- pkg/sentry/fs/tty/BUILD | 1 + pkg/sentry/fs/tty/dir.go | 2 +- pkg/sentry/fs/tty/line_discipline.go | 2 +- pkg/sentry/fs/tty/queue.go | 3 +- pkg/sentry/fsimpl/ext/BUILD | 1 + pkg/sentry/fsimpl/ext/directory.go | 3 +- pkg/sentry/fsimpl/ext/filesystem.go | 2 +- pkg/sentry/fsimpl/ext/regular_file.go | 2 +- pkg/sentry/fsimpl/kernfs/BUILD | 2 + pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 2 +- pkg/sentry/fsimpl/tmpfs/BUILD | 1 + pkg/sentry/fsimpl/tmpfs/regular_file.go | 2 +- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 2 +- pkg/sentry/kernel/BUILD | 5 +- pkg/sentry/kernel/abstract_socket_namespace.go | 2 +- pkg/sentry/kernel/auth/BUILD | 3 +- pkg/sentry/kernel/auth/user_namespace.go | 2 +- pkg/sentry/kernel/epoll/BUILD | 1 + pkg/sentry/kernel/epoll/epoll.go | 2 +- pkg/sentry/kernel/eventfd/BUILD | 1 + pkg/sentry/kernel/eventfd/eventfd.go | 2 +- pkg/sentry/kernel/fasync/BUILD | 1 + pkg/sentry/kernel/fasync/fasync.go | 3 +- pkg/sentry/kernel/fd_table.go | 2 +- pkg/sentry/kernel/fd_table_test.go | 2 +- pkg/sentry/kernel/fs_context.go | 2 +- pkg/sentry/kernel/futex/BUILD | 8 +- pkg/sentry/kernel/futex/futex.go | 3 +- pkg/sentry/kernel/futex/futex_test.go | 2 +- pkg/sentry/kernel/kernel.go | 2 +- pkg/sentry/kernel/memevent/BUILD | 1 + pkg/sentry/kernel/memevent/memory_events.go | 2 +- pkg/sentry/kernel/pipe/BUILD | 1 + pkg/sentry/kernel/pipe/buffer.go | 2 +- pkg/sentry/kernel/pipe/node.go | 3 +- pkg/sentry/kernel/pipe/pipe.go | 2 +- pkg/sentry/kernel/pipe/pipe_util.go | 2 +- pkg/sentry/kernel/pipe/vfs.go | 3 +- pkg/sentry/kernel/semaphore/BUILD | 1 + pkg/sentry/kernel/semaphore/semaphore.go | 2 +- pkg/sentry/kernel/shm/BUILD | 1 + pkg/sentry/kernel/shm/shm.go | 2 +- pkg/sentry/kernel/signal_handlers.go | 3 +- pkg/sentry/kernel/signalfd/BUILD | 1 + pkg/sentry/kernel/signalfd/signalfd.go | 3 +- pkg/sentry/kernel/syscalls.go | 2 +- pkg/sentry/kernel/syslog.go | 3 +- pkg/sentry/kernel/task.go | 5 +- pkg/sentry/kernel/thread_group.go | 2 +- pkg/sentry/kernel/threads.go | 2 +- pkg/sentry/kernel/time/BUILD | 1 + pkg/sentry/kernel/time/time.go | 2 +- pkg/sentry/kernel/timekeeper.go | 2 +- pkg/sentry/kernel/tty.go | 2 +- pkg/sentry/kernel/uts_namespace.go | 3 +- pkg/sentry/limits/BUILD | 1 + pkg/sentry/limits/limits.go | 3 +- pkg/sentry/mm/BUILD | 2 +- pkg/sentry/mm/aio_context.go | 3 +- pkg/sentry/mm/mm.go | 8 +- pkg/sentry/pgalloc/BUILD | 1 + pkg/sentry/pgalloc/pgalloc.go | 2 +- pkg/sentry/platform/interrupt/BUILD | 1 + pkg/sentry/platform/interrupt/interrupt.go | 3 +- pkg/sentry/platform/kvm/BUILD | 1 + pkg/sentry/platform/kvm/address_space.go | 2 +- pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go | 2 - pkg/sentry/platform/kvm/kvm.go | 2 +- pkg/sentry/platform/kvm/machine.go | 2 +- pkg/sentry/platform/ptrace/BUILD | 1 + pkg/sentry/platform/ptrace/ptrace.go | 2 +- pkg/sentry/platform/ptrace/subprocess.go | 2 +- .../platform/ptrace/subprocess_linux_unsafe.go | 2 +- pkg/sentry/platform/ring0/defs.go | 2 +- pkg/sentry/platform/ring0/defs_amd64.go | 1 + pkg/sentry/platform/ring0/defs_arm64.go | 1 + pkg/sentry/platform/ring0/pagetables/BUILD | 5 +- pkg/sentry/platform/ring0/pagetables/pcids_x86.go | 2 +- pkg/sentry/socket/netlink/BUILD | 1 + pkg/sentry/socket/netlink/port/BUILD | 1 + pkg/sentry/socket/netlink/port/port.go | 3 +- pkg/sentry/socket/netlink/socket.go | 2 +- pkg/sentry/socket/netstack/BUILD | 1 + pkg/sentry/socket/netstack/netstack.go | 2 +- pkg/sentry/socket/rpcinet/conn/BUILD | 1 + pkg/sentry/socket/rpcinet/conn/conn.go | 2 +- pkg/sentry/socket/rpcinet/notifier/BUILD | 1 + pkg/sentry/socket/rpcinet/notifier/notifier.go | 2 +- pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 3 +- pkg/sentry/socket/unix/transport/queue.go | 3 +- pkg/sentry/socket/unix/transport/unix.go | 2 +- pkg/sentry/syscalls/linux/BUILD | 1 + pkg/sentry/syscalls/linux/error.go | 2 +- pkg/sentry/time/BUILD | 4 +- pkg/sentry/time/calibrated_clock.go | 2 +- pkg/sentry/usage/BUILD | 1 + pkg/sentry/usage/memory.go | 2 +- pkg/sentry/vfs/BUILD | 3 +- pkg/sentry/vfs/dentry.go | 2 +- pkg/sentry/vfs/file_description_impl_util.go | 2 +- pkg/sentry/vfs/mount_test.go | 3 +- pkg/sentry/vfs/mount_unsafe.go | 4 +- pkg/sentry/vfs/pathname.go | 3 +- pkg/sentry/vfs/resolving_path.go | 2 +- pkg/sentry/vfs/vfs.go | 2 +- pkg/sentry/watchdog/BUILD | 1 + pkg/sentry/watchdog/watchdog.go | 2 +- pkg/sync/BUILD | 53 +++++++ pkg/sync/LICENSE | 27 ++++ pkg/sync/README.md | 5 + pkg/sync/aliases.go | 37 +++++ pkg/sync/atomicptr_unsafe.go | 47 +++++++ pkg/sync/atomicptrtest/BUILD | 29 ++++ pkg/sync/atomicptrtest/atomicptr_test.go | 31 +++++ pkg/sync/downgradable_rwmutex_test.go | 150 ++++++++++++++++++++ pkg/sync/downgradable_rwmutex_unsafe.go | 146 ++++++++++++++++++++ pkg/sync/memmove_unsafe.go | 28 ++++ pkg/sync/norace_unsafe.go | 35 +++++ pkg/sync/race_unsafe.go | 41 ++++++ pkg/sync/seqatomic_unsafe.go | 72 ++++++++++ pkg/sync/seqatomictest/BUILD | 33 +++++ pkg/sync/seqatomictest/seqatomic_test.go | 132 ++++++++++++++++++ pkg/sync/seqcount.go | 149 ++++++++++++++++++++ pkg/sync/seqcount_test.go | 153 +++++++++++++++++++++ pkg/sync/syncutil.go | 7 + pkg/syncutil/BUILD | 52 ------- pkg/syncutil/LICENSE | 27 ---- pkg/syncutil/README.md | 5 - pkg/syncutil/atomicptr_unsafe.go | 47 ------- pkg/syncutil/atomicptrtest/BUILD | 29 ---- pkg/syncutil/atomicptrtest/atomicptr_test.go | 31 ----- pkg/syncutil/downgradable_rwmutex_test.go | 150 -------------------- pkg/syncutil/downgradable_rwmutex_unsafe.go | 146 -------------------- pkg/syncutil/memmove_unsafe.go | 28 ---- pkg/syncutil/norace_unsafe.go | 35 ----- pkg/syncutil/race_unsafe.go | 41 ------ pkg/syncutil/seqatomic_unsafe.go | 72 ---------- pkg/syncutil/seqatomictest/BUILD | 35 ----- pkg/syncutil/seqatomictest/seqatomic_test.go | 132 ------------------ pkg/syncutil/seqcount.go | 149 -------------------- pkg/syncutil/seqcount_test.go | 153 --------------------- pkg/syncutil/syncutil.go | 7 - pkg/tcpip/BUILD | 1 + pkg/tcpip/adapters/gonet/BUILD | 1 + pkg/tcpip/adapters/gonet/gonet.go | 2 +- pkg/tcpip/link/fdbased/BUILD | 1 + pkg/tcpip/link/fdbased/endpoint.go | 2 +- pkg/tcpip/link/sharedmem/BUILD | 2 + pkg/tcpip/link/sharedmem/pipe/BUILD | 1 + pkg/tcpip/link/sharedmem/pipe/pipe_test.go | 3 +- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/network/fragmentation/BUILD | 1 + pkg/tcpip/network/fragmentation/fragmentation.go | 2 +- pkg/tcpip/network/fragmentation/reassembler.go | 2 +- pkg/tcpip/ports/BUILD | 1 + pkg/tcpip/ports/ports.go | 2 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/linkaddrcache.go | 2 +- pkg/tcpip/stack/linkaddrcache_test.go | 2 +- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/stack/transport_demuxer.go | 2 +- pkg/tcpip/tcpip.go | 2 +- pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 3 +- pkg/tcpip/transport/packet/BUILD | 1 + pkg/tcpip/transport/packet/endpoint.go | 3 +- pkg/tcpip/transport/raw/BUILD | 1 + pkg/tcpip/transport/raw/endpoint.go | 3 +- pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 2 +- pkg/tcpip/transport/tcp/endpoint_state.go | 2 +- pkg/tcpip/transport/tcp/forwarder.go | 3 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/tcp/segment_queue.go | 2 +- pkg/tcpip/transport/tcp/snd.go | 2 +- pkg/tcpip/transport/udp/BUILD | 1 + pkg/tcpip/transport/udp/endpoint.go | 3 +- pkg/tmutex/BUILD | 1 + pkg/tmutex/tmutex_test.go | 3 +- pkg/unet/BUILD | 1 + pkg/unet/unet_test.go | 3 +- pkg/urpc/BUILD | 1 + pkg/urpc/urpc.go | 2 +- pkg/waiter/BUILD | 1 + pkg/waiter/waiter.go | 2 +- runsc/boot/BUILD | 2 + runsc/boot/compat.go | 2 +- runsc/boot/limits.go | 2 +- runsc/boot/loader.go | 2 +- runsc/boot/loader_test.go | 2 +- runsc/cmd/BUILD | 1 + runsc/cmd/create.go | 1 + runsc/cmd/gofer.go | 2 +- runsc/cmd/start.go | 1 + runsc/container/BUILD | 2 + runsc/container/console_test.go | 2 +- runsc/container/container_test.go | 2 +- runsc/container/multi_container_test.go | 2 +- runsc/container/state_file.go | 2 +- runsc/fsgofer/BUILD | 1 + runsc/fsgofer/fsgofer.go | 2 +- runsc/sandbox/BUILD | 1 + runsc/sandbox/sandbox.go | 2 +- runsc/testutil/BUILD | 1 + runsc/testutil/testutil.go | 2 +- 303 files changed, 1507 insertions(+), 1368 deletions(-) create mode 100644 pkg/sync/BUILD create mode 100644 pkg/sync/LICENSE create mode 100644 pkg/sync/README.md create mode 100644 pkg/sync/aliases.go create mode 100644 pkg/sync/atomicptr_unsafe.go create mode 100644 pkg/sync/atomicptrtest/BUILD create mode 100644 pkg/sync/atomicptrtest/atomicptr_test.go create mode 100644 pkg/sync/downgradable_rwmutex_test.go create mode 100644 pkg/sync/downgradable_rwmutex_unsafe.go create mode 100644 pkg/sync/memmove_unsafe.go create mode 100644 pkg/sync/norace_unsafe.go create mode 100644 pkg/sync/race_unsafe.go create mode 100644 pkg/sync/seqatomic_unsafe.go create mode 100644 pkg/sync/seqatomictest/BUILD create mode 100644 pkg/sync/seqatomictest/seqatomic_test.go create mode 100644 pkg/sync/seqcount.go create mode 100644 pkg/sync/seqcount_test.go create mode 100644 pkg/sync/syncutil.go delete mode 100644 pkg/syncutil/BUILD delete mode 100644 pkg/syncutil/LICENSE delete mode 100644 pkg/syncutil/README.md delete mode 100644 pkg/syncutil/atomicptr_unsafe.go delete mode 100644 pkg/syncutil/atomicptrtest/BUILD delete mode 100644 pkg/syncutil/atomicptrtest/atomicptr_test.go delete mode 100644 pkg/syncutil/downgradable_rwmutex_test.go delete mode 100644 pkg/syncutil/downgradable_rwmutex_unsafe.go delete mode 100644 pkg/syncutil/memmove_unsafe.go delete mode 100644 pkg/syncutil/norace_unsafe.go delete mode 100644 pkg/syncutil/race_unsafe.go delete mode 100644 pkg/syncutil/seqatomic_unsafe.go delete mode 100644 pkg/syncutil/seqatomictest/BUILD delete mode 100644 pkg/syncutil/seqatomictest/seqatomic_test.go delete mode 100644 pkg/syncutil/seqcount.go delete mode 100644 pkg/syncutil/seqcount_test.go delete mode 100644 pkg/syncutil/syncutil.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 6bc486b62..d99e37b40 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -15,4 +15,5 @@ go_test( size = "small", srcs = ["amutex_test.go"], embed = [":amutex"], + deps = ["//pkg/sync"], ) diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go index 1d7f45641..8a3952f2a 100644 --- a/pkg/amutex/amutex_test.go +++ b/pkg/amutex/amutex_test.go @@ -15,9 +15,10 @@ package amutex import ( - "sync" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) type sleeper struct { diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 36beaade9..6403c60c2 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -20,4 +20,5 @@ go_test( size = "small", srcs = ["atomic_bitops_test.go"], embed = [":atomicbitops"], + deps = ["//pkg/sync"], ) diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomic_bitops_test.go index 965e9be79..9466d3e23 100644 --- a/pkg/atomicbitops/atomic_bitops_test.go +++ b/pkg/atomicbitops/atomic_bitops_test.go @@ -16,8 +16,9 @@ package atomicbitops import ( "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) const iterations = 100 diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index a0b21d4bd..2bb581b18 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -8,7 +8,10 @@ go_library( srcs = ["compressio.go"], importpath = "gvisor.dev/gvisor/pkg/compressio", visibility = ["//:sandbox"], - deps = ["//pkg/binary"], + deps = [ + "//pkg/binary", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index 3b0bb086e..5f52cbe74 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -52,9 +52,9 @@ import ( "hash" "io" "runtime" - "sync" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" ) var bufPool = sync.Pool{ diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD index 21adf3adf..adbd1e3f8 100644 --- a/pkg/control/server/BUILD +++ b/pkg/control/server/BUILD @@ -9,6 +9,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", ], diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go index a56152d10..41abe1f2d 100644 --- a/pkg/control/server/server.go +++ b/pkg/control/server/server.go @@ -22,9 +22,9 @@ package server import ( "os" - "sync" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" ) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 0b4b7cc44..9d68682c7 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ ":eventchannel_go_proto", "//pkg/log", + "//pkg/sync", "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_golang_protobuf//ptypes:go_default_library_gen", @@ -40,6 +41,7 @@ go_test( srcs = ["event_test.go"], embed = [":eventchannel"], deps = [ + "//pkg/sync", "@com_github_golang_protobuf//proto:go_default_library", ], ) diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go index d37ad0428..9a29c58bd 100644 --- a/pkg/eventchannel/event.go +++ b/pkg/eventchannel/event.go @@ -22,13 +22,13 @@ package eventchannel import ( "encoding/binary" "fmt" - "sync" "syscall" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 3649097d6..7f41b4a27 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -16,11 +16,11 @@ package eventchannel import ( "fmt" - "sync" "testing" "time" "github.com/golang/protobuf/proto" + "gvisor.dev/gvisor/pkg/sync" ) // testEmitter is an emitter that can be used in tests. It records all events diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index 56495cbd9..b0478c672 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -15,4 +15,5 @@ go_test( size = "small", srcs = ["fdchannel_test.go"], embed = [":fdchannel"], + deps = ["//pkg/sync"], ) diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go index 5d01dc636..7a8a63a59 100644 --- a/pkg/fdchannel/fdchannel_test.go +++ b/pkg/fdchannel/fdchannel_test.go @@ -17,10 +17,11 @@ package fdchannel import ( "io/ioutil" "os" - "sync" "syscall" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSendRecvFD(t *testing.T) { diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD index aca2d8a82..91a202a30 100644 --- a/pkg/fdnotifier/BUILD +++ b/pkg/fdnotifier/BUILD @@ -11,6 +11,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/fdnotifier", visibility = ["//:sandbox"], deps = [ + "//pkg/sync", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go index f4aae1953..a6b63c982 100644 --- a/pkg/fdnotifier/fdnotifier.go +++ b/pkg/fdnotifier/fdnotifier.go @@ -22,10 +22,10 @@ package fdnotifier import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index e590a71ba..85bd83af1 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -19,7 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", - "//pkg/syncutil", + "//pkg/sync", ], ) @@ -31,4 +31,5 @@ go_test( "flipcall_test.go", ], embed = [":flipcall"], + deps = ["//pkg/sync"], ) diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go index 8d88b845d..2e28a149a 100644 --- a/pkg/flipcall/flipcall_example_test.go +++ b/pkg/flipcall/flipcall_example_test.go @@ -17,7 +17,8 @@ package flipcall import ( "bytes" "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) func Example() { diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index 168a487ec..33fd55a44 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -16,9 +16,10 @@ package flipcall import ( "runtime" - "sync" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) var testPacketWindowSize = pageSize diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 27b8939fc..ac974b232 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -18,7 +18,7 @@ import ( "reflect" "unsafe" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte { var ioSync int64 func raceBecomeActive() { - if syncutil.RaceEnabled { - syncutil.RaceAcquire((unsafe.Pointer)(&ioSync)) + if sync.RaceEnabled { + sync.RaceAcquire((unsafe.Pointer)(&ioSync)) } } func raceBecomeInactive() { - if syncutil.RaceEnabled { - syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + if sync.RaceEnabled { + sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) } } diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index 4b9321711..f22bd070d 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -19,5 +19,6 @@ go_test( ], deps = [ ":gate", + "//pkg/sync", ], ) diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go index 5dbd8d712..850693df8 100644 --- a/pkg/gate/gate_test.go +++ b/pkg/gate/gate_test.go @@ -15,11 +15,11 @@ package gate_test import ( - "sync" "testing" "time" "gvisor.dev/gvisor/pkg/gate" + "gvisor.dev/gvisor/pkg/sync" ) func TestBasicEnter(t *testing.T) { diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index a5d980d14..bcde6d308 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -8,6 +8,7 @@ go_library( srcs = ["linewriter.go"], importpath = "gvisor.dev/gvisor/pkg/linewriter", visibility = ["//visibility:public"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go index cd6e4e2ce..a1b1285d4 100644 --- a/pkg/linewriter/linewriter.go +++ b/pkg/linewriter/linewriter.go @@ -17,7 +17,8 @@ package linewriter import ( "bytes" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Writer is an io.Writer which buffers input, flushing diff --git a/pkg/log/BUILD b/pkg/log/BUILD index fc5f5779b..0df0f2849 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -16,7 +16,10 @@ go_library( visibility = [ "//visibility:public", ], - deps = ["//pkg/linewriter"], + deps = [ + "//pkg/linewriter", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/log/log.go b/pkg/log/log.go index 9387586e6..91a81b288 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -25,12 +25,12 @@ import ( stdlog "log" "os" "runtime" - "sync" "sync/atomic" "syscall" "time" "gvisor.dev/gvisor/pkg/linewriter" + "gvisor.dev/gvisor/pkg/sync" ) // Level is the log level. diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index dd6ca6d39..9145f3233 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -14,6 +14,7 @@ go_library( ":metric_go_proto", "//pkg/eventchannel", "//pkg/log", + "//pkg/sync", ], ) diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index eadde06e4..93d4f2b8c 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -18,12 +18,12 @@ package metric import ( "errors" "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index f32244c69..a3e05c96d 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/fdchannel", "//pkg/flipcall", "//pkg/log", + "//pkg/sync", "//pkg/unet", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 221516c6c..4045e41fa 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -17,12 +17,12 @@ package p9 import ( "errors" "fmt" - "sync" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 28707c0ca..f4edd68b2 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -70,6 +70,7 @@ go_library( "//pkg/fd", "//pkg/log", "//pkg/p9", + "//pkg/sync", "//pkg/unet", "@com_github_golang_mock//gomock:go_default_library", ], @@ -83,6 +84,7 @@ go_test( deps = [ "//pkg/fd", "//pkg/p9", + "//pkg/sync", "@com_github_golang_mock//gomock:go_default_library", ], ) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e758148d..6e7bb3db2 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -22,7 +22,6 @@ import ( "os" "reflect" "strings" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( "github.com/golang/mock/gomock" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" ) func TestPanic(t *testing.T) { diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 4d3271b37..dd8b01b6d 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -17,13 +17,13 @@ package p9test import ( "fmt" - "sync" "sync/atomic" "syscall" "testing" "github.com/golang/mock/gomock" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go index 865459411..72ef53313 100644 --- a/pkg/p9/path_tree.go +++ b/pkg/p9/path_tree.go @@ -16,7 +16,8 @@ package p9 import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // pathNode is a single node in a path traversal. diff --git a/pkg/p9/pool.go b/pkg/p9/pool.go index 52de889e1..2b14a5ce3 100644 --- a/pkg/p9/pool.go +++ b/pkg/p9/pool.go @@ -15,7 +15,7 @@ package p9 import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // pool is a simple allocator. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 40b8fa023..fdfa83648 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -17,7 +17,6 @@ package p9 import ( "io" "runtime/debug" - "sync" "sync/atomic" "syscall" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/fdchannel" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 6e8b4bbcd..9c11e28ce 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -19,11 +19,11 @@ import ( "fmt" "io" "io/ioutil" - "sync" "syscall" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 078f084b2..b506813f0 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -21,6 +21,7 @@ go_test( "procid_test.go", ], embed = [":procid"], + deps = ["//pkg/sync"], ) go_test( @@ -31,4 +32,5 @@ go_test( "procid_test.go", ], embed = [":procid"], + deps = ["//pkg/sync"], ) diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go index 88dd0b3ae..9ec08c3d6 100644 --- a/pkg/procid/procid_test.go +++ b/pkg/procid/procid_test.go @@ -17,9 +17,10 @@ package procid import ( "os" "runtime" - "sync" "syscall" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) // runOnMain is used to send functions to run on the main (initial) thread. diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD index f4f2001f3..9d5b4859b 100644 --- a/pkg/rand/BUILD +++ b/pkg/rand/BUILD @@ -10,5 +10,8 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/rand", visibility = ["//:sandbox"], - deps = ["@org_golang_x_sys//unix:go_default_library"], + deps = [ + "//pkg/sync", + "@org_golang_x_sys//unix:go_default_library", + ], ) diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go index 2b92db3e6..0bdad5fad 100644 --- a/pkg/rand/rand_linux.go +++ b/pkg/rand/rand_linux.go @@ -19,9 +19,9 @@ package rand import ( "crypto/rand" "io" - "sync" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" ) // reader implements an io.Reader that returns pseudorandom bytes. diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 7ad59dfd7..974d9af9b 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -27,6 +27,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", ], ) @@ -35,4 +36,5 @@ go_test( size = "small", srcs = ["refcounter_test.go"], embed = [":refs"], + deps = ["//pkg/sync"], ) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index ad69e0757..c45ba8200 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -21,10 +21,10 @@ import ( "fmt" "reflect" "runtime" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" ) // RefCounter is the interface to be implemented by objects that are reference diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go index ffd3d3f07..1ab4a4440 100644 --- a/pkg/refs/refcounter_test.go +++ b/pkg/refs/refcounter_test.go @@ -16,8 +16,9 @@ package refs import ( "reflect" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) type testCounter struct { diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 18c73cc24..ae3e364cd 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/limits", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 9294ac773..9f41e566f 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -19,7 +19,6 @@ package arch import ( "fmt" "io" - "sync" "syscall" "gvisor.dev/gvisor/pkg/binary" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 5522cecd0..2561a6109 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/sentry/strace", "//pkg/sentry/usage", "//pkg/sentry/watchdog", + "//pkg/sync", "//pkg/tcpip/link/sniffer", "//pkg/urpc", ], diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index e1f2fea60..151808911 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -19,10 +19,10 @@ import ( "runtime" "runtime/pprof" "runtime/trace" - "sync" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/urpc" ) diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 1098ed777..97fa1512c 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -8,7 +8,10 @@ go_library( srcs = ["device.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/device", visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/abi/linux"], + deps = [ + "//pkg/abi/linux", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go index 47945d1a7..69e71e322 100644 --- a/pkg/sentry/device/device.go +++ b/pkg/sentry/device/device.go @@ -19,10 +19,10 @@ package device import ( "bytes" "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" ) // Registry tracks all simple devices and related state on the system for diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index c035ffff7..7d5d72d5a 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -68,7 +68,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], @@ -115,6 +115,7 @@ go_test( "//pkg/sentry/fs/tmpfs", "//pkg/sentry/kernel/contexttest", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 9ac62c84d..734177e90 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -17,12 +17,12 @@ package fs import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index 1d80bf15a..738580c5f 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -19,13 +19,13 @@ import ( "crypto/rand" "fmt" "io" - "sync" "testing" "gvisor.dev/gvisor/pkg/sentry/fs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) const ( diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 3cb73bd78..31fc4d87b 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -18,7 +18,6 @@ import ( "fmt" "path" "sort" - "sync" "sync/atomic" "syscall" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go index 60a15a275..25514ace4 100644 --- a/pkg/sentry/fs/dirent_cache.go +++ b/pkg/sentry/fs/dirent_cache.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // DirentCache is an LRU cache of Dirents. The Dirent's refCount is diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go index ebb80bd50..525ee25f9 100644 --- a/pkg/sentry/fs/dirent_cache_limiter.go +++ b/pkg/sentry/fs/dirent_cache_limiter.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // DirentCacheLimiter acts as a global limit for all dirent caches in the diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index 277ee4c31..cc43de69d 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -23,6 +23,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/safemem", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 669ffcb75..5b6cfeb0a 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -17,7 +17,6 @@ package fdpipe import ( "os" - "sync" "syscall" "gvisor.dev/gvisor/pkg/fd" @@ -29,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go index 29175fb3d..cee87f726 100644 --- a/pkg/sentry/fs/fdpipe/pipe_state.go +++ b/pkg/sentry/fs/fdpipe/pipe_state.go @@ -17,10 +17,10 @@ package fdpipe import ( "fmt" "io/ioutil" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" ) // beforeSave is invoked by stateify. diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index a2f966cb6..7c4586296 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -16,7 +16,6 @@ package fs import ( "math" - "sync" "sync/atomic" "time" @@ -29,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 225e40186..8a633b1ba 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -16,13 +16,13 @@ package fs import ( "io" - "sync" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go index b157fd228..c5b51620a 100644 --- a/pkg/sentry/fs/filesystems.go +++ b/pkg/sentry/fs/filesystems.go @@ -18,9 +18,9 @@ import ( "fmt" "sort" "strings" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" ) // FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags. diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index 8b2a5e6b2..26abf49e2 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -54,10 +54,9 @@ package fs import ( - "sync" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 9ca695a95..945b6270d 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -93,6 +93,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index b06a71cc2..837fc70b5 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -16,7 +16,6 @@ package fsutil import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/log" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // HostFileMapper caches mappings of an arbitrary host file descriptor. It is diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 30475f340..a625f0e26 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -16,7 +16,6 @@ package fsutil import ( "math" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // HostMappable implements memmap.Mappable and platform.File over a diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index 4e100a402..adf5ec69c 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -15,13 +15,12 @@ package fsutil import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 798920d18..20a014402 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -17,7 +17,6 @@ package fsutil import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // Lock order (compare the lock order model in mm/mm.go): diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 4a005c605..fd870e8e1 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -44,6 +44,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 91263ebdc..245fe2ef1 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -16,7 +16,6 @@ package gofer import ( "errors" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/host" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 4e358a46a..edc796ce0 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -16,7 +16,6 @@ package gofer import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refs" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 23daeb528..2b581aa69 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sentry/unimpl", "//pkg/sentry/uniqueid", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index a6e4a09e3..873a1c52d 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -15,7 +15,6 @@ package host import ( - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 107336a3e..c076d5bdd 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -16,7 +16,6 @@ package host import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -30,6 +29,7 @@ import ( unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 90331e3b2..753ef8cd6 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -15,8 +15,6 @@ package host import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index 91e2fde2f..468043df0 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -15,8 +15,6 @@ package fs import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" @@ -26,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go index 0f2a66a79..efd3c962b 100644 --- a/pkg/sentry/fs/inode_inotify.go +++ b/pkg/sentry/fs/inode_inotify.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Watches is the collection of inotify watches on an inode. diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index ba3e0233d..cc7dd1c92 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -16,7 +16,6 @@ package fs import ( "io" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go index 0aa0a5e9b..900cba3ca 100644 --- a/pkg/sentry/fs/inotify_watch.go +++ b/pkg/sentry/fs/inotify_watch.go @@ -15,10 +15,10 @@ package fs import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" ) // Watch represent a particular inotify watch created by inotify_add_watch. diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 8d62642e7..2c332a82a 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -44,6 +44,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go index 636484424..41b040818 100644 --- a/pkg/sentry/fs/lock/lock.go +++ b/pkg/sentry/fs/lock/lock.go @@ -52,9 +52,9 @@ package lock import ( "fmt" "math" - "sync" "syscall" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index ac0398bd9..db3dfd096 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -19,7 +19,6 @@ import ( "math" "path" "strings" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 25573e986..4cad55327 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -17,13 +17,12 @@ package fs import ( "fmt" "strings" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -199,7 +198,7 @@ type overlayEntry struct { upper *Inode // dirCacheMu protects dirCache. - dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"` + dirCacheMu sync.DowngradableRWMutex `state:"nosave"` // dirCache is cache of DentAttrs from upper and lower Inodes. dirCache *SortedDentryMap diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 75cbb0622..94d46ab1b 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", "//pkg/waiter", diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index fe7067be1..38b246dff 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/proc/device", "//pkg/sentry/kernel/time", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go index 5fe823000..f9af191d5 100644 --- a/pkg/sentry/fs/proc/seqfile/seqfile.go +++ b/pkg/sentry/fs/proc/seqfile/seqfile.go @@ -17,7 +17,6 @@ package seqfile import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -26,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index bd93f83fa..a37e1fa06 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -17,7 +17,6 @@ package proc import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 012cb3e44..3fb7b0633 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index 78e082b8e..dcbb8eb2e 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -17,7 +17,6 @@ package ramfs import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go index f10168125..64c6a6ae9 100644 --- a/pkg/sentry/fs/restore.go +++ b/pkg/sentry/fs/restore.go @@ -15,7 +15,7 @@ package fs import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // RestoreEnvironment is the restore environment for file systems. It consists diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 59ce400c2..3400b940c 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index f86dfaa36..f1c87fe41 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -17,7 +17,6 @@ package tmpfs import ( "fmt" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 95ad98cb0..f6f60d0cf 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 2f639c823..88aa66b24 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -19,7 +19,6 @@ import ( "fmt" "math" "strconv" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 7cc0eb409..894964260 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -16,13 +16,13 @@ package tty import ( "bytes" - "sync" "unicode/utf8" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index 231e4e6eb..8b5d4699a 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -15,13 +15,12 @@ package tty import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index bc90330bc..903874141 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sentry/syscalls/linux", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 91802dc1e..8944171c8 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -15,8 +15,6 @@ package ext import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/log" @@ -25,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 616fc002a..9afb1a84c 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -17,13 +17,13 @@ package ext import ( "errors" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index aec33e00a..d11153c90 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -16,7 +16,6 @@ package ext import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 39c03ee9d..809178250 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -39,6 +39,7 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", ], ) @@ -56,6 +57,7 @@ go_test( "//pkg/sentry/kernel/auth", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "@com_github_google_go-cmp//cmp:go_default_library", ], diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 752e0f659..1d469a0db 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -16,7 +16,6 @@ package kernfs import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index d69b299ae..bb12f39a2 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -53,7 +53,6 @@ package kernfs import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -61,6 +60,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" ) // FilesystemType implements vfs.FilesystemType. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 4b6b95f5f..5c9d580e1 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "runtime" - "sync" "testing" "github.com/google/go-cmp/cmp" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index a5b285987..82f5c2f41 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index f51e247a7..f200e767d 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -17,7 +17,6 @@ package tmpfs import ( "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 7be6faa5b..701826f90 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -26,7 +26,6 @@ package tmpfs import ( "fmt" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 2706927ff..ac85ba0c8 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -35,7 +35,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//pkg/syncutil:generic_seqatomic", + template = "//pkg/sync:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -209,7 +209,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/state", "//pkg/state/statefile", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", @@ -241,6 +241,7 @@ go_test( "//pkg/sentry/time", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 244655b5c..920fe4329 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -15,11 +15,11 @@ package kernel import ( - "sync" "syscall" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" ) // +stateify savable diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 04c244447..1aa72fa47 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//pkg/syncutil:generic_atomicptr", + template = "//pkg/sync:generic_atomicptr", types = { "Value": "Credentials", }, @@ -64,6 +64,7 @@ go_library( "//pkg/bits", "//pkg/log", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go index af28ccc65..9dd52c860 100644 --- a/pkg/sentry/kernel/auth/user_namespace.go +++ b/pkg/sentry/kernel/auth/user_namespace.go @@ -16,8 +16,8 @@ package auth import ( "math" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index 3361e8b7d..c47f6b6fc 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 9c0a4e1b4..430311cc0 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -18,7 +18,6 @@ package epoll import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/refs" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index e65b961e8..c831fbab2 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 12f0d429b..687690679 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -18,7 +18,6 @@ package eventfd import ( "math" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 49d81b712..6b36bc63e 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index 6b0bb0324..d32c3e90a 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -16,12 +16,11 @@ package fasync import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 11f613a11..cd1501f85 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -18,7 +18,6 @@ import ( "bytes" "fmt" "math" - "sync" "sync/atomic" "syscall" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) // FDFlags define flags for an individual descriptor. diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2bcb6216a..eccb7d1e7 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -16,7 +16,6 @@ package kernel import ( "runtime" - "sync" "testing" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/filetest" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) const ( diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index ded27d668..2448c1d99 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -16,10 +16,10 @@ package kernel import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" ) // FSContext contains filesystem context. diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 75ec31761..50db443ce 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//pkg/syncutil:generic_atomicptr", + template = "//pkg/sync:generic_atomicptr", types = { "Value": "bucket", }, @@ -42,6 +42,7 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/memmap", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) @@ -51,5 +52,8 @@ go_test( size = "small", srcs = ["futex_test.go"], embed = [":futex"], - deps = ["//pkg/sentry/usermem"], + deps = [ + "//pkg/sentry/usermem", + "//pkg/sync", + ], ) diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 278cc8143..d1931c8f4 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -18,11 +18,10 @@ package futex import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go index 65e5d1428..c23126ca5 100644 --- a/pkg/sentry/kernel/futex/futex_test.go +++ b/pkg/sentry/kernel/futex/futex_test.go @@ -17,13 +17,13 @@ package futex import ( "math" "runtime" - "sync" "sync/atomic" "syscall" "testing" "unsafe" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // testData implements the Target interface, and allows us to diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8653d2f63..c85e97fef 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -36,7 +36,6 @@ import ( "fmt" "io" "path/filepath" - "sync" "sync/atomic" "time" @@ -67,6 +66,7 @@ import ( uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index d7a7d1169..7f36252a9 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/metric", "//pkg/sentry/kernel", "//pkg/sentry/usage", + "//pkg/sync", ], ) diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go index b0d98e7f0..200565bb8 100644 --- a/pkg/sentry/kernel/memevent/memory_events.go +++ b/pkg/sentry/kernel/memevent/memory_events.go @@ -17,7 +17,6 @@ package memevent import ( - "sync" "time" "gvisor.dev/gvisor/pkg/eventchannel" @@ -26,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sync" ) var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.") diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 9d34f6d4d..5eeaeff66 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -43,6 +43,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 95bee2d37..1c0f34269 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -16,9 +16,9 @@ package pipe import ( "io" - "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sync" ) // buffer encapsulates a queueable byte buffer. diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 4a19ab7ce..716f589af 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -15,12 +15,11 @@ package pipe import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 1a1b38f83..e4fd7d420 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -17,12 +17,12 @@ package pipe import ( "fmt" - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index ef9641e6a..8394eb78b 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -17,7 +17,6 @@ package pipe import ( "io" "math" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 6416e0dd8..bf7461cbb 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -15,13 +15,12 @@ package pipe import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index f4c00cd86..13a961594 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index de9617e9d..18299814e 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -17,7 +17,6 @@ package semaphore import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index cd48945e6..7321b22ed 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -24,6 +24,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 19034a21e..8ddef7eb8 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -35,7 +35,6 @@ package shm import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" @@ -49,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go index a16f3d57f..768fda220 100644 --- a/pkg/sentry/kernel/signal_handlers.go +++ b/pkg/sentry/kernel/signal_handlers.go @@ -15,10 +15,9 @@ package kernel import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sync" ) // SignalHandlers holds information about signal actions. diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 9f7e19b4d..89e4d84b1 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 4b08d7d72..28be4a939 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -16,8 +16,6 @@ package signalfd import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/context" @@ -26,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 2fdee0282..d2d01add4 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -16,13 +16,13 @@ package kernel import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // maxSyscallNum is the highest supported syscall number. diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index 8227ecf1d..4607cde2f 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -17,7 +17,8 @@ package kernel import ( "fmt" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // syslog represents a sentry-global kernel log. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index d25a7903b..978d66da8 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -17,7 +17,6 @@ package kernel import ( gocontext "context" "runtime/trace" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -37,7 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) @@ -85,7 +84,7 @@ type Task struct { // // gosched is protected by goschedSeq. gosched is owned by the task // goroutine. - goschedSeq syncutil.SeqCount `state:"nosave"` + goschedSeq sync.SeqCount `state:"nosave"` gosched TaskGoroutineSchedInfo // yieldCount is the number of times the task goroutine has called diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index c0197a563..768e958d2 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -15,7 +15,6 @@ package kernel import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 8267929a6..bf2dabb6e 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -16,9 +16,9 @@ package kernel import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 31847e1df..4e4de0512 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index 107394183..706de83ef 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -19,10 +19,10 @@ package time import ( "fmt" "math" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 76417342a..dc99301de 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -16,7 +16,6 @@ package kernel import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/log" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sync" ) // Timekeeper manages all of the kernel clocks. diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go index 048de26dc..464d2306a 100644 --- a/pkg/sentry/kernel/tty.go +++ b/pkg/sentry/kernel/tty.go @@ -14,7 +14,7 @@ package kernel -import "sync" +import "gvisor.dev/gvisor/pkg/sync" // TTY defines the relationship between a thread group and its controlling // terminal. diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go index 0a563e715..8ccf04bd1 100644 --- a/pkg/sentry/kernel/uts_namespace.go +++ b/pkg/sentry/kernel/uts_namespace.go @@ -15,9 +15,8 @@ package kernel import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" ) // UTSNamespace represents a UTS namespace, a holder of two system identifiers: diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 156e67bf8..9fa841e8b 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/sentry/context", + "//pkg/sync", ], ) diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go index b6c22656b..31b9e9ff6 100644 --- a/pkg/sentry/limits/limits.go +++ b/pkg/sentry/limits/limits.go @@ -16,8 +16,9 @@ package limits import ( - "sync" "syscall" + + "gvisor.dev/gvisor/pkg/sync" ) // LimitType defines a type of resource limit. diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 839931f67..83e248431 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -118,7 +118,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usage", "//pkg/sentry/usermem", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/buffer", ], diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 1b746d030..4b48866ad 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -15,8 +15,6 @@ package mm import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" @@ -25,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 58a5c186d..fa86ebced 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -35,8 +35,6 @@ package mm import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -44,7 +42,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // MemoryManager implements a virtual address space. @@ -82,7 +80,7 @@ type MemoryManager struct { users int32 // mappingMu is analogous to Linux's struct mm_struct::mmap_sem. - mappingMu syncutil.DowngradableRWMutex `state:"nosave"` + mappingMu sync.DowngradableRWMutex `state:"nosave"` // vmas stores virtual memory areas. Since vmas are stored by value, // clients should usually use vmaIterator.ValuePtr() instead of @@ -125,7 +123,7 @@ type MemoryManager struct { // activeMu is loosely analogous to Linux's struct // mm_struct::page_table_lock. - activeMu syncutil.DowngradableRWMutex `state:"nosave"` + activeMu sync.DowngradableRWMutex `state:"nosave"` // pmas stores platform mapping areas used to implement vmas. Since pmas // are stored by value, clients should usually use pmaIterator.ValuePtr() diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index f404107af..a9a2642c5 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -73,6 +73,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index f7f7298c4..c99e023d9 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -25,7 +25,6 @@ import ( "fmt" "math" "os" - "sync" "sync/atomic" "syscall" "time" @@ -37,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index b6d008dbe..85e882df9 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -10,6 +10,7 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go index a4651f500..57be41647 100644 --- a/pkg/sentry/platform/interrupt/interrupt.go +++ b/pkg/sentry/platform/interrupt/interrupt.go @@ -17,7 +17,8 @@ package interrupt import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Receiver receives interrupt notifications from a Forwarder. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index f3afd98da..6a358d1d4 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -55,6 +55,7 @@ go_library( "//pkg/sentry/platform/safecopy", "//pkg/sentry/time", "//pkg/sentry/usermem", + "//pkg/sync", ], ) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index ea8b9632e..a25f3c449 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -15,13 +15,13 @@ package kvm import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // dirtySet tracks vCPUs for invalidation. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index e5fac0d6a..2f02c03cf 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -17,8 +17,6 @@ package kvm import ( - "unsafe" - "gvisor.dev/gvisor/pkg/sentry/arch" ) diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index f2c2c059e..a7850faed 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -18,13 +18,13 @@ package kvm import ( "fmt" "os" - "sync" "syscall" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // KVM represents a lightweight VM context. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 7d02ebf19..e6d912168 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -17,7 +17,6 @@ package kvm import ( "fmt" "runtime" - "sync" "sync/atomic" "syscall" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // machine contains state associated with the VM as a whole. diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 0df8cfa0f..cd13390c3 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/safecopy", "//pkg/sentry/usermem", + "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 7b120a15d..bb0e03880 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -46,13 +46,13 @@ package ptrace import ( "os" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 20244fd95..15dc46a5b 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -18,7 +18,6 @@ import ( "fmt" "os" "runtime" - "sync" "syscall" "golang.org/x/sys/unix" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // Linux kernel errnos which "should never be seen by user programs", but will diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index 2e6fbe488..245b20722 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -18,7 +18,6 @@ package ptrace import ( - "sync" "sync/atomic" "syscall" "unsafe" @@ -26,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/hostcpu" + "gvisor.dev/gvisor/pkg/sync" ) // maskPool contains reusable CPU masks for setting affinity. Unfortunately, diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index 3f094c2a7..86fd5ed58 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -17,7 +17,7 @@ package ring0 import ( "syscall" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) // Kernel is a global kernel object. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 10dbd381f..9dae0dccb 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -18,6 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index dc0eeec01..a850ce6cf 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -18,6 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index e2e15ba5c..387a7f6c3 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -96,7 +96,10 @@ go_library( "//pkg/sentry/platform/kvm:__subpackages__", "//pkg/sentry/platform/ring0:__subpackages__", ], - deps = ["//pkg/sentry/usermem"], + deps = [ + "//pkg/sentry/usermem", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go index 0f029f25d..e199bae18 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go +++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go @@ -17,7 +17,7 @@ package pagetables import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // limitPCID is the number of valid PCIDs. diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 136821963..103933144 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 463544c1a..2d9f4ba9b 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -8,6 +8,7 @@ go_library( srcs = ["port.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go index e9d3275b1..2cd3afc22 100644 --- a/pkg/sentry/socket/netlink/port/port.go +++ b/pkg/sentry/socket/netlink/port/port.go @@ -24,7 +24,8 @@ import ( "fmt" "math" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // maxPorts is a sanity limit on the maximum number of ports to allocate per diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d2e3644a6..cea56f4ed 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -17,7 +17,6 @@ package netlink import ( "math" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e414d8055..f78784569 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 764f11a6b..0affb8071 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,7 +29,6 @@ import ( "io" "math" "reflect" - "sync" "syscall" "time" @@ -49,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD index 23eadcb1b..b2677c659 100644 --- a/pkg/sentry/socket/rpcinet/conn/BUILD +++ b/pkg/sentry/socket/rpcinet/conn/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/binary", "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", + "//pkg/sync", "//pkg/syserr", "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go index 356adad99..02f39c767 100644 --- a/pkg/sentry/socket/rpcinet/conn/conn.go +++ b/pkg/sentry/socket/rpcinet/conn/conn.go @@ -17,12 +17,12 @@ package conn import ( "fmt" - "sync" "sync/atomic" "syscall" "github.com/golang/protobuf/proto" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/unet" diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD index a3585e10d..a5954f22b 100644 --- a/pkg/sentry/socket/rpcinet/notifier/BUILD +++ b/pkg/sentry/socket/rpcinet/notifier/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", "//pkg/sentry/socket/rpcinet/conn", + "//pkg/sync", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go index 7efe4301f..82b75d6dd 100644 --- a/pkg/sentry/socket/rpcinet/notifier/notifier.go +++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go @@ -17,12 +17,12 @@ package notifier import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 788ad70d2..d7ba95dff 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/ilist", "//pkg/refs", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index dea11e253..9e6fbc111 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -15,10 +15,9 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e27b1c714..5dcd3d95e 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,9 +15,8 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 37c7ac3c1..fcc0da332 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,11 +16,11 @@ package transport import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index a76975cee..aa05e208a 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -91,6 +91,7 @@ go_library( "//pkg/sentry/syscalls", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/waiter", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 1d9018c96..60469549d 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -16,13 +16,13 @@ package linux import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 18e212dff..3cde3a0be 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//pkg/syncutil:generic_seqatomic", + template = "//pkg/sync:generic_seqatomic", types = { "Value": "Parameters", }, @@ -36,7 +36,7 @@ go_library( deps = [ "//pkg/log", "//pkg/metric", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go index 318503277..f9a93115d 100644 --- a/pkg/sentry/time/calibrated_clock.go +++ b/pkg/sentry/time/calibrated_clock.go @@ -17,11 +17,11 @@ package time import ( - "sync" "time" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index c32fe3241..5518ac3d0 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -18,5 +18,6 @@ go_library( deps = [ "//pkg/bits", "//pkg/memutil", + "//pkg/sync", ], ) diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index d6ef644d8..538c645eb 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -17,12 +17,12 @@ package usage import ( "fmt" "os" - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/sync" ) // MemoryKind represents a type of memory used by the application. diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 4c6aa04a1..35c7be259 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -34,7 +34,7 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/usermem", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], @@ -54,6 +54,7 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/kernel/auth", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 1bc9c4a38..486a76475 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -16,9 +16,9 @@ package vfs import ( "fmt" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 66eb57bc2..c00b3c84b 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -17,13 +17,13 @@ package vfs import ( "bytes" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index adff0b94b..3b933468d 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -17,8 +17,9 @@ package vfs import ( "fmt" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestMountTableLookupEmpty(t *testing.T) { diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index ab13fa461..bd90d36c4 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -26,7 +26,7 @@ import ( "sync/atomic" "unsafe" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // mountKey represents the location at which a Mount is mounted. It is @@ -75,7 +75,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq syncutil.SeqCount + seq sync.SeqCount seed uint32 // for hashing keys // size holds both length (number of elements) and capacity (number of diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index 8e155654f..cf80df90e 100644 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go @@ -15,10 +15,9 @@ package vfs import ( - "sync" - "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index f0641d314..8a0b382f6 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -16,11 +16,11 @@ package vfs import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index ea2db7031..1f21b0b31 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -29,12 +29,12 @@ package vfs import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD index 4d8435265..28f21f13d 100644 --- a/pkg/sentry/watchdog/BUILD +++ b/pkg/sentry/watchdog/BUILD @@ -13,5 +13,6 @@ go_library( "//pkg/metric", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", + "//pkg/sync", ], ) diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 5e4611333..bfb2fac26 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -32,7 +32,6 @@ package watchdog import ( "bytes" "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -40,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" ) // Opts configures the watchdog. diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD new file mode 100644 index 000000000..e8cd16b8f --- /dev/null +++ b/pkg/sync/BUILD @@ -0,0 +1,53 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +exports_files(["LICENSE"]) + +go_template( + name = "generic_atomicptr", + srcs = ["atomicptr_unsafe.go"], + types = [ + "Value", + ], +) + +go_template( + name = "generic_seqatomic", + srcs = ["seqatomic_unsafe.go"], + types = [ + "Value", + ], + deps = [ + ":sync", + ], +) + +go_library( + name = "sync", + srcs = [ + "aliases.go", + "downgradable_rwmutex_unsafe.go", + "memmove_unsafe.go", + "norace_unsafe.go", + "race_unsafe.go", + "seqcount.go", + "syncutil.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sync", +) + +go_test( + name = "sync_test", + size = "small", + srcs = [ + "downgradable_rwmutex_test.go", + "seqcount_test.go", + ], + embed = [":sync"], +) diff --git a/pkg/sync/LICENSE b/pkg/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/pkg/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sync/README.md b/pkg/sync/README.md new file mode 100644 index 000000000..2183c4e20 --- /dev/null +++ b/pkg/sync/README.md @@ -0,0 +1,5 @@ +# Syncutil + +This package provides additional synchronization primitives not provided by the +Go stdlib 'sync' package. It is partially derived from the upstream 'sync' +package from go1.10. diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go new file mode 100644 index 000000000..20c7ca041 --- /dev/null +++ b/pkg/sync/aliases.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "sync" +) + +// Aliases of standard library types. +type ( + // Mutex is an alias of sync.Mutex. + Mutex = sync.Mutex + + // RWMutex is an alias of sync.RWMutex. + RWMutex = sync.RWMutex + + // Cond is an alias of sync.Cond. + Cond = sync.Cond + + // Locker is an alias of sync.Locker. + Locker = sync.Locker + + // Once is an alias of sync.Once. + Once = sync.Once + + // Pool is an alias of sync.Pool. + Pool = sync.Pool + + // WaitGroup is an alias of sync.WaitGroup. + WaitGroup = sync.WaitGroup + + // Map is an alias of sync.Map. + Map = sync.Map +) diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/atomicptr_unsafe.go new file mode 100644 index 000000000..525c4beed --- /dev/null +++ b/pkg/sync/atomicptr_unsafe.go @@ -0,0 +1,47 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "sync/atomic" + "unsafe" +) + +// Value is a required type parameter. +type Value struct{} + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +// +// +stateify savable +type AtomicPtr struct { + ptr unsafe.Pointer `state:".(*Value)"` +} + +func (p *AtomicPtr) savePtr() *Value { + return p.Load() +} + +func (p *AtomicPtr) loadPtr(v *Value) { + p.Store(v) +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtr) Load() *Value { + return (*Value)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtr) Store(x *Value) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD new file mode 100644 index 000000000..418eda29c --- /dev/null +++ b/pkg/sync/atomicptrtest/BUILD @@ -0,0 +1,29 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "atomicptr_int", + out = "atomicptr_int_unsafe.go", + package = "atomicptr", + suffix = "Int", + template = "//pkg/sync:generic_atomicptr", + types = { + "Value": "int", + }, +) + +go_library( + name = "atomicptr", + srcs = ["atomicptr_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr", +) + +go_test( + name = "atomicptr_test", + size = "small", + srcs = ["atomicptr_test.go"], + embed = [":atomicptr"], +) diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go new file mode 100644 index 000000000..8fdc5112e --- /dev/null +++ b/pkg/sync/atomicptrtest/atomicptr_test.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package atomicptr + +import ( + "testing" +) + +func newInt(val int) *int { + return &val +} + +func TestAtomicPtr(t *testing.T) { + var p AtomicPtrInt + if got := p.Load(); got != nil { + t.Errorf("initial value is %p (%v), wanted nil", got, got) + } + want := newInt(42) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } + want = newInt(100) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } +} diff --git a/pkg/sync/downgradable_rwmutex_test.go b/pkg/sync/downgradable_rwmutex_test.go new file mode 100644 index 000000000..f04496bc5 --- /dev/null +++ b/pkg/sync/downgradable_rwmutex_test.go @@ -0,0 +1,150 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// GOMAXPROCS=10 go test + +// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the +// addition of downgradingWriter and the renaming of num_iterations to +// numIterations to shut up Golint. + +package sync + +import ( + "fmt" + "runtime" + "sync/atomic" + "testing" +) + +func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { + m.RLock() + clocked <- true + <-cunlock + m.RUnlock() + cdone <- true +} + +func doTestParallelReaders(numReaders, gomaxprocs int) { + runtime.GOMAXPROCS(gomaxprocs) + var m DowngradableRWMutex + clocked := make(chan bool) + cunlock := make(chan bool) + cdone := make(chan bool) + for i := 0; i < numReaders; i++ { + go parallelReader(&m, clocked, cunlock, cdone) + } + // Wait for all parallel RLock()s to succeed. + for i := 0; i < numReaders; i++ { + <-clocked + } + for i := 0; i < numReaders; i++ { + cunlock <- true + } + // Wait for the goroutines to finish. + for i := 0; i < numReaders; i++ { + <-cdone + } +} + +func TestParallelReaders(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + doTestParallelReaders(1, 4) + doTestParallelReaders(3, 4) + doTestParallelReaders(4, 2) +} + +func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.RLock() + n := atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.Unlock() + } + cdone <- true +} + +func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.DowngradeLock() + n = atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + n = atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { + runtime.GOMAXPROCS(gomaxprocs) + // Number of active readers + 10000 * number of active writers. + var activity int32 + var rwm DowngradableRWMutex + cdone := make(chan bool) + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + var i int + for i = 0; i < numReaders/2; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + for ; i < numReaders; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + // Wait for the 4 writers and all readers to finish. + for i := 0; i < 4+numReaders; i++ { + <-cdone + } +} + +func TestDowngradableRWMutex(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + n := 1000 + if testing.Short() { + n = 5 + } + HammerDowngradableRWMutex(1, 1, n) + HammerDowngradableRWMutex(1, 3, n) + HammerDowngradableRWMutex(1, 10, n) + HammerDowngradableRWMutex(4, 1, n) + HammerDowngradableRWMutex(4, 3, n) + HammerDowngradableRWMutex(4, 10, n) + HammerDowngradableRWMutex(10, 1, n) + HammerDowngradableRWMutex(10, 3, n) + HammerDowngradableRWMutex(10, 10, n) + HammerDowngradableRWMutex(10, 5, n) +} diff --git a/pkg/sync/downgradable_rwmutex_unsafe.go b/pkg/sync/downgradable_rwmutex_unsafe.go new file mode 100644 index 000000000..9bb55cd3a --- /dev/null +++ b/pkg/sync/downgradable_rwmutex_unsafe.go @@ -0,0 +1,146 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +// This is mostly copied from the standard library's sync/rwmutex.go. +// +// Happens-before relationships indicated to the race detector: +// - Unlock -> Lock (via writerSem) +// - Unlock -> RLock (via readerSem) +// - RUnlock -> Lock (via writerSem) +// - DowngradeLock -> RLock (via readerSem) + +package sync + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +//go:linkname runtimeSemacquire sync.runtime_Semacquire +func runtimeSemacquire(s *uint32) + +//go:linkname runtimeSemrelease sync.runtime_Semrelease +func runtimeSemrelease(s *uint32, handoff bool, skipframes int) + +// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock +// method. +type DowngradableRWMutex struct { + w sync.Mutex // held if there are pending writers + writerSem uint32 // semaphore for writers to wait for completing readers + readerSem uint32 // semaphore for readers to wait for completing writers + readerCount int32 // number of pending readers + readerWait int32 // number of departing readers +} + +const rwmutexMaxReaders = 1 << 30 + +// RLock locks rw for reading. +func (rw *DowngradableRWMutex) RLock() { + if RaceEnabled { + RaceDisable() + } + if atomic.AddInt32(&rw.readerCount, 1) < 0 { + // A writer is pending, wait for it. + runtimeSemacquire(&rw.readerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.readerSem)) + } +} + +// RUnlock undoes a single RLock call. +func (rw *DowngradableRWMutex) RUnlock() { + if RaceEnabled { + RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) + RaceDisable() + } + if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { + if r+1 == 0 || r+1 == -rwmutexMaxReaders { + panic("RUnlock of unlocked DowngradableRWMutex") + } + // A writer is pending. + if atomic.AddInt32(&rw.readerWait, -1) == 0 { + // The last reader unblocks the writer. + runtimeSemrelease(&rw.writerSem, false, 0) + } + } + if RaceEnabled { + RaceEnable() + } +} + +// Lock locks rw for writing. +func (rw *DowngradableRWMutex) Lock() { + if RaceEnabled { + RaceDisable() + } + // First, resolve competition with other writers. + rw.w.Lock() + // Announce to readers there is a pending writer. + r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders + // Wait for active readers. + if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { + runtimeSemacquire(&rw.writerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.writerSem)) + } +} + +// Unlock unlocks rw for writing. +func (rw *DowngradableRWMutex) Unlock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.writerSem)) + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) + if r >= rwmutexMaxReaders { + panic("Unlock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. + for i := 0; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} + +// DowngradeLock atomically unlocks rw for writing and locks it for reading. +func (rw *DowngradableRWMutex) DowngradeLock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer and one additional reader. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) + if r >= rwmutexMaxReaders+1 { + panic("DowngradeLock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. Note that this loop starts as 1 since r + // includes this goroutine. + for i := 1; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed to rw.w.Lock(). Note that they will still + // block on rw.writerSem since at least this reader exists, such that + // DowngradeLock() is atomic with the previous write lock. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go new file mode 100644 index 000000000..ad4a3a37e --- /dev/null +++ b/pkg/sync/memmove_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +package sync + +import ( + "unsafe" +) + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad, which can't +// define it because go_generics can't update the go:linkname annotation. +// Furthermore, go:linkname silently doesn't work if the local name is exported +// (this is of course undocumented), which is why this indirection is +// necessary. +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go new file mode 100644 index 000000000..006055dd6 --- /dev/null +++ b/pkg/sync/norace_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !race + +package sync + +import ( + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = false + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { +} diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go new file mode 100644 index 000000000..31d8fa9a6 --- /dev/null +++ b/pkg/sync/race_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package sync + +import ( + "runtime" + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = true + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { + runtime.RaceDisable() +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { + runtime.RaceEnable() +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { + runtime.RaceAcquire(addr) +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { + runtime.RaceRelease(addr) +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { + runtime.RaceReleaseMerge(addr) +} diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go new file mode 100644 index 000000000..eda6fb131 --- /dev/null +++ b/pkg/sync/seqatomic_unsafe.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/pkg/sync" +) + +// Value is a required type parameter. +// +// Value must not contain any pointers, including interface objects, function +// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs +// containing any of the above. An init() function will panic if this property +// does not hold. +type Value struct{} + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val Value + for { + epoch := sc.BeginRead() + if sync.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, + // so it can't help us here. Instead, call runtime.memmove + // directly, which is not instrumented by the race detector. + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + // This is ~40% faster for short reads than going through memmove. + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) { + var val Value + if sync.RaceEnabled { + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func init() { + var val Value + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD new file mode 100644 index 000000000..eba21518d --- /dev/null +++ b/pkg/sync/seqatomictest/BUILD @@ -0,0 +1,33 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "seqatomic_int", + out = "seqatomic_int_unsafe.go", + package = "seqatomic", + suffix = "Int", + template = "//pkg/sync:generic_seqatomic", + types = { + "Value": "int", + }, +) + +go_library( + name = "seqatomic", + srcs = ["seqatomic_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic", + deps = [ + "//pkg/sync", + ], +) + +go_test( + name = "seqatomic_test", + size = "small", + srcs = ["seqatomic_test.go"], + embed = [":seqatomic"], + deps = ["//pkg/sync"], +) diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go new file mode 100644 index 000000000..2c4568b07 --- /dev/null +++ b/pkg/sync/seqatomictest/seqatomic_test.go @@ -0,0 +1,132 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seqatomic + +import ( + "sync/atomic" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +func TestSeqAtomicLoadUncontended(t *testing.T) { + var seq sync.SeqCount + const want = 1 + data := want + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadAfterWrite(t *testing.T) { + var seq sync.SeqCount + var data int + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadDuringWrite(t *testing.T) { + var seq sync.SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicTryLoadUncontended(t *testing.T) { + var seq sync.SeqCount + const want = 1 + data := want + epoch := seq.BeginRead() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } +} + +func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { + var seq sync.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } + seq.EndWrite() +} + +func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { + var seq sync.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } +} + +func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { + var seq sync.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := SeqAtomicLoadInt(&seq, &data); got != want { + b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } + } + }) +} + +func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { + var seq sync.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + epoch := seq.BeginRead() + for pb.Next() { + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } + } + }) +} + +// For comparison: +func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { + var a atomic.Value + const want = 42 + a.Store(int(want)) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := a.Load().(int); got != want { + b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) + } + } + }) +} diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go new file mode 100644 index 000000000..a1e895352 --- /dev/null +++ b/pkg/sync/seqcount.go @@ -0,0 +1,149 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "fmt" + "reflect" + "runtime" + "sync/atomic" +) + +// SeqCount is a synchronization primitive for optimistic reader/writer +// synchronization in cases where readers can work with stale data and +// therefore do not need to block writers. +// +// Compared to sync/atomic.Value: +// +// - Mutation of SeqCount-protected data does not require memory allocation, +// whereas atomic.Value generally does. This is a significant advantage when +// writes are common. +// +// - Atomic reads of SeqCount-protected data require copying. This is a +// disadvantage when atomic reads are common. +// +// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other +// operations to be made atomic with reads of SeqCount-protected data. +// +// - SeqCount may be less flexible: as of this writing, SeqCount-protected data +// cannot include pointers. +// +// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected +// data require instantiating function templates using go_generics (see +// seqatomic.go). +type SeqCount struct { + // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd + // if a writer critical section is active, and a read from data protected + // by this SeqCount is atomic iff epoch is the same even value before and + // after the read. + epoch uint32 +} + +// SeqCountEpoch tracks writer critical sections in a SeqCount. +type SeqCountEpoch struct { + val uint32 +} + +// We assume that: +// +// - All functions in sync/atomic that perform a memory read are at least a +// read fence: memory reads before calls to such functions cannot be reordered +// after the call, and memory reads after calls to such functions cannot be +// reordered before the call, even if those reads do not use sync/atomic. +// +// - All functions in sync/atomic that perform a memory write are at least a +// write fence: memory writes before calls to such functions cannot be +// reordered after the call, and memory writes after calls to such functions +// cannot be reordered before the call, even if those writes do not use +// sync/atomic. +// +// As of this writing, the Go memory model completely fails to describe +// sync/atomic, but these properties are implied by +// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. + +// BeginRead indicates the beginning of a reader critical section. Reader +// critical sections DO NOT BLOCK writer critical sections, so operations in a +// reader critical section MAY RACE with writer critical sections. Races are +// detected by ReadOk at the end of the reader critical section. Thus, the +// low-level structure of readers is generally: +// +// for { +// epoch := seq.BeginRead() +// // do something idempotent with seq-protected data +// if seq.ReadOk(epoch) { +// break +// } +// } +// +// However, since reader critical sections may race with writer critical +// sections, the Go race detector will (accurately) flag data races in readers +// using this pattern. Most users of SeqCount will need to use the +// SeqAtomicLoad function template in seqatomic.go. +func (s *SeqCount) BeginRead() SeqCountEpoch { + epoch := atomic.LoadUint32(&s.epoch) + for epoch&1 != 0 { + runtime.Gosched() + epoch = atomic.LoadUint32(&s.epoch) + } + return SeqCountEpoch{epoch} +} + +// ReadOk returns true if the reader critical section initiated by a previous +// call to BeginRead() that returned epoch did not race with any writer critical +// sections. +// +// ReadOk may be called any number of times during a reader critical section. +// Reader critical sections do not need to be explicitly terminated; the last +// call to ReadOk is implicitly the end of the reader critical section. +func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { + return atomic.LoadUint32(&s.epoch) == epoch.val +} + +// BeginWrite indicates the beginning of a writer critical section. +// +// SeqCount does not support concurrent writer critical sections; clients with +// concurrent writers must synchronize them using e.g. sync.Mutex. +func (s *SeqCount) BeginWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { + panic("SeqCount.BeginWrite during writer critical section") + } +} + +// EndWrite ends the effect of a preceding BeginWrite. +func (s *SeqCount) EndWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { + panic("SeqCount.EndWrite outside writer critical section") + } +} + +// PointersInType returns a list of pointers reachable from values named +// valName of the given type. +// +// PointersInType is not exhaustive, but it is guaranteed that if typ contains +// at least one pointer, then PointersInTypeOf returns a non-empty list. +func PointersInType(typ reflect.Type, valName string) []string { + switch kind := typ.Kind(); kind { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return nil + + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: + return []string{valName} + + case reflect.Array: + return PointersInType(typ.Elem(), valName+"[]") + + case reflect.Struct: + var ptrs []string + for i, n := 0, typ.NumField(); i < n; i++ { + field := typ.Field(i) + ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) + } + return ptrs + + default: + return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} + } +} diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go new file mode 100644 index 000000000..6eb7b4b59 --- /dev/null +++ b/pkg/sync/seqcount_test.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "reflect" + "testing" + "time" +) + +func TestSeqCountWriteUncontended(t *testing.T) { + var seq SeqCount + seq.BeginWrite() + seq.EndWrite() +} + +func TestSeqCountReadUncontended(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadAfterWrite(t *testing.T) { + var seq SeqCount + var data int32 + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadDuringWrite(t *testing.T) { + var seq SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountReadOkAfterWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } +} + +func TestSeqCountReadOkDuringWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } + seq.EndWrite() +} + +func BenchmarkSeqCountWriteUncontended(b *testing.B) { + var seq SeqCount + for i := 0; i < b.N; i++ { + seq.BeginWrite() + seq.EndWrite() + } +} + +func BenchmarkSeqCountReadUncontended(b *testing.B) { + var seq SeqCount + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + b.Fatalf("ReadOk: got false, wanted true") + } + } + }) +} + +func TestPointersInType(t *testing.T) { + for _, test := range []struct { + name string // used for both test and value name + val interface{} + ptrs []string + }{ + { + name: "EmptyStruct", + val: struct{}{}, + }, + { + name: "Int", + val: int(0), + }, + { + name: "MixedStruct", + val: struct { + b bool + I int + ExportedPtr *struct{} + unexportedPtr *struct{} + arr [2]int + ptrArr [2]*int + nestedStruct struct { + nestedNonptr int + nestedPtr *int + } + structArr [1]struct { + nonptr int + ptr *int + } + }{}, + ptrs: []string{ + "MixedStruct.ExportedPtr", + "MixedStruct.unexportedPtr", + "MixedStruct.ptrArr[]", + "MixedStruct.nestedStruct.nestedPtr", + "MixedStruct.structArr[].ptr", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + typ := reflect.TypeOf(test.val) + ptrs := PointersInType(typ, test.name) + t.Logf("Found pointers: %v", ptrs) + if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { + t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) + } + }) + } +} diff --git a/pkg/sync/syncutil.go b/pkg/sync/syncutil.go new file mode 100644 index 000000000..b16cf5333 --- /dev/null +++ b/pkg/sync/syncutil.go @@ -0,0 +1,7 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sync provides synchronization primitives. +package sync diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD deleted file mode 100644 index cb1f41628..000000000 --- a/pkg/syncutil/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -exports_files(["LICENSE"]) - -go_template( - name = "generic_atomicptr", - srcs = ["atomicptr_unsafe.go"], - types = [ - "Value", - ], -) - -go_template( - name = "generic_seqatomic", - srcs = ["seqatomic_unsafe.go"], - types = [ - "Value", - ], - deps = [ - ":sync", - ], -) - -go_library( - name = "syncutil", - srcs = [ - "downgradable_rwmutex_unsafe.go", - "memmove_unsafe.go", - "norace_unsafe.go", - "race_unsafe.go", - "seqcount.go", - "syncutil.go", - ], - importpath = "gvisor.dev/gvisor/pkg/syncutil", -) - -go_test( - name = "syncutil_test", - size = "small", - srcs = [ - "downgradable_rwmutex_test.go", - "seqcount_test.go", - ], - embed = [":syncutil"], -) diff --git a/pkg/syncutil/LICENSE b/pkg/syncutil/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/syncutil/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/syncutil/README.md b/pkg/syncutil/README.md deleted file mode 100644 index 2183c4e20..000000000 --- a/pkg/syncutil/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Syncutil - -This package provides additional synchronization primitives not provided by the -Go stdlib 'sync' package. It is partially derived from the upstream 'sync' -package from go1.10. diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/syncutil/atomicptr_unsafe.go deleted file mode 100644 index 525c4beed..000000000 --- a/pkg/syncutil/atomicptr_unsafe.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template - -import ( - "sync/atomic" - "unsafe" -) - -// Value is a required type parameter. -type Value struct{} - -// An AtomicPtr is a pointer to a value of type Value that can be atomically -// loaded and stored. The zero value of an AtomicPtr represents nil. -// -// Note that copying AtomicPtr by value performs a non-atomic read of the -// stored pointer, which is unsafe if Store() can be called concurrently; in -// this case, do `dst.Store(src.Load())` instead. -// -// +stateify savable -type AtomicPtr struct { - ptr unsafe.Pointer `state:".(*Value)"` -} - -func (p *AtomicPtr) savePtr() *Value { - return p.Load() -} - -func (p *AtomicPtr) loadPtr(v *Value) { - p.Store(v) -} - -// Load returns the value set by the most recent Store. It returns nil if there -// has been no previous call to Store. -func (p *AtomicPtr) Load() *Value { - return (*Value)(atomic.LoadPointer(&p.ptr)) -} - -// Store sets the value returned by Load to x. -func (p *AtomicPtr) Store(x *Value) { - atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) -} diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/syncutil/atomicptrtest/BUILD deleted file mode 100644 index 63f411a90..000000000 --- a/pkg/syncutil/atomicptrtest/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_int", - out = "atomicptr_int_unsafe.go", - package = "atomicptr", - suffix = "Int", - template = "//pkg/syncutil:generic_atomicptr", - types = { - "Value": "int", - }, -) - -go_library( - name = "atomicptr", - srcs = ["atomicptr_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr", -) - -go_test( - name = "atomicptr_test", - size = "small", - srcs = ["atomicptr_test.go"], - embed = [":atomicptr"], -) diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/syncutil/atomicptrtest/atomicptr_test.go deleted file mode 100644 index 8fdc5112e..000000000 --- a/pkg/syncutil/atomicptrtest/atomicptr_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package atomicptr - -import ( - "testing" -) - -func newInt(val int) *int { - return &val -} - -func TestAtomicPtr(t *testing.T) { - var p AtomicPtrInt - if got := p.Load(); got != nil { - t.Errorf("initial value is %p (%v), wanted nil", got, got) - } - want := newInt(42) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } - want = newInt(100) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } -} diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/syncutil/downgradable_rwmutex_test.go deleted file mode 100644 index ffaf7ecc7..000000000 --- a/pkg/syncutil/downgradable_rwmutex_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// GOMAXPROCS=10 go test - -// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the -// addition of downgradingWriter and the renaming of num_iterations to -// numIterations to shut up Golint. - -package syncutil - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" -) - -func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { - m.RLock() - clocked <- true - <-cunlock - m.RUnlock() - cdone <- true -} - -func doTestParallelReaders(numReaders, gomaxprocs int) { - runtime.GOMAXPROCS(gomaxprocs) - var m DowngradableRWMutex - clocked := make(chan bool) - cunlock := make(chan bool) - cdone := make(chan bool) - for i := 0; i < numReaders; i++ { - go parallelReader(&m, clocked, cunlock, cdone) - } - // Wait for all parallel RLock()s to succeed. - for i := 0; i < numReaders; i++ { - <-clocked - } - for i := 0; i < numReaders; i++ { - cunlock <- true - } - // Wait for the goroutines to finish. - for i := 0; i < numReaders; i++ { - <-cdone - } -} - -func TestParallelReaders(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - doTestParallelReaders(1, 4) - doTestParallelReaders(3, 4) - doTestParallelReaders(4, 2) -} - -func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.RLock() - n := atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.Unlock() - } - cdone <- true -} - -func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.DowngradeLock() - n = atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - n = atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { - runtime.GOMAXPROCS(gomaxprocs) - // Number of active readers + 10000 * number of active writers. - var activity int32 - var rwm DowngradableRWMutex - cdone := make(chan bool) - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - var i int - for i = 0; i < numReaders/2; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - for ; i < numReaders; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - // Wait for the 4 writers and all readers to finish. - for i := 0; i < 4+numReaders; i++ { - <-cdone - } -} - -func TestDowngradableRWMutex(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - n := 1000 - if testing.Short() { - n = 5 - } - HammerDowngradableRWMutex(1, 1, n) - HammerDowngradableRWMutex(1, 3, n) - HammerDowngradableRWMutex(1, 10, n) - HammerDowngradableRWMutex(4, 1, n) - HammerDowngradableRWMutex(4, 3, n) - HammerDowngradableRWMutex(4, 10, n) - HammerDowngradableRWMutex(10, 1, n) - HammerDowngradableRWMutex(10, 3, n) - HammerDowngradableRWMutex(10, 10, n) - HammerDowngradableRWMutex(10, 5, n) -} diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go deleted file mode 100644 index 51e11555d..000000000 --- a/pkg/syncutil/downgradable_rwmutex_unsafe.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.13 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -// This is mostly copied from the standard library's sync/rwmutex.go. -// -// Happens-before relationships indicated to the race detector: -// - Unlock -> Lock (via writerSem) -// - Unlock -> RLock (via readerSem) -// - RUnlock -> Lock (via writerSem) -// - DowngradeLock -> RLock (via readerSem) - -package syncutil - -import ( - "sync" - "sync/atomic" - "unsafe" -) - -//go:linkname runtimeSemacquire sync.runtime_Semacquire -func runtimeSemacquire(s *uint32) - -//go:linkname runtimeSemrelease sync.runtime_Semrelease -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) - -// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock -// method. -type DowngradableRWMutex struct { - w sync.Mutex // held if there are pending writers - writerSem uint32 // semaphore for writers to wait for completing readers - readerSem uint32 // semaphore for readers to wait for completing writers - readerCount int32 // number of pending readers - readerWait int32 // number of departing readers -} - -const rwmutexMaxReaders = 1 << 30 - -// RLock locks rw for reading. -func (rw *DowngradableRWMutex) RLock() { - if RaceEnabled { - RaceDisable() - } - if atomic.AddInt32(&rw.readerCount, 1) < 0 { - // A writer is pending, wait for it. - runtimeSemacquire(&rw.readerSem) - } - if RaceEnabled { - RaceEnable() - RaceAcquire(unsafe.Pointer(&rw.readerSem)) - } -} - -// RUnlock undoes a single RLock call. -func (rw *DowngradableRWMutex) RUnlock() { - if RaceEnabled { - RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) - RaceDisable() - } - if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { - if r+1 == 0 || r+1 == -rwmutexMaxReaders { - panic("RUnlock of unlocked DowngradableRWMutex") - } - // A writer is pending. - if atomic.AddInt32(&rw.readerWait, -1) == 0 { - // The last reader unblocks the writer. - runtimeSemrelease(&rw.writerSem, false, 0) - } - } - if RaceEnabled { - RaceEnable() - } -} - -// Lock locks rw for writing. -func (rw *DowngradableRWMutex) Lock() { - if RaceEnabled { - RaceDisable() - } - // First, resolve competition with other writers. - rw.w.Lock() - // Announce to readers there is a pending writer. - r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders - // Wait for active readers. - if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { - runtimeSemacquire(&rw.writerSem) - } - if RaceEnabled { - RaceEnable() - RaceAcquire(unsafe.Pointer(&rw.writerSem)) - } -} - -// Unlock unlocks rw for writing. -func (rw *DowngradableRWMutex) Unlock() { - if RaceEnabled { - RaceRelease(unsafe.Pointer(&rw.writerSem)) - RaceRelease(unsafe.Pointer(&rw.readerSem)) - RaceDisable() - } - // Announce to readers there is no active writer. - r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) - if r >= rwmutexMaxReaders { - panic("Unlock of unlocked DowngradableRWMutex") - } - // Unblock blocked readers, if any. - for i := 0; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) - } - // Allow other writers to proceed. - rw.w.Unlock() - if RaceEnabled { - RaceEnable() - } -} - -// DowngradeLock atomically unlocks rw for writing and locks it for reading. -func (rw *DowngradableRWMutex) DowngradeLock() { - if RaceEnabled { - RaceRelease(unsafe.Pointer(&rw.readerSem)) - RaceDisable() - } - // Announce to readers there is no active writer and one additional reader. - r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) - if r >= rwmutexMaxReaders+1 { - panic("DowngradeLock of unlocked DowngradableRWMutex") - } - // Unblock blocked readers, if any. Note that this loop starts as 1 since r - // includes this goroutine. - for i := 1; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) - } - // Allow other writers to proceed to rw.w.Lock(). Note that they will still - // block on rw.writerSem since at least this reader exists, such that - // DowngradeLock() is atomic with the previous write lock. - rw.w.Unlock() - if RaceEnabled { - RaceEnable() - } -} diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/syncutil/memmove_unsafe.go deleted file mode 100644 index 348675baa..000000000 --- a/pkg/syncutil/memmove_unsafe.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.12 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -package syncutil - -import ( - "unsafe" -) - -//go:linkname memmove runtime.memmove -//go:noescape -func memmove(to, from unsafe.Pointer, n uintptr) - -// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad, which can't -// define it because go_generics can't update the go:linkname annotation. -// Furthermore, go:linkname silently doesn't work if the local name is exported -// (this is of course undocumented), which is why this indirection is -// necessary. -func Memmove(to, from unsafe.Pointer, n uintptr) { - memmove(to, from, n) -} diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/syncutil/norace_unsafe.go deleted file mode 100644 index 0a0a9deda..000000000 --- a/pkg/syncutil/norace_unsafe.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !race - -package syncutil - -import ( - "unsafe" -) - -// RaceEnabled is true if the Go data race detector is enabled. -const RaceEnabled = false - -// RaceDisable has the same semantics as runtime.RaceDisable. -func RaceDisable() { -} - -// RaceEnable has the same semantics as runtime.RaceEnable. -func RaceEnable() { -} - -// RaceAcquire has the same semantics as runtime.RaceAcquire. -func RaceAcquire(addr unsafe.Pointer) { -} - -// RaceRelease has the same semantics as runtime.RaceRelease. -func RaceRelease(addr unsafe.Pointer) { -} - -// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. -func RaceReleaseMerge(addr unsafe.Pointer) { -} diff --git a/pkg/syncutil/race_unsafe.go b/pkg/syncutil/race_unsafe.go deleted file mode 100644 index 206067ec1..000000000 --- a/pkg/syncutil/race_unsafe.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build race - -package syncutil - -import ( - "runtime" - "unsafe" -) - -// RaceEnabled is true if the Go data race detector is enabled. -const RaceEnabled = true - -// RaceDisable has the same semantics as runtime.RaceDisable. -func RaceDisable() { - runtime.RaceDisable() -} - -// RaceEnable has the same semantics as runtime.RaceEnable. -func RaceEnable() { - runtime.RaceEnable() -} - -// RaceAcquire has the same semantics as runtime.RaceAcquire. -func RaceAcquire(addr unsafe.Pointer) { - runtime.RaceAcquire(addr) -} - -// RaceRelease has the same semantics as runtime.RaceRelease. -func RaceRelease(addr unsafe.Pointer) { - runtime.RaceRelease(addr) -} - -// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. -func RaceReleaseMerge(addr unsafe.Pointer) { - runtime.RaceReleaseMerge(addr) -} diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/syncutil/seqatomic_unsafe.go deleted file mode 100644 index cb6d2eb22..000000000 --- a/pkg/syncutil/seqatomic_unsafe.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template - -import ( - "fmt" - "reflect" - "strings" - "unsafe" - - "gvisor.dev/gvisor/pkg/syncutil" -) - -// Value is a required type parameter. -// -// Value must not contain any pointers, including interface objects, function -// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs -// containing any of the above. An init() function will panic if this property -// does not hold. -type Value struct{} - -// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race -// with any writer critical sections in sc. -func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value { - // This function doesn't use SeqAtomicTryLoad because doing so is - // measurably, significantly (~20%) slower; Go is awful at inlining. - var val Value - for { - epoch := sc.BeginRead() - if syncutil.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, - // so it can't help us here. Instead, call runtime.memmove - // directly, which is not instrumented by the race detector. - syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - // This is ~40% faster for short reads than going through memmove. - val = *ptr - } - if sc.ReadOk(epoch) { - break - } - } - return val -} - -// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section -// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read -// would race with a writer critical section, SeqAtomicTryLoad returns -// (unspecified, false). -func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) { - var val Value - if syncutil.RaceEnabled { - syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - val = *ptr - } - return val, sc.ReadOk(epoch) -} - -func init() { - var val Value - typ := reflect.TypeOf(val) - name := typ.Name() - if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 { - panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) - } -} diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/syncutil/seqatomictest/BUILD deleted file mode 100644 index ba18f3238..000000000 --- a/pkg/syncutil/seqatomictest/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "seqatomic_int", - out = "seqatomic_int_unsafe.go", - package = "seqatomic", - suffix = "Int", - template = "//pkg/syncutil:generic_seqatomic", - types = { - "Value": "int", - }, -) - -go_library( - name = "seqatomic", - srcs = ["seqatomic_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic", - deps = [ - "//pkg/syncutil", - ], -) - -go_test( - name = "seqatomic_test", - size = "small", - srcs = ["seqatomic_test.go"], - embed = [":seqatomic"], - deps = [ - "//pkg/syncutil", - ], -) diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/syncutil/seqatomictest/seqatomic_test.go deleted file mode 100644 index b0db44999..000000000 --- a/pkg/syncutil/seqatomictest/seqatomic_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqatomic - -import ( - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/syncutil" -) - -func TestSeqAtomicLoadUncontended(t *testing.T) { - var seq syncutil.SeqCount - const want = 1 - data := want - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadAfterWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadDuringWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicTryLoadUncontended(t *testing.T) { - var seq syncutil.SeqCount - const want = 1 - data := want - epoch := seq.BeginRead() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } -} - -func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } - seq.EndWrite() -} - -func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } -} - -func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { - var seq syncutil.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := SeqAtomicLoadInt(&seq, &data); got != want { - b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } - } - }) -} - -func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { - var seq syncutil.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - epoch := seq.BeginRead() - for pb.Next() { - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } - } - }) -} - -// For comparison: -func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { - var a atomic.Value - const want = 42 - a.Store(int(want)) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := a.Load().(int); got != want { - b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) - } - } - }) -} diff --git a/pkg/syncutil/seqcount.go b/pkg/syncutil/seqcount.go deleted file mode 100644 index 11d8dbfaa..000000000 --- a/pkg/syncutil/seqcount.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package syncutil - -import ( - "fmt" - "reflect" - "runtime" - "sync/atomic" -) - -// SeqCount is a synchronization primitive for optimistic reader/writer -// synchronization in cases where readers can work with stale data and -// therefore do not need to block writers. -// -// Compared to sync/atomic.Value: -// -// - Mutation of SeqCount-protected data does not require memory allocation, -// whereas atomic.Value generally does. This is a significant advantage when -// writes are common. -// -// - Atomic reads of SeqCount-protected data require copying. This is a -// disadvantage when atomic reads are common. -// -// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other -// operations to be made atomic with reads of SeqCount-protected data. -// -// - SeqCount may be less flexible: as of this writing, SeqCount-protected data -// cannot include pointers. -// -// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected -// data require instantiating function templates using go_generics (see -// seqatomic.go). -type SeqCount struct { - // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd - // if a writer critical section is active, and a read from data protected - // by this SeqCount is atomic iff epoch is the same even value before and - // after the read. - epoch uint32 -} - -// SeqCountEpoch tracks writer critical sections in a SeqCount. -type SeqCountEpoch struct { - val uint32 -} - -// We assume that: -// -// - All functions in sync/atomic that perform a memory read are at least a -// read fence: memory reads before calls to such functions cannot be reordered -// after the call, and memory reads after calls to such functions cannot be -// reordered before the call, even if those reads do not use sync/atomic. -// -// - All functions in sync/atomic that perform a memory write are at least a -// write fence: memory writes before calls to such functions cannot be -// reordered after the call, and memory writes after calls to such functions -// cannot be reordered before the call, even if those writes do not use -// sync/atomic. -// -// As of this writing, the Go memory model completely fails to describe -// sync/atomic, but these properties are implied by -// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. - -// BeginRead indicates the beginning of a reader critical section. Reader -// critical sections DO NOT BLOCK writer critical sections, so operations in a -// reader critical section MAY RACE with writer critical sections. Races are -// detected by ReadOk at the end of the reader critical section. Thus, the -// low-level structure of readers is generally: -// -// for { -// epoch := seq.BeginRead() -// // do something idempotent with seq-protected data -// if seq.ReadOk(epoch) { -// break -// } -// } -// -// However, since reader critical sections may race with writer critical -// sections, the Go race detector will (accurately) flag data races in readers -// using this pattern. Most users of SeqCount will need to use the -// SeqAtomicLoad function template in seqatomic.go. -func (s *SeqCount) BeginRead() SeqCountEpoch { - epoch := atomic.LoadUint32(&s.epoch) - for epoch&1 != 0 { - runtime.Gosched() - epoch = atomic.LoadUint32(&s.epoch) - } - return SeqCountEpoch{epoch} -} - -// ReadOk returns true if the reader critical section initiated by a previous -// call to BeginRead() that returned epoch did not race with any writer critical -// sections. -// -// ReadOk may be called any number of times during a reader critical section. -// Reader critical sections do not need to be explicitly terminated; the last -// call to ReadOk is implicitly the end of the reader critical section. -func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { - return atomic.LoadUint32(&s.epoch) == epoch.val -} - -// BeginWrite indicates the beginning of a writer critical section. -// -// SeqCount does not support concurrent writer critical sections; clients with -// concurrent writers must synchronize them using e.g. sync.Mutex. -func (s *SeqCount) BeginWrite() { - if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { - panic("SeqCount.BeginWrite during writer critical section") - } -} - -// EndWrite ends the effect of a preceding BeginWrite. -func (s *SeqCount) EndWrite() { - if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { - panic("SeqCount.EndWrite outside writer critical section") - } -} - -// PointersInType returns a list of pointers reachable from values named -// valName of the given type. -// -// PointersInType is not exhaustive, but it is guaranteed that if typ contains -// at least one pointer, then PointersInTypeOf returns a non-empty list. -func PointersInType(typ reflect.Type, valName string) []string { - switch kind := typ.Kind(); kind { - case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: - return nil - - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: - return []string{valName} - - case reflect.Array: - return PointersInType(typ.Elem(), valName+"[]") - - case reflect.Struct: - var ptrs []string - for i, n := 0, typ.NumField(); i < n; i++ { - field := typ.Field(i) - ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) - } - return ptrs - - default: - return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} - } -} diff --git a/pkg/syncutil/seqcount_test.go b/pkg/syncutil/seqcount_test.go deleted file mode 100644 index 14d6aedea..000000000 --- a/pkg/syncutil/seqcount_test.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package syncutil - -import ( - "reflect" - "testing" - "time" -) - -func TestSeqCountWriteUncontended(t *testing.T) { - var seq SeqCount - seq.BeginWrite() - seq.EndWrite() -} - -func TestSeqCountReadUncontended(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadAfterWrite(t *testing.T) { - var seq SeqCount - var data int32 - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadDuringWrite(t *testing.T) { - var seq SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountReadOkAfterWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } -} - -func TestSeqCountReadOkDuringWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } - seq.EndWrite() -} - -func BenchmarkSeqCountWriteUncontended(b *testing.B) { - var seq SeqCount - for i := 0; i < b.N; i++ { - seq.BeginWrite() - seq.EndWrite() - } -} - -func BenchmarkSeqCountReadUncontended(b *testing.B) { - var seq SeqCount - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - b.Fatalf("ReadOk: got false, wanted true") - } - } - }) -} - -func TestPointersInType(t *testing.T) { - for _, test := range []struct { - name string // used for both test and value name - val interface{} - ptrs []string - }{ - { - name: "EmptyStruct", - val: struct{}{}, - }, - { - name: "Int", - val: int(0), - }, - { - name: "MixedStruct", - val: struct { - b bool - I int - ExportedPtr *struct{} - unexportedPtr *struct{} - arr [2]int - ptrArr [2]*int - nestedStruct struct { - nestedNonptr int - nestedPtr *int - } - structArr [1]struct { - nonptr int - ptr *int - } - }{}, - ptrs: []string{ - "MixedStruct.ExportedPtr", - "MixedStruct.unexportedPtr", - "MixedStruct.ptrArr[]", - "MixedStruct.nestedStruct.nestedPtr", - "MixedStruct.structArr[].ptr", - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - typ := reflect.TypeOf(test.val) - ptrs := PointersInType(typ, test.name) - t.Logf("Found pointers: %v", ptrs) - if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { - t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) - } - }) - } -} diff --git a/pkg/syncutil/syncutil.go b/pkg/syncutil/syncutil.go deleted file mode 100644 index 66e750d06..000000000 --- a/pkg/syncutil/syncutil.go +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package syncutil provides synchronization primitives. -package syncutil diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index e07ebd153..db06d02c6 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -15,6 +15,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip/buffer", "//pkg/tcpip/iptables", "//pkg/waiter", diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 78df5a0b1..3df7d18d3 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index cd6ce930a..a2f44b496 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -20,9 +20,9 @@ import ( "errors" "io" "net" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 897c94821..66cc53ed4 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -16,6 +16,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index fa8a703d9..b7f60178e 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,10 +41,10 @@ package fdbased import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index a4f9cdd69..09165dd4c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -31,6 +32,7 @@ go_test( ], embed = [":sharedmem"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 6b5bc542c..a0d4ad0be 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -21,4 +21,5 @@ go_test( "pipe_test.go", ], embed = [":pipe"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index 59ef69a8b..dc239a0d0 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -18,8 +18,9 @@ import ( "math/rand" "reflect" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSimpleReadWrite(t *testing.T) { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 080f9d667..655e537c4 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -23,11 +23,11 @@ package sharedmem import ( - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 89603c48f..5c729a439 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -22,11 +22,11 @@ import ( "math/rand" "os" "strings" - "sync" "syscall" "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index acf1e022c..ed16076fd 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", ], diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 6da5238ec..92f2aa13a 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -19,9 +19,9 @@ package fragmentation import ( "fmt" "log" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9e002e396..0a83d81f2 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index e156b01f6..a6ef3bdcc 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 6c5e19e8f..b937cb84b 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -18,9 +18,9 @@ package ports import ( "math" "math/rand" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 826fca4de..6a8654105 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -36,6 +36,7 @@ go_library( "//pkg/ilist", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", @@ -80,6 +81,7 @@ go_test( embed = [":stack"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 267df60d1..403557fd7 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -16,10 +16,10 @@ package stack import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 9946b8fe8..1baa498d0 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,12 +16,12 @@ package stack import ( "fmt" - "sync" "sync/atomic" "testing" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3810c6602..fe557ccbd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,9 +16,9 @@ package stack import ( "strings" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 41bf9fd9b..a47ceba54 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -21,13 +21,13 @@ package stack import ( "encoding/binary" - "sync" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 67c21be42..f384a91de 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -18,8 +18,8 @@ import ( "fmt" "math/rand" "sort" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 72b5ce179..4a090ac86 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -35,10 +35,10 @@ import ( "reflect" "strconv" "strings" - "sync" "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index d8c5b5058..3aa23d529 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index c7ce74cdd..330786f4c 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,8 +15,7 @@ package icmp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 44b58ff6b..4858d150c 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -28,6 +28,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 07ffa8aba..fc5bc69fa 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,8 +25,7 @@ package packet import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 00991ac8e..2f2131ff7 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -29,6 +29,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 85f7eb76b..ee9c4c58b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,8 +26,7 @@ package raw import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 3b353d56c..353bd06f4 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/log", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 5422ae80c..1ea996936 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -19,11 +19,11 @@ import ( "encoding/binary" "hash" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index cdd69f360..613ec1775 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -16,11 +16,11 @@ package tcp import ( "encoding/binary" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 830bc1e3e..cca511fb9 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -19,12 +19,12 @@ import ( "fmt" "math" "strings" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 7aa4c3f0e..4b8d867bc 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,9 +16,9 @@ package tcp import ( "fmt" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 4983bca81..7eb613be5 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -15,8 +15,7 @@ package tcp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index bc718064c..9a8f64aa6 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -22,9 +22,9 @@ package tcp import ( "strings" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index e0759225e..bd20a7ee9 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -15,7 +15,7 @@ package tcp import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // segmentQueue is a bounded, thread-safe queue of TCP segments. diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8a947dc66..79f2d274b 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -16,11 +16,11 @@ package tcp import ( "math" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 97e4d5825..57ff123e3 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -30,6 +30,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 864dc8733..a4ff29a7d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,8 +15,7 @@ package udp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 6afdb29b7..07778e4f7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -15,4 +15,5 @@ go_test( size = "medium", srcs = ["tmutex_test.go"], embed = [":tmutex"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go index ce34c7962..05540696a 100644 --- a/pkg/tmutex/tmutex_test.go +++ b/pkg/tmutex/tmutex_test.go @@ -17,10 +17,11 @@ package tmutex import ( "fmt" "runtime" - "sync" "sync/atomic" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func TestBasicLock(t *testing.T) { diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index 8f6f180e5..d1885ae66 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -24,4 +24,5 @@ go_test( "unet_test.go", ], embed = [":unet"], + deps = ["//pkg/sync"], ) diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go index a3cc6f5d3..5c4b9e8e9 100644 --- a/pkg/unet/unet_test.go +++ b/pkg/unet/unet_test.go @@ -19,10 +19,11 @@ import ( "os" "path/filepath" "reflect" - "sync" "syscall" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func randomFilename() (string, error) { diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b6bbb0ea2..b8fdc3125 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -11,6 +11,7 @@ go_library( deps = [ "//pkg/fd", "//pkg/log", + "//pkg/sync", "//pkg/unet", ], ) diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go index df59ffab1..13b2ea314 100644 --- a/pkg/urpc/urpc.go +++ b/pkg/urpc/urpc.go @@ -27,10 +27,10 @@ import ( "os" "reflect" "runtime" - "sync" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 0427bc41f..1c6890e52 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -24,6 +24,7 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/waiter", visibility = ["//visibility:public"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 8a65ed164..f708e95fa 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -58,7 +58,7 @@ package waiter import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // EventMask represents io events as used in the poll() syscall. diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 6226b63f8..3e20f8f2f 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -74,6 +74,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/watchdog", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/link/fdbased", @@ -114,6 +115,7 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//runsc/fsgofer", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 352e710d2..9c23b9553 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -17,7 +17,6 @@ package boot import ( "fmt" "os" - "sync" "syscall" "github.com/golang/protobuf/proto" @@ -27,6 +26,7 @@ import ( ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/strace" spb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" + "gvisor.dev/gvisor/pkg/sync" ) func initCompatLogs(fd int) error { diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go index d1c0bb9b5..ce62236e5 100644 --- a/runsc/boot/limits.go +++ b/runsc/boot/limits.go @@ -16,12 +16,12 @@ package boot import ( "fmt" - "sync" "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) // Mapping from linux resource names to limits.LimitType. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index bc1d0c1bb..fad72f4ab 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -20,7 +20,6 @@ import ( mrand "math/rand" "os" "runtime" - "sync" "sync/atomic" "syscall" gtime "time" @@ -46,6 +45,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/watchdog" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 147ff7703..bec0dc292 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -19,7 +19,6 @@ import ( "math/rand" "os" "reflect" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/fsgofer" ) diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 250845ad7..b94bc4fa0 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -44,6 +44,7 @@ go_library( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", "//runsc/boot", diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go index a4e3071b3..1815c93b9 100644 --- a/runsc/cmd/create.go +++ b/runsc/cmd/create.go @@ -16,6 +16,7 @@ package cmd import ( "context" + "flag" "github.com/google/subcommands" "gvisor.dev/gvisor/runsc/boot" diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 4831210c0..7df7995f0 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -21,7 +21,6 @@ import ( "os" "path/filepath" "strings" - "sync" "syscall" "flag" @@ -30,6 +29,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/fsgofer" diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index de2115dff..5e9bc53ab 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -16,6 +16,7 @@ package cmd import ( "context" + "flag" "github.com/google/subcommands" "gvisor.dev/gvisor/runsc/boot" diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 2bd12120d..6dea179e4 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -18,6 +18,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sentry/control", + "//pkg/sync", "//runsc/boot", "//runsc/cgroup", "//runsc/sandbox", @@ -53,6 +54,7 @@ go_test( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", "//runsc/boot", diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 5ed131a7f..060b63bf3 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -20,7 +20,6 @@ import ( "io" "os" "path/filepath" - "sync" "syscall" "testing" "time" @@ -29,6 +28,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/testutil" diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index c10f85992..b54d8f712 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -26,7 +26,6 @@ import ( "reflect" "strconv" "strings" - "sync" "syscall" "testing" "time" @@ -39,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" "gvisor.dev/gvisor/runsc/specutils" diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 4ad09ceab..2da93ec5b 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -22,7 +22,6 @@ import ( "path" "path/filepath" "strings" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" "gvisor.dev/gvisor/runsc/testutil" diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go index d95151ea5..17a251530 100644 --- a/runsc/container/state_file.go +++ b/runsc/container/state_file.go @@ -20,10 +20,10 @@ import ( "io/ioutil" "os" "path/filepath" - "sync" "github.com/gofrs/flock" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" ) const stateFileExtension = ".state" diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index afcb41801..a9582d92b 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/fd", "//pkg/log", "//pkg/p9", + "//pkg/sync", "//pkg/syserr", "//runsc/specutils", "@org_golang_x_sys//unix:go_default_library", diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index b59e1a70e..93606d051 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -29,7 +29,6 @@ import ( "path/filepath" "runtime" "strconv" - "sync" "syscall" "golang.org/x/sys/unix" @@ -37,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/specutils" ) diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 8001949d5..ddbc37456 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/log", "//pkg/sentry/control", "//pkg/sentry/platform", + "//pkg/sync", "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/urpc", diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index ce1452b87..ec72bdbfd 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -22,7 +22,6 @@ import ( "os" "os/exec" "strconv" - "sync" "syscall" "time" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index c96ca2eb6..3c3027cb5 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index 9632776d2..fb22eae39 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -34,7 +34,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "sync/atomic" "syscall" "time" @@ -42,6 +41,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" ) -- cgit v1.2.3 From debd213da61cf35d7c91346820e93fc87bfa5896 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Mon, 13 Jan 2020 14:45:31 -0800 Subject: Allow dual stack sockets to operate on AF_INET Fixes #1490 Fixes #1495 PiperOrigin-RevId: 289523250 --- pkg/sentry/socket/netstack/netstack.go | 65 +++++++++--- pkg/sentry/socket/unix/unix.go | 5 +- pkg/sentry/strace/socket.go | 2 +- pkg/tcpip/stack/stack.go | 43 ++++++++ pkg/tcpip/transport/icmp/endpoint.go | 22 ++-- pkg/tcpip/transport/tcp/endpoint.go | 23 +--- pkg/tcpip/transport/udp/endpoint.go | 41 ++------ scripts/common.sh | 2 +- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/socket_inet_loopback.cc | 156 ++++++++++++++++++++++++++++ 10 files changed, 278 insertions(+), 82 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 099319327..c020c11cb 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -324,22 +324,15 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // converts it to the FullAddress format. It supports AF_UNIX, AF_INET, // AF_INET6, and AF_PACKET addresses. // -// strict indicates whether addresses with the AF_UNSPEC family are accepted of not. -// // AddressAndFamily returns an address and its family. -func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument } - family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { - return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported - } - // Get the rest of the fields based on the address family. - switch family { + switch family := usermem.ByteOrder.Uint16(addr); family { case linux.AF_UNIX: path := addr[2:] if len(path) > linux.UnixPathMax { @@ -638,10 +631,40 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { return r } +func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error { + if family == uint16(s.family) { + return nil + } + if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { + v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { + return syserr.TranslateNetstackError(err) + } + if !v { + return nil + } + } + return syserr.ErrInvalidArgument +} + +// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the +// receiver's family is AF_INET6. +// +// This is a hack to work around the fact that both IPv4 and IPv6 ANY are +// represented by the empty string. +// +// TODO(gvisor.dev/issues/1556): remove this function. +func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress { + if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET { + addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + } + return addr +} + // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */) + addr, family, err := AddressAndFamily(sockaddr) if err != nil { return err } @@ -653,6 +676,12 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } return syserr.TranslateNetstackError(err) } + + if err := s.checkFamily(family, false /* exact */); err != nil { + return err + } + addr = s.mapFamily(addr, family) + // Always return right away in the non-blocking case. if !blocking { return syserr.TranslateNetstackError(s.Endpoint.Connect(addr)) @@ -681,10 +710,14 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */) + addr, family, err := AddressAndFamily(sockaddr) if err != nil { return err } + if err := s.checkFamily(family, true /* exact */); err != nil { + return err + } + addr = s.mapFamily(addr, family) // Issue the bind request to the endpoint. return syserr.TranslateNetstackError(s.Endpoint.Bind(addr)) @@ -2080,8 +2113,8 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) case linux.AF_INET6: var out linux.SockAddrInet6 - if len(addr.Addr) == 4 { - // Copy address is v4-mapped format. + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. copy(out.Addr[12:], addr.Addr) out.Addr[10] = 0xff out.Addr[11] = 0xff @@ -2395,10 +2428,14 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */) + addrBuf, family, err := AddressAndFamily(to) if err != nil { return 0, err } + if err := s.checkFamily(family, false /* exact */); err != nil { + return 0, err + } + addrBuf = s.mapFamily(addrBuf, family) addr = &addrBuf } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 91effe89a..7f49ba864 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -116,13 +116,16 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, family, err := netstack.AddressAndFamily(sockaddr) if err != nil { if err == syserr.ErrAddressFamilyNotSupported { err = syserr.ErrInvalidArgument } return "", err } + if family != linux.AF_UNIX { + return "", syserr.ErrInvalidArgument + } // The address is trimmed by GetAddress. p := string(addr.Addr) diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 51f2efb39..b6d7177f4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */) + fa, _, err := netstack.AddressAndFamily(b) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a47ceba54..113b457fb 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -547,6 +547,49 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } +// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address +// and returns the network protocol number to be used to communicate with the +// specified address. It returns an error if the passed address is incompatible +// with the receiver. +func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.NetProto + switch len(addr.Addr) { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + if header.IsV4MappedAddress(addr.Addr) { + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == header.IPv4Any { + addr.Addr = "" + } + } + } + + switch len(e.ID.LocalAddress) { + case header.IPv4AddressSize: + if len(addr.Addr) == header.IPv6AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + case header.IPv6AddressSize: + if len(addr.Addr) == header.IPv4AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + } + } + + switch { + case netProto == e.NetProto: + case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: + if v6only { + return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + } + default: + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + + return addr, netProto, nil +} + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo // marker interface. func (*TransportEndpointInfo) IsEndpointInfo() {} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 330786f4c..42afb3f5b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -288,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c toCopy := *to to = &toCopy - netProto, err := e.checkV4Mapped(to, true) + netProto, err := e.checkV4Mapped(to) if err != nil { return 0, nil, err } @@ -475,18 +475,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - return 0, tcpip.ErrNoRoute - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } @@ -518,7 +512,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } @@ -631,7 +625,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cca511fb9..cc8b533c8 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1691,26 +1691,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index a4ff29a7d..13446f5d9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -402,7 +402,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - netProto, err := e.checkV4Mapped(to, false) + netProto, err := e.checkV4Mapped(to) if err != nil { return 0, nil, err } @@ -501,7 +501,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa, false) + netProto, err := e.checkV4Mapped(&fa) if err != nil { return err } @@ -839,35 +839,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if len(addr.Addr) == 0 { - return netProto, nil - } - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - - // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.ID.LocalAddress) == 16 { - return 0, tcpip.ErrNetworkUnreachable - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } @@ -916,7 +893,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } @@ -1074,7 +1051,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, true) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } diff --git a/scripts/common.sh b/scripts/common.sh index 6dabad141..fdb1aa142 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -73,7 +73,7 @@ function install_runsc() { sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@" # Clear old logs files that may exist. - sudo rm -f "${RUNSC_LOGS_DIR}"/* + sudo rm -f "${RUNSC_LOGS_DIR}"/'*' # Restart docker to pick up the new runtime configuration. sudo systemctl restart docker diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index ce8abe217..4c7ec3f06 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2693,6 +2693,7 @@ cc_binary( srcs = ["socket_inet_loopback.cc"], linkstatic = 1, deps = [ + ":ip_socket_test_util", ":socket_test_util", "//test/util:file_descriptor", "//test/util:posix_error", diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 138024d9e..5d114d460 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -32,6 +32,7 @@ #include "absl/strings/str_cat.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -102,6 +103,161 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { SyscallFailsWithErrno(EAFNOSUPPORT)); } +enum class Operation { + Bind, + Connect, + SendTo, +}; + +std::string OperationToString(Operation operation) { + switch (operation) { + case Operation::Bind: + return "Bind"; + case Operation::Connect: + return "Connect"; + case Operation::SendTo: + return "SendTo"; + } +} + +using OperationSequence = std::vector; + +using DualStackSocketTest = + ::testing::TestWithParam>; + +TEST_P(DualStackSocketTest, AddressOperations) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, 0)); + + const TestAddress& addr = std::get<0>(GetParam()); + const OperationSequence& operations = std::get<1>(GetParam()); + + auto addr_in = reinterpret_cast(&addr.addr); + + // sockets may only be bound once. Both `connect` and `sendto` cause a socket + // to be bound. + bool bound = false; + for (const Operation& operation : operations) { + bool sockname = false; + bool peername = false; + switch (operation) { + case Operation::Bind: { + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast(&addr.addr), 0)); + + int bind_ret = bind(fd.get(), addr_in, addr.addr_len); + + // Dual stack sockets may only be bound to AF_INET6. + if (!bound && addr.family() == AF_INET6) { + EXPECT_THAT(bind_ret, SyscallSucceeds()); + bound = true; + + sockname = true; + } else { + EXPECT_THAT(bind_ret, SyscallFailsWithErrno(EINVAL)); + } + break; + } + case Operation::Connect: { + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast(&addr.addr), 1337)); + + EXPECT_THAT(connect(fd.get(), addr_in, addr.addr_len), + SyscallSucceeds()) + << GetAddrStr(addr_in); + bound = true; + + sockname = true; + peername = true; + + break; + } + case Operation::SendTo: { + const char payload[] = "hello"; + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast(&addr.addr), 1337)); + + ssize_t sendto_ret = sendto(fd.get(), &payload, sizeof(payload), 0, + addr_in, addr.addr_len); + + EXPECT_THAT(sendto_ret, SyscallSucceedsWithValue(sizeof(payload))); + sockname = !bound; + bound = true; + break; + } + } + + if (sockname) { + sockaddr_storage sock_addr; + socklen_t addrlen = sizeof(sock_addr); + ASSERT_THAT(getsockname(fd.get(), reinterpret_cast(&sock_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); + + auto sock_addr_in6 = reinterpret_cast(&sock_addr); + + if (operation == Operation::SendTo) { + EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6); + EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getsocknam=" + << GetAddrStr(reinterpret_cast(&sock_addr)); + + EXPECT_NE(sock_addr_in6->sin6_port, 0); + } else if (IN6_IS_ADDR_V4MAPPED( + reinterpret_cast(addr_in) + ->sin6_addr.s6_addr32)) { + EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getsocknam=" + << GetAddrStr(reinterpret_cast(&sock_addr)); + } + } + + if (peername) { + sockaddr_storage peer_addr; + socklen_t addrlen = sizeof(peer_addr); + ASSERT_THAT(getpeername(fd.get(), reinterpret_cast(&peer_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); + + if (addr.family() == AF_INET || + IN6_IS_ADDR_V4MAPPED(reinterpret_cast(addr_in) + ->sin6_addr.s6_addr32)) { + EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED( + reinterpret_cast(&peer_addr) + ->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getpeername=" + << GetAddrStr(reinterpret_cast(&peer_addr)); + } + } + } +} + +// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny. +INSTANTIATE_TEST_SUITE_P( + All, DualStackSocketTest, + ::testing::Combine( + ::testing::Values(V4Any(), V4Loopback(), /*V4MappedAny(),*/ + V4MappedLoopback(), V6Any(), V6Loopback()), + ::testing::ValuesIn( + {{Operation::Bind, Operation::Connect, Operation::SendTo}, + {Operation::Bind, Operation::SendTo, Operation::Connect}, + {Operation::Connect, Operation::Bind, Operation::SendTo}, + {Operation::Connect, Operation::SendTo, Operation::Bind}, + {Operation::SendTo, Operation::Bind, Operation::Connect}, + {Operation::SendTo, Operation::Connect, Operation::Bind}})), + [](::testing::TestParamInfo< + std::tuple> const& info) { + const TestAddress& addr = std::get<0>(info.param); + const OperationSequence& operations = std::get<1>(info.param); + std::string s = addr.description; + for (const Operation& operation : operations) { + absl::StrAppend(&s, OperationToString(operation)); + } + return s; + }); + void tcpSimpleConnectTest(TestAddress const& listener, TestAddress const& connector, bool unbound) { // Create the listening socket. -- cgit v1.2.3 From 50625cee59aaff834c7968771ab385ad0e7b0e1f Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Tue, 14 Jan 2020 13:31:52 -0800 Subject: Implement {g,s}etsockopt(IP_RECVTOS) for UDP sockets PiperOrigin-RevId: 289718534 --- pkg/sentry/socket/control/control.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 36 +++++++++++++-- pkg/tcpip/checker/checker.go | 16 +++++++ pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/tcpip.go | 8 +++- pkg/tcpip/transport/udp/endpoint.go | 40 +++++++++++++++-- pkg/tcpip/transport/udp/udp_test.go | 67 ++++++++++++++++++++++++---- test/syscalls/linux/socket_ip_udp_generic.cc | 40 +++++++++++++++++ test/syscalls/linux/udp_socket_test_cases.cc | 8 ++-- 10 files changed, 197 insertions(+), 24 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4301b697c..1684dfc24 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -327,7 +327,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { } // PackTOS packs an IP_TOS socket control message. -func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { +func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index c020c11cb..d2f7e987d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1268,11 +1268,11 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o uint32 + var o int32 if v { o = 1 } - return int32(o), nil + return o, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1377,6 +1377,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return int32(v), nil + case linux.IP_RECVTOS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + var o int32 + if v { + o = 1 + } + return o, nil + default: emitUnimplementedEventIP(t, name) } @@ -1895,6 +1910,13 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + case linux.IP_RECVTOS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1915,7 +1937,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVORIGDSTADDR, - linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2335,7 +2356,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe } func (s *SocketOperations) controlMessages() socket.ControlMessages { - return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}} + return socket.ControlMessages{ + IP: tcpip.ControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + }, + } } // updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 2f15bf1f1..542abc99d 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -33,6 +33,9 @@ type NetworkChecker func(*testing.T, []header.Network) // TransportChecker is a function to check a property of a transport packet. type TransportChecker func(*testing.T, header.Transport) +// ControlMessagesChecker is a function to check a property of ancillary data. +type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) + // IPv4 checks the validity and properties of the given IPv4 packet. It is // expected to be used in conjunction with other network checkers for specific // properties. For example, to check the source and destination address, one @@ -158,6 +161,19 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. +func ReceiveTOS(want uint8) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTOS { + t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + } + if got := cm.TOS; got != want { + t.Fatalf("got cm.TOS = %d, want %d", got, want) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index abf73fe33..071221d5a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -763,7 +763,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Unlock() } -// Subnets returns the Subnets associated with this NIC. +// AddressRanges returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f8d89248e..386eb6eec 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -912,7 +912,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } -// NICSubnets returns a map of NICIDs to their associated subnets. +// NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4a090ac86..b7813cbc0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -322,7 +322,7 @@ type ControlMessages struct { HasTOS bool // TOS is the IPv4 type of service of the associated packet. - TOS int8 + TOS uint8 // HasTClass indicates whether Tclass is valid/set. HasTClass bool @@ -500,9 +500,13 @@ type WriteOptions struct { type SockOptBool int const ( + // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS + // ancillary message is passed with incoming packets. + ReceiveTOSOption SockOptBool = iota + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. - V6OnlyOption SockOptBool = iota + V6OnlyOption ) // SockOptInt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 13446f5d9..c9cbed8f4 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -31,6 +31,7 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -113,6 +114,10 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 + // receiveTOS determines if the incoming IPv4 TOS header field is passed + // as ancillary data to ControlMessages on Read. + receiveTOS bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -243,7 +248,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess *addr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + cm := tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + } + e.mu.RLock() + receiveTOS := e.receiveTOS + e.mu.RUnlock() + if receiveTOS { + cm.HasTOS = true + cm.TOS = p.tos + } + return p.data.ToView(), cm, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -458,6 +474,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { + case tcpip.ReceiveTOSOption: + e.mu.Lock() + e.receiveTOS = v + e.mu.Unlock() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -664,15 +686,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.ReceiveTOSOption: + e.mu.RLock() + v := e.receiveTOS + e.mu.RUnlock() + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { return false, tcpip.ErrUnknownProtocolOption } - e.mu.Lock() + e.mu.RLock() v := e.v6only - e.mu.Unlock() + e.mu.RUnlock() return v, nil } @@ -1215,6 +1243,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvList.PushBack(packet) e.rcvBufSize += pkt.Data.Size() + // Save any useful information from the network header to the packet. + switch r.NetProto { + case header.IPv4ProtocolNumber: + packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + } + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0a82bc4fa..ee9d10555 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -56,6 +56,7 @@ const ( multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast + testTOS = 0x80 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -453,6 +454,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, + TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), @@ -552,8 +554,8 @@ func TestBindToDeviceOption(t *testing.T) { // testReadInternal sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its -// correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { +// correctness including any additional checker functions provided. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { c.t.Helper() payload := newPayload() @@ -568,12 +570,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) + v, cm, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, _, err = c.ep.Read(&addr) + v, cm, err = c.ep.Read(&addr) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -606,15 +608,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + + // Run any checkers against the ControlMessages. + for _, f := range checkers { + f(c.t, cm) + } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it // into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// its correctness including any additional checker functions provided. +func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) } // testFailingRead sends a packet of the given test flow into the stack by @@ -1282,7 +1290,7 @@ func TestTOSV4(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1317,7 +1325,7 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1344,6 +1352,47 @@ func TestTOSV6(t *testing.T) { } } +func TestReceiveTOSV4(t *testing.T) { + for _, flow := range []testFlow{unicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Verify that setting and reading the option works. + v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false) + } + + want := true + if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil { + c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err) + } + + got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) + } + if got != want { + c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want) + } + + // Verify that the correct received TOS is handed through as + // ancillary data to the ControlMessages struct. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + testRead(c, flow, checker.ReceiveTOS(testTOS)) + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 66eb68857..53290bed7 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -209,6 +209,46 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } +// Ensure that Receiving TOS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS works as expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index dc35c2f50..68e0a8109 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1349,8 +1349,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. - SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. + SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && + !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1421,7 +1422,8 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. + // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); -- cgit v1.2.3 From 47d85257d3d015f0b9f7739c81af0ee9f510aaf5 Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Fri, 17 Jan 2020 18:24:39 -0800 Subject: Filter out received packets with a local source IP address. CERT Advisory CA-96.21 III. Solution advises that devices drop packets which could not have correctly arrived on the wire, such as receiving a packet where the source IP address is owned by the device that sent it. Fixes #1507 PiperOrigin-RevId: 290378240 --- pkg/sentry/socket/netstack/netstack.go | 15 +++++----- pkg/sentry/socket/netstack/stack.go | 38 ++++++++++++------------- pkg/tcpip/stack/nic.go | 14 +++++++-- pkg/tcpip/tcpip.go | 10 +++++-- pkg/tcpip/transport/udp/udp_test.go | 52 ++++++++++++++++++++++++++++++++-- 5 files changed, 95 insertions(+), 34 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d2f7e987d..fec575357 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -138,13 +138,14 @@ var Metrics = tcpip.Stats{ }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), - MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), - MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), + InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index a0db2d4fd..31ea66eca 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -148,25 +148,25 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { case *inet.StatSNMPIP: ip := Metrics.IP *stats = inet.StatSNMPIP{ - 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. - 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. - ip.PacketsReceived.Value(), // InReceives. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. - ip.InvalidAddressesReceived.Value(), // InAddrErrors. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. - ip.PacketsDelivered.Value(), // InDelivers. - ip.PacketsSent.Value(), // OutRequests. - ip.OutgoingPacketErrors.Value(), // OutDiscards. - 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. + 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. } case *inet.StatSNMPICMP: in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 53abf29e5..4afe7b744 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -984,7 +984,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when -// the NIC receives a packet from the physical interface. +// the NIC receives a packet from the link endpoint. // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. @@ -1029,6 +1029,14 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link src, dst := netProto.ParseAddresses(pkt.Data.First()) + if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { + // The source address is one of our own, so we never should have gotten a + // packet like this unless handleLocal is false. Loopback also calls this + // function even though the packets didn't come from the physical interface + // so don't drop those. + n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() + return + } if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return @@ -1041,7 +1049,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link if n.stack.Forwarding() { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { - n.stack.stats.IP.InvalidAddressesReceived.Increment() + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() return } defer r.Release() @@ -1079,7 +1087,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // If a packet socket handled the packet, don't treat it as invalid. if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidAddressesReceived.Increment() + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index b7813cbc0..6243762e3 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -903,9 +903,13 @@ type IPStats struct { // link layer in nic.DeliverNetworkPacket. PacketsReceived *StatCounter - // InvalidAddressesReceived is the total number of IP packets received - // with an unknown or invalid destination address. - InvalidAddressesReceived *StatCounter + // InvalidDestinationAddressesReceived is the total number of IP packets + // received with an unknown or invalid destination address. + InvalidDestinationAddressesReceived *StatCounter + + // InvalidSourceAddressesReceived is the total number of IP packets received + // with a source address that should never have been received on the wire. + InvalidSourceAddressesReceived *StatCounter // PacketsDelivered is the total number of incoming IP packets that // are successfully delivered to the transport layer via HandlePacket. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index ee9d10555..51bb61167 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -274,11 +274,16 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - - s := stack.New(stack.Options{ + return newDualTestContextWithOptions(t, mtu, stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, }) +} + +func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { + t.Helper() + + s := stack.New(options) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -763,6 +768,49 @@ func TestV6ReadOnV6(t *testing.T) { testRead(c, unicastV6) } +// TestV4ReadSelfSource checks that packets coming from a local IP address are +// correctly dropped when handleLocal is true and not otherwise. +func TestV4ReadSelfSource(t *testing.T) { + for _, tt := range []struct { + name string + handleLocal bool + wantErr *tcpip.Error + wantInvalidSource uint64 + }{ + {"HandleLocal", false, nil, 0}, + {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, + } { + t.Run(tt.name, func(t *testing.T) { + c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + HandleLocal: tt.handleLocal, + }) + defer c.cleanup() + + c.createEndpointForFlow(unicastV4) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + h.srcAddr = h.dstAddr + + c.injectV4Packet(payload, &h, true /* valid */) + + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { + t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) + } + + if _, _, err := c.ep.Read(nil); err != tt.wantErr { + t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr) + } + }) + } +} + func TestV4ReadOnV4(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() -- cgit v1.2.3 From 7e6fbc6afe797752efe066a8aa86f9eca973f3a4 Mon Sep 17 00:00:00 2001 From: Mithun Iyer Date: Tue, 21 Jan 2020 14:47:04 -0800 Subject: Add a new TCP stat for current open connections. Such a stat accounts for all connections that are currently established and not yet transitioned to close state. Also fix bug in double increment of CurrentEstablished stat. Fixes #1579 PiperOrigin-RevId: 290827365 --- pkg/sentry/socket/netstack/netstack.go | 3 +- pkg/tcpip/tcpip.go | 6 ++- pkg/tcpip/transport/tcp/accept.go | 1 - pkg/tcpip/transport/tcp/connect.go | 1 + pkg/tcpip/transport/tcp/endpoint.go | 1 + pkg/tcpip/transport/tcp/tcp_test.go | 83 ++++++++++++++++++++++++++++++++++ 6 files changed, 92 insertions(+), 3 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 2662fbc0f..318acbeff 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -150,7 +150,8 @@ var Metrics = tcpip.Stats{ TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), - CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."), + CurrentConnected: mustCreateMetric("/netstack/tcp/current_open", "Number of connections that are in connected state."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3fc823a36..59c9b3fb0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -938,9 +938,13 @@ type TCPStats struct { PassiveConnectionOpenings *StatCounter // CurrentEstablished is the number of TCP connections for which the - // current state is either ESTABLISHED or CLOSE-WAIT. + // current state is ESTABLISHED. CurrentEstablished *StatCounter + // CurrentConnected is the number of TCP connections that + // are in connected state. + CurrentConnected *StatCounter + // EstablishedResets is the number of times TCP connections have made // a direct transition to the CLOSED state from either the // ESTABLISHED state or the CLOSE-WAIT state. diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1a2e3efa9..d469758eb 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -562,7 +562,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // Switch state to connected. // We do not use transitionToStateEstablishedLocked here as there is // no handshake state available when doing a SYN cookie based accept. - n.stack.Stats().TCP.CurrentEstablished.Increment() n.isConnectNotified = true n.setEndpointState(StateEstablished) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a2f384384..4e3c5419c 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -934,6 +934,7 @@ func (e *endpoint) transitionToStateCloseLocked() { // Mark the endpoint as fully closed for reads/writes. e.cleanupLocked() e.setEndpointState(StateClose) + e.stack.Stats().TCP.CurrentConnected.Decrement() e.stack.Stats().TCP.EstablishedClosed.Increment() } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4797f11d1..13718ff55 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -594,6 +594,7 @@ func (e *endpoint) setEndpointState(state EndpointState) { switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() + e.stack.Stats().TCP.CurrentConnected.Increment() case StateError: fallthrough case StateClose: diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a9dfbe857..df2fb1071 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -470,6 +470,89 @@ func TestConnectResetAfterClose(t *testing.T) { } } +// TestCurrentConnectedIncrement tests increment of the current +// established and connected counters. +func TestCurrentConnectedIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got) + } + gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() + if gotConnected != 1 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected) + } + + ep.Close() + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected) + } + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Wait for the TIME-WAIT state to transition to CLOSED. + time.Sleep(1 * time.Second) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got) + } +} + // TestClosingWithEnqueuedSegments tests handling of still enqueued segments // when the endpoint transitions to StateClose. The in-flight segments would be // re-enqueued to a any listening endpoint. -- cgit v1.2.3 From d29e59af9fbd420e34378bcbf7ae543134070217 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Mon, 27 Jan 2020 10:04:07 -0800 Subject: Standardize on tools directory. PiperOrigin-RevId: 291745021 --- .bazelrc | 8 +- BUILD | 49 ++++++- benchmarks/defs.bzl | 18 --- benchmarks/harness/BUILD | 74 +++++----- benchmarks/harness/machine_producers/BUILD | 4 +- benchmarks/runner/BUILD | 24 ++-- benchmarks/tcp/BUILD | 3 +- benchmarks/workloads/ab/BUILD | 19 ++- benchmarks/workloads/absl/BUILD | 19 ++- benchmarks/workloads/curl/BUILD | 2 +- benchmarks/workloads/ffmpeg/BUILD | 2 +- benchmarks/workloads/fio/BUILD | 19 ++- benchmarks/workloads/httpd/BUILD | 2 +- benchmarks/workloads/iperf/BUILD | 19 ++- benchmarks/workloads/netcat/BUILD | 2 +- benchmarks/workloads/nginx/BUILD | 2 +- benchmarks/workloads/node/BUILD | 2 +- benchmarks/workloads/node_template/BUILD | 2 +- benchmarks/workloads/redis/BUILD | 2 +- benchmarks/workloads/redisbenchmark/BUILD | 19 ++- benchmarks/workloads/ruby/BUILD | 2 +- benchmarks/workloads/ruby_template/BUILD | 2 +- benchmarks/workloads/sleep/BUILD | 2 +- benchmarks/workloads/sysbench/BUILD | 19 ++- benchmarks/workloads/syscall/BUILD | 19 ++- benchmarks/workloads/tensorflow/BUILD | 2 +- benchmarks/workloads/true/BUILD | 2 +- pkg/abi/BUILD | 3 +- pkg/abi/linux/BUILD | 6 +- pkg/amutex/BUILD | 6 +- pkg/atomicbitops/BUILD | 6 +- pkg/binary/BUILD | 6 +- pkg/bits/BUILD | 6 +- pkg/bpf/BUILD | 6 +- pkg/compressio/BUILD | 6 +- pkg/control/client/BUILD | 3 +- pkg/control/server/BUILD | 3 +- pkg/cpuid/BUILD | 8 +- pkg/eventchannel/BUILD | 16 +-- pkg/fd/BUILD | 6 +- pkg/fdchannel/BUILD | 8 +- pkg/fdnotifier/BUILD | 3 +- pkg/flipcall/BUILD | 8 +- pkg/fspath/BUILD | 13 +- pkg/gate/BUILD | 4 +- pkg/goid/BUILD | 6 +- pkg/ilist/BUILD | 6 +- pkg/linewriter/BUILD | 6 +- pkg/log/BUILD | 6 +- pkg/memutil/BUILD | 3 +- pkg/metric/BUILD | 23 +-- pkg/p9/BUILD | 6 +- pkg/p9/p9test/BUILD | 6 +- pkg/procid/BUILD | 8 +- pkg/rand/BUILD | 3 +- pkg/refs/BUILD | 6 +- pkg/seccomp/BUILD | 6 +- pkg/secio/BUILD | 6 +- pkg/segment/test/BUILD | 6 +- pkg/sentry/BUILD | 2 + pkg/sentry/arch/BUILD | 20 +-- pkg/sentry/context/BUILD | 3 +- pkg/sentry/context/contexttest/BUILD | 3 +- pkg/sentry/control/BUILD | 8 +- pkg/sentry/device/BUILD | 6 +- pkg/sentry/fs/BUILD | 6 +- pkg/sentry/fs/anon/BUILD | 3 +- pkg/sentry/fs/dev/BUILD | 3 +- pkg/sentry/fs/fdpipe/BUILD | 6 +- pkg/sentry/fs/filetest/BUILD | 3 +- pkg/sentry/fs/fsutil/BUILD | 6 +- pkg/sentry/fs/gofer/BUILD | 6 +- pkg/sentry/fs/host/BUILD | 6 +- pkg/sentry/fs/lock/BUILD | 6 +- pkg/sentry/fs/proc/BUILD | 6 +- pkg/sentry/fs/proc/device/BUILD | 3 +- pkg/sentry/fs/proc/seqfile/BUILD | 6 +- pkg/sentry/fs/ramfs/BUILD | 6 +- pkg/sentry/fs/sys/BUILD | 3 +- pkg/sentry/fs/timerfd/BUILD | 3 +- pkg/sentry/fs/tmpfs/BUILD | 6 +- pkg/sentry/fs/tty/BUILD | 6 +- pkg/sentry/fsimpl/ext/BUILD | 6 +- pkg/sentry/fsimpl/ext/benchmark/BUILD | 2 +- pkg/sentry/fsimpl/ext/disklayout/BUILD | 6 +- pkg/sentry/fsimpl/kernfs/BUILD | 6 +- pkg/sentry/fsimpl/proc/BUILD | 8 +- pkg/sentry/fsimpl/sys/BUILD | 6 +- pkg/sentry/fsimpl/testutil/BUILD | 5 +- pkg/sentry/fsimpl/tmpfs/BUILD | 8 +- pkg/sentry/hostcpu/BUILD | 6 +- pkg/sentry/hostmm/BUILD | 3 +- pkg/sentry/inet/BUILD | 3 +- pkg/sentry/kernel/BUILD | 24 +--- pkg/sentry/kernel/auth/BUILD | 3 +- pkg/sentry/kernel/contexttest/BUILD | 3 +- pkg/sentry/kernel/epoll/BUILD | 6 +- pkg/sentry/kernel/eventfd/BUILD | 6 +- pkg/sentry/kernel/fasync/BUILD | 3 +- pkg/sentry/kernel/futex/BUILD | 6 +- pkg/sentry/kernel/memevent/BUILD | 20 +-- pkg/sentry/kernel/pipe/BUILD | 6 +- pkg/sentry/kernel/sched/BUILD | 6 +- pkg/sentry/kernel/semaphore/BUILD | 6 +- pkg/sentry/kernel/shm/BUILD | 3 +- pkg/sentry/kernel/signalfd/BUILD | 5 +- pkg/sentry/kernel/time/BUILD | 3 +- pkg/sentry/limits/BUILD | 6 +- pkg/sentry/loader/BUILD | 4 +- pkg/sentry/memmap/BUILD | 6 +- pkg/sentry/mm/BUILD | 6 +- pkg/sentry/pgalloc/BUILD | 6 +- pkg/sentry/platform/BUILD | 3 +- pkg/sentry/platform/interrupt/BUILD | 6 +- pkg/sentry/platform/kvm/BUILD | 6 +- pkg/sentry/platform/kvm/testutil/BUILD | 3 +- pkg/sentry/platform/ptrace/BUILD | 3 +- pkg/sentry/platform/ring0/BUILD | 3 +- pkg/sentry/platform/ring0/gen_offsets/BUILD | 2 +- pkg/sentry/platform/ring0/pagetables/BUILD | 16 +-- pkg/sentry/platform/safecopy/BUILD | 6 +- pkg/sentry/safemem/BUILD | 6 +- pkg/sentry/sighandling/BUILD | 3 +- pkg/sentry/socket/BUILD | 3 +- pkg/sentry/socket/control/BUILD | 3 +- pkg/sentry/socket/hostinet/BUILD | 3 +- pkg/sentry/socket/netfilter/BUILD | 3 +- pkg/sentry/socket/netlink/BUILD | 3 +- pkg/sentry/socket/netlink/port/BUILD | 6 +- pkg/sentry/socket/netlink/route/BUILD | 3 +- pkg/sentry/socket/netlink/uevent/BUILD | 3 +- pkg/sentry/socket/netstack/BUILD | 3 +- pkg/sentry/socket/unix/BUILD | 3 +- pkg/sentry/socket/unix/transport/BUILD | 3 +- pkg/sentry/state/BUILD | 3 +- pkg/sentry/strace/BUILD | 20 +-- pkg/sentry/syscalls/BUILD | 3 +- pkg/sentry/syscalls/linux/BUILD | 3 +- pkg/sentry/time/BUILD | 6 +- pkg/sentry/unimpl/BUILD | 21 +-- pkg/sentry/uniqueid/BUILD | 3 +- pkg/sentry/usage/BUILD | 5 +- pkg/sentry/usermem/BUILD | 7 +- pkg/sentry/vfs/BUILD | 8 +- pkg/sentry/watchdog/BUILD | 3 +- pkg/sleep/BUILD | 6 +- pkg/state/BUILD | 17 +-- pkg/state/statefile/BUILD | 6 +- pkg/sync/BUILD | 6 +- pkg/sync/atomicptrtest/BUILD | 6 +- pkg/sync/seqatomictest/BUILD | 6 +- pkg/syserr/BUILD | 3 +- pkg/syserror/BUILD | 4 +- pkg/tcpip/BUILD | 6 +- pkg/tcpip/adapters/gonet/BUILD | 6 +- pkg/tcpip/buffer/BUILD | 6 +- pkg/tcpip/checker/BUILD | 3 +- pkg/tcpip/hash/jenkins/BUILD | 6 +- pkg/tcpip/header/BUILD | 6 +- pkg/tcpip/iptables/BUILD | 3 +- pkg/tcpip/link/channel/BUILD | 3 +- pkg/tcpip/link/fdbased/BUILD | 6 +- pkg/tcpip/link/loopback/BUILD | 3 +- pkg/tcpip/link/muxed/BUILD | 6 +- pkg/tcpip/link/rawfile/BUILD | 3 +- pkg/tcpip/link/sharedmem/BUILD | 6 +- pkg/tcpip/link/sharedmem/pipe/BUILD | 6 +- pkg/tcpip/link/sharedmem/queue/BUILD | 6 +- pkg/tcpip/link/sniffer/BUILD | 3 +- pkg/tcpip/link/tun/BUILD | 3 +- pkg/tcpip/link/waitable/BUILD | 6 +- pkg/tcpip/network/BUILD | 2 +- pkg/tcpip/network/arp/BUILD | 4 +- pkg/tcpip/network/fragmentation/BUILD | 6 +- pkg/tcpip/network/hash/BUILD | 3 +- pkg/tcpip/network/ipv4/BUILD | 4 +- pkg/tcpip/network/ipv6/BUILD | 6 +- pkg/tcpip/ports/BUILD | 6 +- pkg/tcpip/sample/tun_tcp_connect/BUILD | 2 +- pkg/tcpip/sample/tun_tcp_echo/BUILD | 2 +- pkg/tcpip/seqnum/BUILD | 3 +- pkg/tcpip/stack/BUILD | 6 +- pkg/tcpip/transport/icmp/BUILD | 3 +- pkg/tcpip/transport/packet/BUILD | 3 +- pkg/tcpip/transport/raw/BUILD | 3 +- pkg/tcpip/transport/tcp/BUILD | 4 +- pkg/tcpip/transport/tcp/testing/context/BUILD | 3 +- pkg/tcpip/transport/tcpconntrack/BUILD | 4 +- pkg/tcpip/transport/udp/BUILD | 4 +- pkg/tmutex/BUILD | 6 +- pkg/unet/BUILD | 6 +- pkg/urpc/BUILD | 6 +- pkg/waiter/BUILD | 6 +- runsc/BUILD | 27 ++-- runsc/boot/BUILD | 5 +- runsc/boot/filter/BUILD | 3 +- runsc/boot/platforms/BUILD | 3 +- runsc/cgroup/BUILD | 5 +- runsc/cmd/BUILD | 5 +- runsc/console/BUILD | 3 +- runsc/container/BUILD | 5 +- runsc/container/test_app/BUILD | 4 +- runsc/criutil/BUILD | 3 +- runsc/dockerutil/BUILD | 3 +- runsc/fsgofer/BUILD | 9 +- runsc/fsgofer/filter/BUILD | 3 +- runsc/sandbox/BUILD | 3 +- runsc/specutils/BUILD | 5 +- runsc/testutil/BUILD | 3 +- runsc/version_test.sh | 2 +- scripts/common.sh | 6 +- scripts/common_bazel.sh | 99 ------------- scripts/common_build.sh | 99 +++++++++++++ test/BUILD | 45 +----- test/e2e/BUILD | 5 +- test/image/BUILD | 5 +- test/iptables/BUILD | 5 +- test/iptables/runner/BUILD | 12 +- test/root/BUILD | 5 +- test/root/testdata/BUILD | 3 +- test/runtimes/BUILD | 4 +- test/runtimes/build_defs.bzl | 5 +- test/runtimes/images/proctor/BUILD | 4 +- test/syscalls/BUILD | 2 +- test/syscalls/build_defs.bzl | 6 +- test/syscalls/gtest/BUILD | 7 +- test/syscalls/linux/BUILD | 23 ++- test/syscalls/linux/arch_prctl.cc | 2 + test/syscalls/linux/rseq/BUILD | 5 +- .../linux/udp_socket_errqueue_test_case.cc | 4 + test/uds/BUILD | 3 +- test/util/BUILD | 27 ++-- test/util/save_util_linux.cc | 4 + test/util/save_util_other.cc | 4 + test/util/test_util_runfiles.cc | 4 + tools/BUILD | 3 + tools/build/BUILD | 10 ++ tools/build/defs.bzl | 91 ++++++++++++ tools/checkunsafe/BUILD | 3 +- tools/defs.bzl | 154 +++++++++++++++++++++ tools/go_generics/BUILD | 2 +- tools/go_generics/globals/BUILD | 4 +- tools/go_generics/go_merge/BUILD | 2 +- tools/go_generics/rules_tests/BUILD | 2 +- tools/go_marshal/BUILD | 4 +- tools/go_marshal/README.md | 52 +------ tools/go_marshal/analysis/BUILD | 5 +- tools/go_marshal/defs.bzl | 112 ++------------- tools/go_marshal/gomarshal/BUILD | 6 +- tools/go_marshal/gomarshal/generator.go | 20 ++- tools/go_marshal/gomarshal/generator_tests.go | 6 +- tools/go_marshal/main.go | 11 +- tools/go_marshal/marshal/BUILD | 5 +- tools/go_marshal/test/BUILD | 7 +- tools/go_marshal/test/external/BUILD | 6 +- tools/go_stateify/BUILD | 2 +- tools/go_stateify/defs.bzl | 79 +---------- tools/images/BUILD | 2 +- tools/images/defs.bzl | 6 +- tools/issue_reviver/BUILD | 2 +- tools/issue_reviver/github/BUILD | 3 +- tools/issue_reviver/reviver/BUILD | 5 +- tools/workspace_status.sh | 2 +- vdso/BUILD | 33 ++--- 264 files changed, 1012 insertions(+), 1380 deletions(-) delete mode 100644 benchmarks/defs.bzl delete mode 100755 scripts/common_bazel.sh create mode 100755 scripts/common_build.sh create mode 100644 tools/BUILD create mode 100644 tools/build/BUILD create mode 100644 tools/build/defs.bzl create mode 100644 tools/defs.bzl (limited to 'pkg/sentry/socket/netstack') diff --git a/.bazelrc b/.bazelrc index 9c35c5e7b..ef214bcfa 100644 --- a/.bazelrc +++ b/.bazelrc @@ -30,10 +30,10 @@ build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" # Add a custom platform and toolchain that builds in a privileged docker # container, which is required by our syscall tests. -build:remote --host_platform=//test:rbe_ubuntu1604 -build:remote --extra_toolchains=//test:cc-toolchain-clang-x86_64-default -build:remote --extra_execution_platforms=//test:rbe_ubuntu1604 -build:remote --platforms=//test:rbe_ubuntu1604 +build:remote --host_platform=//:rbe_ubuntu1604 +build:remote --extra_toolchains=//:cc-toolchain-clang-x86_64-default +build:remote --extra_execution_platforms=//:rbe_ubuntu1604 +build:remote --platforms=//:rbe_ubuntu1604 build:remote --crosstool_top=@rbe_default//cc:toolchain build:remote --jobs=50 build:remote --remote_timeout=3600 diff --git a/BUILD b/BUILD index 76286174f..5fd929378 100644 --- a/BUILD +++ b/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) # Apache 2.0 - load("@io_bazel_rules_go//go:def.bzl", "go_path", "nogo") load("@bazel_gazelle//:def.bzl", "gazelle") +package(licenses = ["notice"]) + # The sandbox filegroup is used for sandbox-internal dependencies. package_group( name = "sandbox", @@ -49,9 +49,52 @@ gazelle(name = "gazelle") # live in the tools subdirectory (unless they are standard). nogo( name = "nogo", - config = "tools/nogo.js", + config = "//tools:nogo.js", visibility = ["//visibility:public"], deps = [ "//tools/checkunsafe", ], ) + +# We need to define a bazel platform and toolchain to specify dockerPrivileged +# and dockerRunAsRoot options, they are required to run tests on the RBE +# cluster in Kokoro. +alias( + name = "rbe_ubuntu1604", + actual = ":rbe_ubuntu1604_r346485", +) + +platform( + name = "rbe_ubuntu1604_r346485", + constraint_values = [ + "@bazel_tools//platforms:x86_64", + "@bazel_tools//platforms:linux", + "@bazel_tools//tools/cpp:clang", + "@bazel_toolchains//constraints:xenial", + "@bazel_toolchains//constraints/sanitizers:support_msan", + ], + remote_execution_properties = """ + properties: { + name: "container-image" + value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" + } + properties: { + name: "dockerAddCapabilities" + value: "SYS_ADMIN" + } + properties: { + name: "dockerPrivileged" + value: "true" + } + """, +) + +toolchain( + name = "cc-toolchain-clang-x86_64-default", + exec_compatible_with = [ + ], + target_compatible_with = [ + ], + toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl deleted file mode 100644 index 79e6cdbc8..000000000 --- a/benchmarks/defs.bzl +++ /dev/null @@ -1,18 +0,0 @@ -"""Provides python helper functions.""" - -load("@pydeps//:requirements.bzl", _requirement = "requirement") - -def filter_deps(deps = None): - if deps == None: - deps = [] - return [dep for dep in deps if dep] - -def py_library(deps = None, **kwargs): - return native.py_library(deps = filter_deps(deps), **kwargs) - -def py_test(deps = None, **kwargs): - return native.py_test(deps = filter_deps(deps), **kwargs) - -def requirement(name, direct = True): - """ requirement returns the required dependency. """ - return _requirement(name) diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD index 081a74243..52d4e42f8 100644 --- a/benchmarks/harness/BUILD +++ b/benchmarks/harness/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -25,16 +25,16 @@ py_library( srcs = ["container.py"], deps = [ "//benchmarks/workloads", - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) @@ -47,17 +47,17 @@ py_library( "//benchmarks/harness:ssh_connection", "//benchmarks/harness:tunnel_dispatcher", "//benchmarks/harness/machine_mocks", - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("six", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("six", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) @@ -66,10 +66,10 @@ py_library( srcs = ["ssh_connection.py"], deps = [ "//benchmarks/harness", - requirement("bcrypt", False), - requirement("cffi", True), - requirement("paramiko", True), - requirement("cryptography", False), + py_requirement("bcrypt", False), + py_requirement("cffi", True), + py_requirement("paramiko", True), + py_requirement("cryptography", False), ], ) @@ -77,16 +77,16 @@ py_library( name = "tunnel_dispatcher", srcs = ["tunnel_dispatcher.py"], deps = [ - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("pexpect", True), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("pexpect", True), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD index c4e943882..48ea0ef39 100644 --- a/benchmarks/harness/machine_producers/BUILD +++ b/benchmarks/harness/machine_producers/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -31,7 +31,7 @@ py_library( deps = [ "//benchmarks/harness:machine", "//benchmarks/harness/machine_producers:machine_producer", - requirement("PyYAML", False), + py_requirement("PyYAML", False), ], ) diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD index e1b2ea550..fae0ca800 100644 --- a/benchmarks/runner/BUILD +++ b/benchmarks/runner/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement", "py_test") package(licenses = ["notice"]) @@ -28,7 +28,7 @@ py_library( "//benchmarks/suites:startup", "//benchmarks/suites:sysbench", "//benchmarks/suites:syscall", - requirement("click", True), + py_requirement("click", True), ], ) @@ -36,7 +36,7 @@ py_library( name = "commands", srcs = ["commands.py"], deps = [ - requirement("click", True), + py_requirement("click", True), ], ) @@ -50,14 +50,14 @@ py_test( ], deps = [ ":runner", - requirement("click", True), - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("click", True), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/tcp/BUILD b/benchmarks/tcp/BUILD index 735d7127f..d5e401acc 100644 --- a/benchmarks/tcp/BUILD +++ b/benchmarks/tcp/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") -load("@rules_cc//cc:defs.bzl", "cc_binary") +load("//tools:defs.bzl", "cc_binary", "go_binary") package(licenses = ["notice"]) diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD index 4fc0ab735..4dd91ceb3 100644 --- a/benchmarks/workloads/ab/BUILD +++ b/benchmarks/workloads/ab/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":ab", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD index 61e010096..55dae3baa 100644 --- a/benchmarks/workloads/absl/BUILD +++ b/benchmarks/workloads/absl/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":absl", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/curl/BUILD +++ b/benchmarks/workloads/curl/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD index be472dfb2..7c41ba631 100644 --- a/benchmarks/workloads/ffmpeg/BUILD +++ b/benchmarks/workloads/ffmpeg/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD index de257adad..7b78e8e75 100644 --- a/benchmarks/workloads/fio/BUILD +++ b/benchmarks/workloads/fio/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":fio", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/httpd/BUILD +++ b/benchmarks/workloads/httpd/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD index 8832a996c..570f40148 100644 --- a/benchmarks/workloads/iperf/BUILD +++ b/benchmarks/workloads/iperf/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":iperf", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/netcat/BUILD +++ b/benchmarks/workloads/netcat/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/nginx/BUILD +++ b/benchmarks/workloads/nginx/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD index 71cd9f519..bfcf78cf9 100644 --- a/benchmarks/workloads/node/BUILD +++ b/benchmarks/workloads/node/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD index ca996f068..e142f082a 100644 --- a/benchmarks/workloads/node_template/BUILD +++ b/benchmarks/workloads/node_template/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/redis/BUILD +++ b/benchmarks/workloads/redis/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD index f5994a815..f472a4443 100644 --- a/benchmarks/workloads/redisbenchmark/BUILD +++ b/benchmarks/workloads/redisbenchmark/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":redisbenchmark", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD index e37d77804..a3be4fe92 100644 --- a/benchmarks/workloads/ruby/BUILD +++ b/benchmarks/workloads/ruby/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD index 27f7c0c46..59443b14a 100644 --- a/benchmarks/workloads/ruby_template/BUILD +++ b/benchmarks/workloads/ruby_template/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/sleep/BUILD +++ b/benchmarks/workloads/sleep/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD index fd2f8f03d..3834af7ed 100644 --- a/benchmarks/workloads/sysbench/BUILD +++ b/benchmarks/workloads/sysbench/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":sysbench", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD index 5100cbb21..dba4bb1e7 100644 --- a/benchmarks/workloads/syscall/BUILD +++ b/benchmarks/workloads/syscall/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":syscall", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD index 026c3b316..a7b7742f4 100644 --- a/benchmarks/workloads/tensorflow/BUILD +++ b/benchmarks/workloads/tensorflow/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD index 221c4b9a7..eba23d325 100644 --- a/benchmarks/workloads/true/BUILD +++ b/benchmarks/workloads/true/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD index f5c08ea06..839f822eb 100644 --- a/pkg/abi/BUILD +++ b/pkg/abi/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,6 +9,5 @@ go_library( "abi_linux.go", "flag.go", ], - importpath = "gvisor.dev/gvisor/pkg/abi", visibility = ["//:sandbox"], ) diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 716ff22d2..1f3c0c687 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") # Package linux contains the constants and types needed to interface with a # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix @@ -60,7 +59,6 @@ go_library( "wait.go", "xattr.go", ], - importpath = "gvisor.dev/gvisor/pkg/abi/linux", visibility = ["//visibility:public"], deps = [ "//pkg/abi", @@ -73,7 +71,7 @@ go_test( name = "linux_test", size = "small", srcs = ["netfilter_test.go"], - embed = [":linux"], + library = ":linux", deps = [ "//pkg/binary", ], diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index d99e37b40..9612f072e 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "amutex", srcs = ["amutex.go"], - importpath = "gvisor.dev/gvisor/pkg/amutex", visibility = ["//:sandbox"], ) @@ -14,6 +12,6 @@ go_test( name = "amutex_test", size = "small", srcs = ["amutex_test.go"], - embed = [":amutex"], + library = ":amutex", deps = ["//pkg/sync"], ) diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 6403c60c2..3948074ba 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "atomic_bitops_arm64.s", "atomic_bitops_common.go", ], - importpath = "gvisor.dev/gvisor/pkg/atomicbitops", visibility = ["//:sandbox"], ) @@ -19,6 +17,6 @@ go_test( name = "atomicbitops_test", size = "small", srcs = ["atomic_bitops_test.go"], - embed = [":atomicbitops"], + library = ":atomicbitops", deps = ["//pkg/sync"], ) diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD index 543fb54bf..7ca2fda90 100644 --- a/pkg/binary/BUILD +++ b/pkg/binary/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "binary", srcs = ["binary.go"], - importpath = "gvisor.dev/gvisor/pkg/binary", visibility = ["//:sandbox"], ) @@ -14,5 +12,5 @@ go_test( name = "binary_test", size = "small", srcs = ["binary_test.go"], - embed = [":binary"], + library = ":binary", ) diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 93b88a29a..63f4670d7 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -15,7 +14,6 @@ go_library( "uint64_arch_arm64_asm.s", "uint64_arch_generic.go", ], - importpath = "gvisor.dev/gvisor/pkg/bits", visibility = ["//:sandbox"], ) @@ -53,5 +51,5 @@ go_test( name = "bits_test", size = "small", srcs = ["uint64_test.go"], - embed = [":bits"], + library = ":bits", ) diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index fba5643e8..2a6977f85 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "interpreter.go", "program_builder.go", ], - importpath = "gvisor.dev/gvisor/pkg/bpf", visibility = ["//visibility:public"], deps = ["//pkg/abi/linux"], ) @@ -25,7 +23,7 @@ go_test( "interpreter_test.go", "program_builder_test.go", ], - embed = [":bpf"], + library = ":bpf", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index 2bb581b18..1f75319a7 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "compressio", srcs = ["compressio.go"], - importpath = "gvisor.dev/gvisor/pkg/compressio", visibility = ["//:sandbox"], deps = [ "//pkg/binary", @@ -18,5 +16,5 @@ go_test( name = "compressio_test", size = "medium", srcs = ["compressio_test.go"], - embed = [":compressio"], + library = ":compressio", ) diff --git a/pkg/control/client/BUILD b/pkg/control/client/BUILD index 066d7b1a1..1b9e10ee7 100644 --- a/pkg/control/client/BUILD +++ b/pkg/control/client/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "client.go", ], - importpath = "gvisor.dev/gvisor/pkg/control/client", visibility = ["//:sandbox"], deps = [ "//pkg/unet", diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD index adbd1e3f8..002d2ef44 100644 --- a/pkg/control/server/BUILD +++ b/pkg/control/server/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "server", srcs = ["server.go"], - importpath = "gvisor.dev/gvisor/pkg/control/server", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index ed111fd2a..43a432190 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "cpu_amd64.s", "cpuid.go", ], - importpath = "gvisor.dev/gvisor/pkg/cpuid", visibility = ["//:sandbox"], deps = ["//pkg/log"], ) @@ -18,7 +16,7 @@ go_test( name = "cpuid_test", size = "small", srcs = ["cpuid_test.go"], - embed = [":cpuid"], + library = ":cpuid", ) go_test( @@ -27,6 +25,6 @@ go_test( srcs = [ "cpuid_parse_test.go", ], - embed = [":cpuid"], + library = ":cpuid", tags = ["manual"], ) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 9d68682c7..bee28b68d 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") package(licenses = ["notice"]) @@ -10,7 +8,6 @@ go_library( "event.go", "rate.go", ], - importpath = "gvisor.dev/gvisor/pkg/eventchannel", visibility = ["//:sandbox"], deps = [ ":eventchannel_go_proto", @@ -24,22 +21,15 @@ go_library( ) proto_library( - name = "eventchannel_proto", + name = "eventchannel", srcs = ["event.proto"], visibility = ["//:sandbox"], ) -go_proto_library( - name = "eventchannel_go_proto", - importpath = "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto", - proto = ":eventchannel_proto", - visibility = ["//:sandbox"], -) - go_test( name = "eventchannel_test", srcs = ["event_test.go"], - embed = [":eventchannel"], + library = ":eventchannel", deps = [ "//pkg/sync", "@com_github_golang_protobuf//proto:go_default_library", diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index afa8f7659..872361546 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "fd", srcs = ["fd.go"], - importpath = "gvisor.dev/gvisor/pkg/fd", visibility = ["//visibility:public"], ) @@ -14,5 +12,5 @@ go_test( name = "fd_test", size = "small", srcs = ["fd_test.go"], - embed = [":fd"], + library = ":fd", ) diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index b0478c672..d9104ef02 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "fdchannel", srcs = ["fdchannel_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/fdchannel", visibility = ["//visibility:public"], ) @@ -14,6 +12,6 @@ go_test( name = "fdchannel_test", size = "small", srcs = ["fdchannel_test.go"], - embed = [":fdchannel"], + library = ":fdchannel", deps = ["//pkg/sync"], ) diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD index 91a202a30..235dcc490 100644 --- a/pkg/fdnotifier/BUILD +++ b/pkg/fdnotifier/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "fdnotifier.go", "poll_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/fdnotifier", visibility = ["//:sandbox"], deps = [ "//pkg/sync", diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index 85bd83af1..9c5ad500b 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "flipcall", @@ -13,7 +12,6 @@ go_library( "io.go", "packet_window_allocator.go", ], - importpath = "gvisor.dev/gvisor/pkg/flipcall", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", @@ -30,6 +28,6 @@ go_test( "flipcall_example_test.go", "flipcall_test.go", ], - embed = [":flipcall"], + library = ":flipcall", deps = ["//pkg/sync"], ) diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index ca540363c..ee84471b2 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -1,10 +1,8 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) go_library( name = "fspath", @@ -13,7 +11,6 @@ go_library( "builder_unsafe.go", "fspath.go", ], - importpath = "gvisor.dev/gvisor/pkg/fspath", ) go_test( @@ -23,5 +20,5 @@ go_test( "builder_test.go", "fspath_test.go", ], - embed = [":fspath"], + library = ":fspath", ) diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index f22bd070d..dd3141143 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "gate.go", ], - importpath = "gvisor.dev/gvisor/pkg/gate", visibility = ["//visibility:public"], ) diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD index 5d31e5366..ea8d2422c 100644 --- a/pkg/goid/BUILD +++ b/pkg/goid/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "goid_race.go", "goid_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/goid", visibility = ["//visibility:public"], ) @@ -22,5 +20,5 @@ go_test( "empty_test.go", "goid_test.go", ], - embed = [":goid"], + library = ":goid", ) diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD index 34d2673ef..3f6eb07df 100644 --- a/pkg/ilist/BUILD +++ b/pkg/ilist/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( srcs = [ "interface_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/ilist", visibility = ["//visibility:public"], ) @@ -41,7 +39,7 @@ go_test( "list_test.go", "test_list.go", ], - embed = [":ilist"], + library = ":ilist", ) go_template( diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index bcde6d308..41bf104d0 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "linewriter", srcs = ["linewriter.go"], - importpath = "gvisor.dev/gvisor/pkg/linewriter", visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) @@ -14,5 +12,5 @@ go_library( go_test( name = "linewriter_test", srcs = ["linewriter_test.go"], - embed = [":linewriter"], + library = ":linewriter", ) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 0df0f2849..935d06963 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "json_k8s.go", "log.go", ], - importpath = "gvisor.dev/gvisor/pkg/log", visibility = [ "//visibility:public", ], @@ -29,5 +27,5 @@ go_test( "json_test.go", "log_test.go", ], - embed = [":log"], + library = ":log", ) diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD index 7b50e2b28..9d07d98b4 100644 --- a/pkg/memutil/BUILD +++ b/pkg/memutil/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "memutil", srcs = ["memutil_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/memutil", visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 9145f3233..58305009d 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -1,14 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") package(licenses = ["notice"]) go_library( name = "metric", srcs = ["metric.go"], - importpath = "gvisor.dev/gvisor/pkg/metric", visibility = ["//:sandbox"], deps = [ ":metric_go_proto", @@ -19,28 +15,15 @@ go_library( ) proto_library( - name = "metric_proto", + name = "metric", srcs = ["metric.proto"], visibility = ["//:sandbox"], ) -cc_proto_library( - name = "metric_cc_proto", - visibility = ["//:sandbox"], - deps = [":metric_proto"], -) - -go_proto_library( - name = "metric_go_proto", - importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto", - proto = ":metric_proto", - visibility = ["//:sandbox"], -) - go_test( name = "metric_test", srcs = ["metric_test.go"], - embed = [":metric"], + library = ":metric", deps = [ ":metric_go_proto", "//pkg/eventchannel", diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index a3e05c96d..4ccc1de86 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package( default_visibility = ["//visibility:public"], @@ -23,7 +22,6 @@ go_library( "transport_flipcall.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/p9", deps = [ "//pkg/fd", "//pkg/fdchannel", @@ -47,7 +45,7 @@ go_test( "transport_test.go", "version_test.go", ], - embed = [":p9"], + library = ":p9", deps = [ "//pkg/fd", "//pkg/unet", diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index f4edd68b2..7ca67cb19 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") package(licenses = ["notice"]) @@ -64,7 +63,6 @@ go_library( "mocks.go", "p9test.go", ], - importpath = "gvisor.dev/gvisor/pkg/p9/p9test", visibility = ["//:sandbox"], deps = [ "//pkg/fd", @@ -80,7 +78,7 @@ go_test( name = "client_test", size = "medium", srcs = ["client_test.go"], - embed = [":p9test"], + library = ":p9test", deps = [ "//pkg/fd", "//pkg/p9", diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index b506813f0..aa3e3ac0b 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "procid_amd64.s", "procid_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/procid", visibility = ["//visibility:public"], ) @@ -20,7 +18,7 @@ go_test( srcs = [ "procid_test.go", ], - embed = [":procid"], + library = ":procid", deps = ["//pkg/sync"], ) @@ -31,6 +29,6 @@ go_test( "procid_net_test.go", "procid_test.go", ], - embed = [":procid"], + library = ":procid", deps = ["//pkg/sync"], ) diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD index 9d5b4859b..80b8ceb02 100644 --- a/pkg/rand/BUILD +++ b/pkg/rand/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "rand.go", "rand_linux.go", ], - importpath = "gvisor.dev/gvisor/pkg/rand", visibility = ["//:sandbox"], deps = [ "//pkg/sync", diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 974d9af9b..74affc887 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "refcounter_state.go", "weak_ref_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/refs", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -35,6 +33,6 @@ go_test( name = "refs_test", size = "small", srcs = ["refcounter_test.go"], - embed = [":refs"], + library = ":refs", deps = ["//pkg/sync"], ) diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index af94e944d..742c8b79b 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test") +load("//tools:defs.bzl", "go_binary", "go_embed_data", "go_library", "go_test") package(licenses = ["notice"]) @@ -27,7 +26,6 @@ go_library( "seccomp_rules.go", "seccomp_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/seccomp", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", @@ -43,7 +41,7 @@ go_test( "seccomp_test.go", ":victim_data", ], - embed = [":seccomp"], + library = ":seccomp", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD index 22abdc69f..60f63c7a6 100644 --- a/pkg/secio/BUILD +++ b/pkg/secio/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "full_reader.go", "secio.go", ], - importpath = "gvisor.dev/gvisor/pkg/secio", visibility = ["//pkg/sentry:internal"], ) @@ -17,5 +15,5 @@ go_test( name = "secio_test", size = "small", srcs = ["secio_test.go"], - embed = [":secio"], + library = ":secio", ) diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index a27c35e21..f2d8462d8 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package( @@ -38,7 +37,6 @@ go_library( "int_set.go", "set_functions.go", ], - importpath = "gvisor.dev/gvisor/pkg/segment/segment", deps = [ "//pkg/state", ], @@ -48,5 +46,5 @@ go_test( name = "segment_test", size = "small", srcs = ["segment_test.go"], - embed = [":segment"], + library = ":segment", ) diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD index 2d6379c86..e8b794179 100644 --- a/pkg/sentry/BUILD +++ b/pkg/sentry/BUILD @@ -6,6 +6,8 @@ package(licenses = ["notice"]) package_group( name = "internal", packages = [ + "//cloud/gvisor/gopkg/sentry/...", + "//cloud/gvisor/sentry/...", "//pkg/sentry/...", "//runsc/...", # Code generated by go_marshal relies on go_marshal libraries. diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 65f22af2b..51ca09b24 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -1,6 +1,4 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) @@ -27,7 +25,6 @@ go_library( "syscalls_amd64.go", "syscalls_arm64.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/arch", visibility = ["//:sandbox"], deps = [ ":registers_go_proto", @@ -44,20 +41,7 @@ go_library( ) proto_library( - name = "registers_proto", + name = "registers", srcs = ["registers.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "registers_cc_proto", - visibility = ["//visibility:public"], - deps = [":registers_proto"], -) - -go_proto_library( - name = "registers_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto", - proto = ":registers_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/context/BUILD b/pkg/sentry/context/BUILD index 8dc1a77b1..e13a9ce20 100644 --- a/pkg/sentry/context/BUILD +++ b/pkg/sentry/context/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "context", srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/context", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/amutex", diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD index 581e7aa96..f91a6d4ed 100644 --- a/pkg/sentry/context/contexttest/BUILD +++ b/pkg/sentry/context/contexttest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "contexttest", testonly = 1, srcs = ["contexttest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/context/contexttest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/memutil", diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 2561a6109..e69496477 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,9 +11,8 @@ go_library( "proc.go", "state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/control", visibility = [ - "//pkg/sentry:internal", + "//:sandbox", ], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "control_test", size = "small", srcs = ["proc_test.go"], - embed = [":control"], + library = ":control", deps = [ "//pkg/log", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 97fa1512c..e403cbd8b 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "device", srcs = ["device.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/device", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -18,5 +16,5 @@ go_test( name = "device_test", size = "small", srcs = ["device_test.go"], - embed = [":device"], + library = ":device", ) diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 7d5d72d5a..605d61dbe 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -44,7 +43,6 @@ go_library( "splice.go", "sync.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -129,7 +127,7 @@ go_test( "mount_test.go", "path_test.go", ], - embed = [":fs"], + library = ":fs", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD index ae1c9cf76..c14e5405e 100644 --- a/pkg/sentry/fs/anon/BUILD +++ b/pkg/sentry/fs/anon/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "anon.go", "device.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/anon", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index a0d9e8496..0c7247bd7 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "random.go", "tty.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/dev", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index cc43de69d..25ef96299 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "pipe_opener.go", "pipe_state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fdpipe", imports = ["gvisor.dev/gvisor/pkg/sentry/fs"], visibility = ["//pkg/sentry:internal"], deps = [ @@ -36,7 +34,7 @@ go_test( "pipe_opener_test.go", "pipe_test.go", ], - embed = [":fdpipe"], + library = ":fdpipe", deps = [ "//pkg/fd", "//pkg/fdnotifier", diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD index 358dc2be3..9a7608cae 100644 --- a/pkg/sentry/fs/filetest/BUILD +++ b/pkg/sentry/fs/filetest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "filetest", testonly = 1, srcs = ["filetest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/filetest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 945b6270d..9142f5bdf 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -75,7 +74,6 @@ go_library( "inode.go", "inode_cached.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fsutil", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -106,7 +104,7 @@ go_test( "dirty_set_test.go", "inode_cached_test.go", ], - embed = [":fsutil"], + library = ":fsutil", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index fd870e8e1..cf48e7c03 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "socket.go", "util.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/gofer", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -56,7 +54,7 @@ go_test( name = "gofer_test", size = "small", srcs = ["gofer_test.go"], - embed = [":gofer"], + library = ":gofer", deps = [ "//pkg/p9", "//pkg/p9/p9test", diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 2b581aa69..f586f47c1 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -25,7 +24,6 @@ go_library( "util_arm64_unsafe.go", "util_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -69,7 +67,7 @@ go_test( "socket_test.go", "wait_test.go", ], - embed = [":host"], + library = ":host", deps = [ "//pkg/fd", "//pkg/fdnotifier", diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 2c332a82a..ae3331737 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -40,7 +39,6 @@ go_library( "lock_set.go", "lock_set_functions.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/lock", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -56,5 +54,5 @@ go_test( "lock_range_test.go", "lock_test.go", ], - embed = [":lock"], + library = ":lock", ) diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index cb37c6c6b..b06bead41 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -27,7 +26,6 @@ go_library( "uptime.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -63,7 +61,7 @@ go_test( "net_test.go", "sys_net_test.go", ], - embed = [":proc"], + library = ":proc", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/fs/proc/device/BUILD b/pkg/sentry/fs/proc/device/BUILD index 0394451d4..52c9aa93d 100644 --- a/pkg/sentry/fs/proc/device/BUILD +++ b/pkg/sentry/fs/proc/device/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "device", srcs = ["device.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/device", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sentry/device"], ) diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 38b246dff..310d8dd52 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "seqfile", srcs = ["seqfile.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -26,7 +24,7 @@ go_test( name = "seqfile_test", size = "small", srcs = ["seqfile_test.go"], - embed = [":seqfile"], + library = ":seqfile", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 3fb7b0633..39c4b84f8 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "symlink.go", "tree.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ramfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -31,7 +29,7 @@ go_test( name = "ramfs_test", size = "small", srcs = ["tree_test.go"], - embed = [":ramfs"], + library = ":ramfs", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD index 25f0f124e..cc6b3bfbf 100644 --- a/pkg/sentry/fs/sys/BUILD +++ b/pkg/sentry/fs/sys/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "fs.go", "sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/sys", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index a215c1b95..092668e8d 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "timerfd", srcs = ["timerfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/timerfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 3400b940c..04776555f 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "inode_file.go", "tmpfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -41,7 +39,7 @@ go_test( name = "tmpfs_test", size = "small", srcs = ["file_test.go"], - embed = [":tmpfs"], + library = ":tmpfs", deps = [ "//pkg/sentry/context", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index f6f60d0cf..29f804c6c 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -14,7 +13,6 @@ go_library( "slave.go", "terminal.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tty", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "tty_test", size = "small", srcs = ["tty_test.go"], - embed = [":tty"], + library = ":tty", deps = [ "//pkg/abi/linux", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 903874141..a718920d5 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -32,7 +31,6 @@ go_library( "symlink.go", "utils.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -71,7 +69,7 @@ go_test( "//pkg/sentry/fsimpl/ext:assets/tiny.ext3", "//pkg/sentry/fsimpl/ext:assets/tiny.ext4", ], - embed = [":ext"], + library = ":ext", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 4fc8296ef..12f3990c1 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index fcfaf5c3e..9bd9c76c0 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "superblock_old.go", "test_utils.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -44,6 +42,6 @@ go_test( "inode_test.go", "superblock_test.go", ], - embed = [":disklayout"], + library = ":disklayout", deps = ["//pkg/sentry/kernel/time"], ) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 66d409785..7bf83ccba 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -1,8 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -package(licenses = ["notice"]) +licenses(["notice"]) go_template_instance( name = "slot_list", @@ -27,7 +26,6 @@ go_library( "slot_list.go", "symlink.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index c5b79fb38..3768f55b2 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "proc", @@ -15,7 +14,6 @@ go_library( "tasks_net.go", "tasks_sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc", deps = [ "//pkg/abi/linux", "//pkg/log", @@ -47,7 +45,7 @@ go_test( "tasks_sys_test.go", "tasks_test.go", ], - embed = [":proc"], + library = ":proc", deps = [ "//pkg/abi/linux", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index ee3c842bd..beda141f1 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -1,14 +1,12 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "sys", srcs = [ "sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 4e70d84a7..12053a5b6 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -1,6 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "testutil", @@ -9,7 +9,6 @@ go_library( "kernel.go", "testutil.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 691476b4f..857e98bc5 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -1,8 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -package(licenses = ["notice"]) +licenses(["notice"]) go_template_instance( name = "dentry_list", @@ -28,7 +27,6 @@ go_library( "symlink.go", "tmpfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs", deps = [ "//pkg/abi/linux", "//pkg/amutex", @@ -81,7 +79,7 @@ go_test( "regular_file_test.go", "stat_test.go", ], - embed = [":tmpfs"], + library = ":tmpfs", deps = [ "//pkg/abi/linux", "//pkg/fspath", diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD index 359468ccc..e6933aa70 100644 --- a/pkg/sentry/hostcpu/BUILD +++ b/pkg/sentry/hostcpu/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "getcpu_arm64.s", "hostcpu.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu", visibility = ["//:sandbox"], ) @@ -18,5 +16,5 @@ go_test( name = "hostcpu_test", size = "small", srcs = ["hostcpu_test.go"], - embed = [":hostcpu"], + library = ":hostcpu", ) diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index 67831d5a1..a145a5ca3 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "cgroup.go", "hostmm.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/hostmm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/fd", diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 8d60ad4ad..aa621b724 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package( default_visibility = ["//:sandbox"], @@ -12,7 +12,6 @@ go_library( "inet.go", "test_stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/inet", deps = [ "//pkg/sentry/context", "//pkg/tcpip/stack", diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index ac85ba0c8..cebaccd92 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,8 +1,5 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -78,26 +75,12 @@ go_template_instance( ) proto_library( - name = "uncaught_signal_proto", + name = "uncaught_signal", srcs = ["uncaught_signal.proto"], visibility = ["//visibility:public"], deps = ["//pkg/sentry/arch:registers_proto"], ) -cc_proto_library( - name = "uncaught_signal_cc_proto", - visibility = ["//visibility:public"], - deps = [":uncaught_signal_proto"], -) - -go_proto_library( - name = "uncaught_signal_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto", - proto = ":uncaught_signal_proto", - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_go_proto"], -) - go_library( name = "kernel", srcs = [ @@ -156,7 +139,6 @@ go_library( "vdso.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel", imports = [ "gvisor.dev/gvisor/pkg/bpf", "gvisor.dev/gvisor/pkg/sentry/device", @@ -227,7 +209,7 @@ go_test( "task_test.go", "timekeeper_test.go", ], - embed = [":kernel"], + library = ":kernel", deps = [ "//pkg/abi", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 1aa72fa47..64537c9be 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -57,7 +57,6 @@ go_library( "id_map_set.go", "user_namespace.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/auth", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD index 3a88a585c..daff608d7 100644 --- a/pkg/sentry/kernel/contexttest/BUILD +++ b/pkg/sentry/kernel/contexttest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "contexttest", testonly = 1, srcs = ["contexttest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index c47f6b6fc..19e16ab3a 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "epoll_list.go", "epoll_state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/epoll", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/refs", @@ -43,7 +41,7 @@ go_test( srcs = [ "epoll_test.go", ], - embed = [":epoll"], + library = ":epoll", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/fs/filetest", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index c831fbab2..ee2d74864 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "eventfd", srcs = ["eventfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -26,7 +24,7 @@ go_test( name = "eventfd_test", size = "small", srcs = ["eventfd_test.go"], - embed = [":eventfd"], + library = ":eventfd", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/usermem", diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 6b36bc63e..b9126e946 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "fasync", srcs = ["fasync.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/fasync", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 50db443ce..f413d8ae2 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -34,7 +33,6 @@ go_library( "futex.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/futex", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -51,7 +49,7 @@ go_test( name = "futex_test", size = "small", srcs = ["futex_test.go"], - embed = [":futex"], + library = ":futex", deps = [ "//pkg/sentry/usermem", "//pkg/sync", diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index 7f36252a9..4486848d2 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -1,13 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) go_library( name = "memevent", srcs = ["memory_events.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent", visibility = ["//:sandbox"], deps = [ ":memory_events_go_proto", @@ -21,20 +18,7 @@ go_library( ) proto_library( - name = "memory_events_proto", + name = "memory_events", srcs = ["memory_events.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "memory_events_cc_proto", - visibility = ["//visibility:public"], - deps = [":memory_events_proto"], -) - -go_proto_library( - name = "memory_events_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto", - proto = ":memory_events_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 5eeaeff66..2c7b6206f 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -30,7 +29,6 @@ go_library( "vfs.go", "writer.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -57,7 +55,7 @@ go_test( "node_test.go", "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD index 98ea7a0d8..1b82e087b 100644 --- a/pkg/sentry/kernel/sched/BUILD +++ b/pkg/sentry/kernel/sched/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "cpuset.go", "sched.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/sched", visibility = ["//pkg/sentry:internal"], ) @@ -17,5 +15,5 @@ go_test( name = "sched_test", size = "small", srcs = ["cpuset_test.go"], - embed = [":sched"], + library = ":sched", ) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 13a961594..76e19b551 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "semaphore.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/semaphore", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "semaphore_test", size = "small", srcs = ["semaphore_test.go"], - embed = [":semaphore"], + library = ":semaphore", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 7321b22ed..5547c5abf 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "device.go", "shm.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/shm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 89e4d84b1..5d44773d4 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "signalfd", srcs = ["signalfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 4e4de0512..d49594d9f 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "context.go", "time.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/time", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 9fa841e8b..67869757f 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "limits.go", "linux.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/limits", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", @@ -25,5 +23,5 @@ go_test( srcs = [ "limits_test.go", ], - embed = [":limits"], + library = ":limits", ) diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 2890393bd..d4ad2bd6c 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_embed_data") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_embed_data", "go_library") package(licenses = ["notice"]) @@ -20,7 +19,6 @@ go_library( "vdso_state.go", ":vdso_bin", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/loader", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi", diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 112794e9c..f9a65f086 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -37,7 +36,6 @@ go_library( "mapping_set_impl.go", "memmap.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/memmap", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -52,6 +50,6 @@ go_test( name = "memmap_test", size = "small", srcs = ["mapping_set_test.go"], - embed = [":memmap"], + library = ":memmap", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 83e248431..bd6399fa2 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -96,7 +95,6 @@ go_library( "vma.go", "vma_set.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/mm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -128,7 +126,7 @@ go_test( name = "mm_test", size = "small", srcs = ["mm_test.go"], - embed = [":mm"], + library = ":mm", deps = [ "//pkg/sentry/arch", "//pkg/sentry/context", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index a9a2642c5..02385a3ce 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -60,7 +59,6 @@ go_library( "save_restore.go", "usage_set.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/pgalloc", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -82,6 +80,6 @@ go_test( name = "pgalloc_test", size = "small", srcs = ["pgalloc_test.go"], - embed = [":pgalloc"], + library = ":pgalloc", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 157bffa81..006450b2d 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +22,6 @@ go_library( "mmap_min_addr.go", "platform.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index 85e882df9..83b385f14 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "interrupt.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sync"], ) @@ -17,5 +15,5 @@ go_test( name = "interrupt_test", size = "small", srcs = ["interrupt_test.go"], - embed = [":interrupt"], + library = ":interrupt", ) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 6a358d1d4..a4532a766 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -38,7 +37,6 @@ go_library( "physical_map_arm64.go", "virtual_map.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -65,7 +63,7 @@ go_test( "kvm_test.go", "virtual_map_test.go", ], - embed = [":kvm"], + library = ":kvm", tags = [ "manual", "nogotsan", diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD index b0e45f159..f7605df8a 100644 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ b/pkg/sentry/platform/kvm/testutil/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,6 +12,5 @@ go_library( "testutil_arm64.go", "testutil_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil", visibility = ["//pkg/sentry/platform/kvm:__pkg__"], ) diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index cd13390c3..3bcc5e040 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -20,7 +20,6 @@ go_library( "subprocess_linux_unsafe.go", "subprocess_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ptrace", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 87f4552b5..6dee8fcc5 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -74,7 +74,6 @@ go_library( "lib_arm64.s", "ring0.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/cpuid", diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 42076fb04..147311ed3 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 387a7f6c3..8b5cdd6c1 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,17 +1,14 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test", "select_arch") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -config_setting( - name = "aarch64", - constraint_values = ["@bazel_tools//platforms:aarch64"], -) - go_template( name = "generic_walker", - srcs = ["walker_amd64.go"], + srcs = select_arch( + amd64 = ["walker_amd64.go"], + arm64 = ["walker_amd64.go"], + ), opt_types = [ "Visitor", ], @@ -91,7 +88,6 @@ go_library( "walker_map.go", "walker_unmap.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables", visibility = [ "//pkg/sentry/platform/kvm:__subpackages__", "//pkg/sentry/platform/ring0:__subpackages__", @@ -111,6 +107,6 @@ go_test( "pagetables_test.go", "walker_check.go", ], - embed = [":pagetables"], + library = ":pagetables", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD index 6769cd0a5..b8747585b 100644 --- a/pkg/sentry/platform/safecopy/BUILD +++ b/pkg/sentry/platform/safecopy/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -17,7 +16,6 @@ go_library( "sighandler_amd64.s", "sighandler_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/syserror"], ) @@ -27,5 +25,5 @@ go_test( srcs = [ "safecopy_test.go", ], - embed = [":safecopy"], + library = ":safecopy", ) diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD index 884020f7b..3ab76da97 100644 --- a/pkg/sentry/safemem/BUILD +++ b/pkg/sentry/safemem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "safemem.go", "seq_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/safemem", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/platform/safecopy", @@ -25,5 +23,5 @@ go_test( "io_test.go", "seq_test.go", ], - embed = [":safemem"], + library = ":safemem", ) diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sentry/sighandling/BUILD index f561670c7..6c38a3f44 100644 --- a/pkg/sentry/sighandling/BUILD +++ b/pkg/sentry/sighandling/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "sighandling.go", "sighandling_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/sighandling", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/abi/linux"], ) diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 26176b10d..8e2b97afb 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "socket", srcs = ["socket.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 357517ed4..3850f6345 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "control", srcs = ["control.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/control", imports = [ "gvisor.dev/gvisor/pkg/sentry/fs", ], diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 4c44c7c0f..42bf7be6a 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "socket_unsafe.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/hostinet", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index b70047d81..ed34a8308 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "netfilter.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netfilter", # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 103933144..baaac13c6 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "provider.go", "socket.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 2d9f4ba9b..3a22923d8 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "port", srcs = ["port.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sync"], ) @@ -14,5 +12,5 @@ go_library( go_test( name = "port_test", srcs = ["port_test.go"], - embed = [":port"], + library = ":port", ) diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 1d4912753..2137c7aeb 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "route", srcs = ["protocol.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD index 0777f3baf..73fbdf1eb 100644 --- a/pkg/sentry/socket/netlink/uevent/BUILD +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "uevent", srcs = ["protocol.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index f78784569..e3d1f90cb 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -11,7 +11,6 @@ go_library( "save_restore.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack", visibility = [ "//pkg/sentry:internal", ], diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 5b6a154f6..bade18686 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "io.go", "unix.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index d7ba95dff..4bdfc9208 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -25,7 +25,6 @@ go_library( "transport_message_list.go", "unix.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 88765f4d6..0ea4aab8b 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "state_metadata.go", "state_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/state", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index aa1ac720c..ff6fafa63 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) @@ -21,7 +19,6 @@ go_library( "strace.go", "syscalls.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/strace", visibility = ["//:sandbox"], deps = [ ":strace_go_proto", @@ -42,20 +39,7 @@ go_library( ) proto_library( - name = "strace_proto", + name = "strace", srcs = ["strace.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "strace_cc_proto", - visibility = ["//visibility:public"], - deps = [":strace_proto"], -) - -go_proto_library( - name = "strace_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto", - proto = ":strace_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD index 79d972202..b8d1bd415 100644 --- a/pkg/sentry/syscalls/BUILD +++ b/pkg/sentry/syscalls/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "epoll.go", "syscalls.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 917f74e07..7d74e0f70 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -57,7 +57,6 @@ go_library( "sys_xattr.go", "timespec.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux", visibility = ["//:sandbox"], deps = [ "//pkg/abi", diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 3cde3a0be..04f81a35b 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -31,7 +30,6 @@ go_library( "tsc_amd64.s", "tsc_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/time", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -48,5 +46,5 @@ go_test( "parameters_test.go", "sampler_test.go", ], - embed = [":time"], + library = ":time", ) diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD index fc7614fff..370fa6ec5 100644 --- a/pkg/sentry/unimpl/BUILD +++ b/pkg/sentry/unimpl/BUILD @@ -1,34 +1,17 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) proto_library( - name = "unimplemented_syscall_proto", + name = "unimplemented_syscall", srcs = ["unimplemented_syscall.proto"], visibility = ["//visibility:public"], deps = ["//pkg/sentry/arch:registers_proto"], ) -cc_proto_library( - name = "unimplemented_syscall_cc_proto", - visibility = ["//visibility:public"], - deps = [":unimplemented_syscall_proto"], -) - -go_proto_library( - name = "unimplemented_syscall_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto", - proto = ":unimplemented_syscall_proto", - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_go_proto"], -) - go_library( name = "unimpl", srcs = ["events.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD index 86a87edd4..e9c18f170 100644 --- a/pkg/sentry/uniqueid/BUILD +++ b/pkg/sentry/uniqueid/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "uniqueid", srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/uniqueid", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index 5518ac3d0..099315613 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -11,9 +11,8 @@ go_library( "memory_unsafe.go", "usage.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/usage", visibility = [ - "//pkg/sentry:internal", + "//:sandbox", ], deps = [ "//pkg/bits", diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index 684f59a6b..c8322e29e 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -29,7 +28,6 @@ go_library( "usermem_unsafe.go", "usermem_x86.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/usermem", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/atomicbitops", @@ -38,7 +36,6 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/safemem", "//pkg/syserror", - "//pkg/tcpip/buffer", ], ) @@ -49,7 +46,7 @@ go_test( "addr_range_seq_test.go", "usermem_test.go", ], - embed = [":usermem"], + library = ":usermem", deps = [ "//pkg/sentry/context", "//pkg/sentry/safemem", diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 35c7be259..51acdc4e9 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "vfs", @@ -24,7 +23,6 @@ go_library( "testutil.go", "vfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/vfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -47,7 +45,7 @@ go_test( "file_description_impl_util_test.go", "mount_test.go", ], - embed = [":vfs"], + library = ":vfs", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD index 28f21f13d..1c5a1c9b6 100644 --- a/pkg/sentry/watchdog/BUILD +++ b/pkg/sentry/watchdog/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "watchdog", srcs = ["watchdog.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/watchdog", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index a23c86fb1..e131455f7 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "commit_noasm.go", "sleep_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sleep", visibility = ["//:sandbox"], ) @@ -22,5 +20,5 @@ go_test( srcs = [ "sleep_test.go", ], - embed = [":sleep"], + library = ":sleep", ) diff --git a/pkg/state/BUILD b/pkg/state/BUILD index be93750bf..921af9d63 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -49,7 +47,7 @@ go_library( "state.go", "stats.go", ], - importpath = "gvisor.dev/gvisor/pkg/state", + stateify = False, visibility = ["//:sandbox"], deps = [ ":object_go_proto", @@ -58,21 +56,14 @@ go_library( ) proto_library( - name = "object_proto", + name = "object", srcs = ["object.proto"], visibility = ["//:sandbox"], ) -go_proto_library( - name = "object_go_proto", - importpath = "gvisor.dev/gvisor/pkg/state/object_go_proto", - proto = ":object_proto", - visibility = ["//:sandbox"], -) - go_test( name = "state_test", timeout = "long", srcs = ["state_test.go"], - embed = [":state"], + library = ":state", ) diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index 8a865d229..e7581c09b 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "statefile", srcs = ["statefile.go"], - importpath = "gvisor.dev/gvisor/pkg/state/statefile", visibility = ["//:sandbox"], deps = [ "//pkg/binary", @@ -18,6 +16,6 @@ go_test( name = "statefile_test", size = "small", srcs = ["statefile_test.go"], - embed = [":statefile"], + library = ":statefile", deps = ["//pkg/compressio"], ) diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 97c4b3b1e..5340cf0d6 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template") package( @@ -40,7 +39,6 @@ go_library( "syncutil.go", "tmutex_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sync", ) go_test( @@ -51,5 +49,5 @@ go_test( "seqcount_test.go", "tmutex_test.go", ], - embed = [":sync"], + library = ":sync", ) diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD index 418eda29c..e97553254 100644 --- a/pkg/sync/atomicptrtest/BUILD +++ b/pkg/sync/atomicptrtest/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -18,12 +17,11 @@ go_template_instance( go_library( name = "atomicptr", srcs = ["atomicptr_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr", ) go_test( name = "atomicptr_test", size = "small", srcs = ["atomicptr_test.go"], - embed = [":atomicptr"], + library = ":atomicptr", ) diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD index eba21518d..5c38c783e 100644 --- a/pkg/sync/seqatomictest/BUILD +++ b/pkg/sync/seqatomictest/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -18,7 +17,6 @@ go_template_instance( go_library( name = "seqatomic", srcs = ["seqatomic_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic", deps = [ "//pkg/sync", ], @@ -28,6 +26,6 @@ go_test( name = "seqatomic_test", size = "small", srcs = ["seqatomic_test.go"], - embed = [":seqatomic"], + library = ":seqatomic", deps = ["//pkg/sync"], ) diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD index 5665ad4ee..7d760344a 100644 --- a/pkg/syserr/BUILD +++ b/pkg/syserr/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "netstack.go", "syserr.go", ], - importpath = "gvisor.dev/gvisor/pkg/syserr", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index bd3f9fd28..b13c15d9b 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "syserror", srcs = ["syserror.go"], - importpath = "gvisor.dev/gvisor/pkg/syserror", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 23e4b09e7..26f7ba86b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "time_unsafe.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -25,7 +23,7 @@ go_test( name = "tcpip_test", size = "small", srcs = ["tcpip_test.go"], - embed = [":tcpip"], + library = ":tcpip", ) go_test( diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 3df7d18d3..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "gonet", srcs = ["gonet.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -23,7 +21,7 @@ go_test( name = "gonet_test", size = "small", srcs = ["gonet_test.go"], - embed = [":gonet"], + library = ":gonet", deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index d6c31bfa2..563bc78ea 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "prependable.go", "view.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer", visibility = ["//visibility:public"], ) @@ -17,5 +15,5 @@ go_test( name = "buffer_test", size = "small", srcs = ["view_test.go"], - embed = [":buffer"], + library = ":buffer", ) diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index b6fa6fc37..ed434807f 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "checker", testonly = 1, srcs = ["checker.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/checker", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index e648efa71..ff2719291 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "jenkins", srcs = ["jenkins.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", visibility = ["//visibility:public"], ) @@ -16,5 +14,5 @@ go_test( srcs = [ "jenkins_test.go", ], - embed = [":jenkins"], + library = ":jenkins", ) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index cd747d100..9da0d71f8 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "tcp.go", "udp.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/header", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -59,7 +57,7 @@ go_test( "eth_test.go", "ndp_test.go", ], - embed = [":header"], + library = ":header", deps = [ "//pkg/tcpip", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 297eaccaf..d1b73cfdf 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "targets.go", "types.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], deps = [ "//pkg/log", diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 7dbc05754..3974c464e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "channel", srcs = ["channel.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 66cc53ed4..abe725548 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -13,7 +12,6 @@ go_library( "mmap_unsafe.go", "packet_dispatchers.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -30,7 +28,7 @@ go_test( name = "fdbased_test", size = "small", srcs = ["endpoint_test.go"], - embed = [":fdbased"], + library = ":fdbased", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index f35fcdff4..6bf3805b7 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "loopback", srcs = ["loopback.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1ac7948b6..82b441b79 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "muxed", srcs = ["injectable.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -19,7 +17,7 @@ go_test( name = "muxed_test", size = "small", srcs = ["injectable_test.go"], - embed = [":muxed"], + library = ":muxed", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index d8211e93d..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "errors.go", "rawfile_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 09165dd4c..13243ebbb 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "sharedmem_unsafe.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -30,7 +28,7 @@ go_test( srcs = [ "sharedmem_test.go", ], - embed = [":sharedmem"], + library = ":sharedmem", deps = [ "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index a0d4ad0be..87020ec08 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", visibility = ["//visibility:public"], ) @@ -20,6 +18,6 @@ go_test( srcs = [ "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 8c9234d54..3ba06af73 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -22,7 +20,7 @@ go_test( srcs = [ "queue_test.go", ], - embed = [":queue"], + library = ":queue", deps = [ "//pkg/tcpip/link/sharedmem/pipe", ], diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index d6ae0368a..230a8d53a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "pcap.go", "sniffer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", visibility = ["//visibility:public"], deps = [ "//pkg/log", diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index a71a493fc..e5096ea38 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "tun", srcs = ["tun_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 134837943..0956d2c65 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "waitable.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", visibility = ["//visibility:public"], deps = [ "//pkg/gate", @@ -23,7 +21,7 @@ go_test( srcs = [ "waitable_test.go", ], - embed = [":waitable"], + library = ":waitable", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9d16ff8c9..6a4839fb8 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index e7617229b..eddf7b725 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "arp", srcs = ["arp.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index ed16076fd..d1c728ccf 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "reassembler.go", "reassembler_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -42,6 +40,6 @@ go_test( "fragmentation_test.go", "reassembler_test.go", ], - embed = [":fragmentation"], + library = ":fragmentation", deps = ["//pkg/tcpip/buffer"], ) diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD index e6db5c0b0..872165866 100644 --- a/pkg/tcpip/network/hash/BUILD +++ b/pkg/tcpip/network/hash/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "hash", srcs = ["hash.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash", visibility = ["//visibility:public"], deps = [ "//pkg/rand", diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 4e2aae9a3..0fef2b1f1 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv4.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index e4e273460..fb11874c6 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv6.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -27,7 +25,7 @@ go_test( "ipv6_test.go", "ndp_test.go", ], - embed = [":ipv6"], + library = ":ipv6", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index a6ef3bdcc..2bad05a2e 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "ports", srcs = ["ports.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -17,7 +15,7 @@ go_library( go_test( name = "ports_test", srcs = ["ports_test.go"], - embed = [":ports"], + library = ":ports", deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index d7496fde6..cf0a5fefe 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index 875561566..43264b76d 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index b31ddba2f..45f503845 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "seqnum", srcs = ["seqnum.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 783351a69..f5b750046 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -30,7 +29,6 @@ go_library( "stack_global_state.go", "transport_demuxer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", visibility = ["//visibility:public"], deps = [ "//pkg/ilist", @@ -81,7 +79,7 @@ go_test( name = "stack_test", size = "small", srcs = ["linkaddrcache_test.go"], - embed = [":stack"], + library = ":stack", deps = [ "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 3aa23d529..ac18ec5b1 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "icmp_packet_list.go", "protocol.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 4858d150c..d22de6b26 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +22,6 @@ go_library( "endpoint_state.go", "packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 2f2131ff7..c9baf4600 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "protocol.go", "raw_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 0e3ab05ad..4acd9fb9a 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -55,7 +54,6 @@ go_library( "tcp_segment_list.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index b33ec2087..ce6a2c31d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "context", testonly = 1, srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ "//visibility:public", ], diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 43fcc27f0..3ad6994a7 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "tcpconntrack", srcs = ["tcp_conntrack.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 57ff123e3..adc908e24 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -25,7 +24,6 @@ go_library( "protocol.go", "udp_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 07778e4f7..2dcba84ae 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "tmutex", srcs = ["tmutex.go"], - importpath = "gvisor.dev/gvisor/pkg/tmutex", visibility = ["//:sandbox"], ) @@ -14,6 +12,6 @@ go_test( name = "tmutex_test", size = "medium", srcs = ["tmutex_test.go"], - embed = [":tmutex"], + library = ":tmutex", deps = ["//pkg/sync"], ) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index d1885ae66..a86501fa2 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "unet.go", "unet_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/unet", visibility = ["//visibility:public"], deps = [ "//pkg/gate", @@ -23,6 +21,6 @@ go_test( srcs = [ "unet_test.go", ], - embed = [":unet"], + library = ":unet", deps = ["//pkg/sync"], ) diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b8fdc3125..850c34ed0 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "urpc", srcs = ["urpc.go"], - importpath = "gvisor.dev/gvisor/pkg/urpc", visibility = ["//:sandbox"], deps = [ "//pkg/fd", @@ -20,6 +18,6 @@ go_test( name = "urpc_test", size = "small", srcs = ["urpc_test.go"], - embed = [":urpc"], + library = ":urpc", deps = ["//pkg/unet"], ) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 1c6890e52..852480a09 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "waiter.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/waiter", visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) @@ -33,5 +31,5 @@ go_test( srcs = [ "waiter_test.go", ], - embed = [":waiter"], + library = ":waiter", ) diff --git a/runsc/BUILD b/runsc/BUILD index e5587421d..b35b41d81 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -1,7 +1,6 @@ -package(licenses = ["notice"]) # Apache 2.0 +load("//tools:defs.bzl", "go_binary", "pkg_deb", "pkg_tar") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") -load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar") +package(licenses = ["notice"]) go_binary( name = "runsc", @@ -9,7 +8,7 @@ go_binary( "main.go", "version.go", ], - pure = "on", + pure = True, visibility = [ "//visibility:public", ], @@ -26,10 +25,12 @@ go_binary( ) # The runsc-race target is a race-compatible BUILD target. This must be built -# via "bazel build --features=race //runsc:runsc-race", since the race feature -# must apply to all dependencies due a bug in gazelle file selection. The pure -# attribute must be off because the race detector requires linking with non-Go -# components, although we still require a static binary. +# via: bazel build --features=race //runsc:runsc-race +# +# This is neccessary because the race feature must apply to all dependencies +# due a bug in gazelle file selection. The pure attribute must be off because +# the race detector requires linking with non-Go components, although we still +# require a static binary. # # Note that in the future this might be convertible to a compatible target by # using the pure and static attributes within a select function, but select is @@ -42,7 +43,7 @@ go_binary( "main.go", "version.go", ], - static = "on", + static = True, visibility = [ "//visibility:public", ], @@ -82,7 +83,12 @@ genrule( # because they are assumes to be hermetic). srcs = [":runsc"], outs = ["version.txt"], - cmd = "$(location :runsc) -version | grep 'runsc version' | sed 's/^[^0-9]*//' > $@", + # Note that the little dance here is necessary because files in the $(SRCS) + # attribute are not executable by default, and we can't touch in place. + cmd = "cp $(location :runsc) $(@D)/runsc && \ + chmod a+x $(@D)/runsc && \ + $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \ + rm -f $(@D)/runsc", stamp = 1, ) @@ -109,5 +115,6 @@ sh_test( name = "version_test", size = "small", srcs = ["version_test.sh"], + args = ["$(location :runsc)"], data = [":runsc"], ) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 3e20f8f2f..f3ebc0231 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "strace.go", "user.go", ], - importpath = "gvisor.dev/gvisor/runsc/boot", visibility = [ "//runsc:__subpackages__", "//test:__subpackages__", @@ -107,7 +106,7 @@ go_test( "loader_test.go", "user_test.go", ], - embed = [":boot"], + library = ":boot", deps = [ "//pkg/control/server", "//pkg/log", diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD index 3a9dcfc04..ce30f6c53 100644 --- a/runsc/boot/filter/BUILD +++ b/runsc/boot/filter/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "extra_filters_race.go", "filter.go", ], - importpath = "gvisor.dev/gvisor/runsc/boot/filter", visibility = [ "//runsc/boot:__subpackages__", ], diff --git a/runsc/boot/platforms/BUILD b/runsc/boot/platforms/BUILD index 03391cdca..77774f43c 100644 --- a/runsc/boot/platforms/BUILD +++ b/runsc/boot/platforms/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "platforms", srcs = ["platforms.go"], - importpath = "gvisor.dev/gvisor/runsc/boot/platforms", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD index d6165f9e5..d4c7bdfbb 100644 --- a/runsc/cgroup/BUILD +++ b/runsc/cgroup/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "cgroup", srcs = ["cgroup.go"], - importpath = "gvisor.dev/gvisor/runsc/cgroup", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -19,6 +18,6 @@ go_test( name = "cgroup_test", size = "small", srcs = ["cgroup_test.go"], - embed = [":cgroup"], + library = ":cgroup", tags = ["local"], ) diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index b94bc4fa0..09aa46434 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -34,7 +34,6 @@ go_library( "syscalls.go", "wait.go", ], - importpath = "gvisor.dev/gvisor/runsc/cmd", visibility = [ "//runsc:__subpackages__", ], @@ -73,7 +72,7 @@ go_test( data = [ "//runsc", ], - embed = [":cmd"], + library = ":cmd", deps = [ "//pkg/abi/linux", "//pkg/log", diff --git a/runsc/console/BUILD b/runsc/console/BUILD index e623c1a0f..06924bccd 100644 --- a/runsc/console/BUILD +++ b/runsc/console/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "console.go", ], - importpath = "gvisor.dev/gvisor/runsc/console", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 6dea179e4..e21431e4c 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "state_file.go", "status.go", ], - importpath = "gvisor.dev/gvisor/runsc/container", visibility = [ "//runsc:__subpackages__", "//test:__subpackages__", @@ -42,7 +41,7 @@ go_test( "//runsc", "//runsc/container/test_app", ], - embed = [":container"], + library = ":container", shard_count = 5, tags = [ "requires-kvm", diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD index bfd338bb6..e200bafd9 100644 --- a/runsc/container/test_app/BUILD +++ b/runsc/container/test_app/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -9,7 +9,7 @@ go_binary( "fds.go", "test_app.go", ], - pure = "on", + pure = True, visibility = ["//runsc/container:__pkg__"], deps = [ "//pkg/unet", diff --git a/runsc/criutil/BUILD b/runsc/criutil/BUILD index 558133a0e..8a571a000 100644 --- a/runsc/criutil/BUILD +++ b/runsc/criutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "criutil", testonly = 1, srcs = ["criutil.go"], - importpath = "gvisor.dev/gvisor/runsc/criutil", visibility = ["//:sandbox"], deps = ["//runsc/testutil"], ) diff --git a/runsc/dockerutil/BUILD b/runsc/dockerutil/BUILD index 0e0423504..8621af901 100644 --- a/runsc/dockerutil/BUILD +++ b/runsc/dockerutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "dockerutil", testonly = 1, srcs = ["dockerutil.go"], - importpath = "gvisor.dev/gvisor/runsc/dockerutil", visibility = ["//:sandbox"], deps = [ "//runsc/testutil", diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index a9582d92b..64a406ae2 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,10 +10,7 @@ go_library( "fsgofer_arm64_unsafe.go", "fsgofer_unsafe.go", ], - importpath = "gvisor.dev/gvisor/runsc/fsgofer", - visibility = [ - "//runsc:__subpackages__", - ], + visibility = ["//runsc:__subpackages__"], deps = [ "//pkg/abi/linux", "//pkg/fd", @@ -30,7 +27,7 @@ go_test( name = "fsgofer_test", size = "small", srcs = ["fsgofer_test.go"], - embed = [":fsgofer"], + library = ":fsgofer", deps = [ "//pkg/log", "//pkg/p9", diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD index bac73f89d..82b48ef32 100644 --- a/runsc/fsgofer/filter/BUILD +++ b/runsc/fsgofer/filter/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "extra_filters_race.go", "filter.go", ], - importpath = "gvisor.dev/gvisor/runsc/fsgofer/filter", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index ddbc37456..c95d50294 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "network_unsafe.go", "sandbox.go", ], - importpath = "gvisor.dev/gvisor/runsc/sandbox", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD index 205638803..4ccd77f63 100644 --- a/runsc/specutils/BUILD +++ b/runsc/specutils/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "namespace.go", "specutils.go", ], - importpath = "gvisor.dev/gvisor/runsc/specutils", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", @@ -28,6 +27,6 @@ go_test( name = "specutils_test", size = "small", srcs = ["specutils_test.go"], - embed = [":specutils"], + library = ":specutils", deps = ["@com_github_opencontainers_runtime-spec//specs-go:go_default_library"], ) diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index 3c3027cb5..f845120b0 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "testutil", testonly = 1, srcs = ["testutil.go"], - importpath = "gvisor.dev/gvisor/runsc/testutil", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/runsc/version_test.sh b/runsc/version_test.sh index cc0ca3f05..747350654 100755 --- a/runsc/version_test.sh +++ b/runsc/version_test.sh @@ -16,7 +16,7 @@ set -euf -x -o pipefail -readonly runsc="${TEST_SRCDIR}/__main__/runsc/linux_amd64_pure_stripped/runsc" +readonly runsc="$1" readonly version=$($runsc --version) # Version should should not match VERSION, which is the default and which will diff --git a/scripts/common.sh b/scripts/common.sh index fdb1aa142..cd91b9f8e 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -16,11 +16,7 @@ set -xeou pipefail -if [[ -f $(dirname $0)/common_google.sh ]]; then - source $(dirname $0)/common_google.sh -else - source $(dirname $0)/common_bazel.sh -fi +source $(dirname $0)/common_build.sh # Ensure it attempts to collect logs in all cases. trap collect_logs EXIT diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh deleted file mode 100755 index a473a88a4..000000000 --- a/scripts/common_bazel.sh +++ /dev/null @@ -1,99 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Install the latest version of Bazel and log the version. -(which use_bazel.sh && use_bazel.sh latest) || which bazel -bazel version - -# Switch into the workspace; only necessary if run with kokoro. -if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then - cd git/repo -elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then - cd github/repo -fi - -# Set the standard bazel flags. -declare -r BAZEL_FLAGS=( - "--show_timestamps" - "--test_output=errors" - "--keep_going" - "--verbose_failures=true" -) -if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then - declare -r BAZEL_RBE_AUTH_FLAGS=( - "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" - ) - declare -r BAZEL_RBE_FLAGS=("--config=remote") -fi - -# Wrap bazel. -function build() { - bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 | - tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }' -} - -function test() { - bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" -} - -function run() { - local binary=$1 - shift - bazel run "${binary}" -- "$@" -} - -function run_as_root() { - local binary=$1 - shift - bazel run --run_under="sudo" "${binary}" -- "$@" -} - -function collect_logs() { - # Zip out everything into a convenient form. - if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then - # Merge results files of all shards for each test suite. - for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do - junitparser merge `find $d -name test.xml` $d/test.xml - cat $d/shard_*_of_*/test.log > $d/test.log - ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip - done - find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf - # Move test logs to Kokoro directory. tar is used to conveniently perform - # renames while moving files. - find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | - tar --create --files-from - --transform 's/test\./sponge_log./' | - tar --extract --directory ${KOKORO_ARTIFACTS_DIR} - - # Collect sentry logs, if any. - if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then - # Check if the directory is empty or not (only the first line it needed). - local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) - if [[ "${logs}" ]]; then - local -r archive=runsc_logs_"${RUNTIME}".tar.gz - if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then - echo "runsc logs will be uploaded to:" - echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" - echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" - fi - tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" . - fi - fi - fi -} - -function find_branch_name() { - git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename -} diff --git a/scripts/common_build.sh b/scripts/common_build.sh new file mode 100755 index 000000000..a473a88a4 --- /dev/null +++ b/scripts/common_build.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Install the latest version of Bazel and log the version. +(which use_bazel.sh && use_bazel.sh latest) || which bazel +bazel version + +# Switch into the workspace; only necessary if run with kokoro. +if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then + cd git/repo +elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then + cd github/repo +fi + +# Set the standard bazel flags. +declare -r BAZEL_FLAGS=( + "--show_timestamps" + "--test_output=errors" + "--keep_going" + "--verbose_failures=true" +) +if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then + declare -r BAZEL_RBE_AUTH_FLAGS=( + "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" + ) + declare -r BAZEL_RBE_FLAGS=("--config=remote") +fi + +# Wrap bazel. +function build() { + bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 | + tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }' +} + +function test() { + bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" +} + +function run() { + local binary=$1 + shift + bazel run "${binary}" -- "$@" +} + +function run_as_root() { + local binary=$1 + shift + bazel run --run_under="sudo" "${binary}" -- "$@" +} + +function collect_logs() { + # Zip out everything into a convenient form. + if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then + # Merge results files of all shards for each test suite. + for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do + junitparser merge `find $d -name test.xml` $d/test.xml + cat $d/shard_*_of_*/test.log > $d/test.log + ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip + done + find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf + # Move test logs to Kokoro directory. tar is used to conveniently perform + # renames while moving files. + find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | + tar --create --files-from - --transform 's/test\./sponge_log./' | + tar --extract --directory ${KOKORO_ARTIFACTS_DIR} + + # Collect sentry logs, if any. + if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then + # Check if the directory is empty or not (only the first line it needed). + local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) + if [[ "${logs}" ]]; then + local -r archive=runsc_logs_"${RUNTIME}".tar.gz + if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then + echo "runsc logs will be uploaded to:" + echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" + echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" + fi + tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" . + fi + fi + fi +} + +function find_branch_name() { + git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename +} diff --git a/test/BUILD b/test/BUILD index bf834d994..34b950644 100644 --- a/test/BUILD +++ b/test/BUILD @@ -1,44 +1 @@ -package(licenses = ["notice"]) # Apache 2.0 - -# We need to define a bazel platform and toolchain to specify dockerPrivileged -# and dockerRunAsRoot options, they are required to run tests on the RBE -# cluster in Kokoro. -alias( - name = "rbe_ubuntu1604", - actual = ":rbe_ubuntu1604_r346485", -) - -platform( - name = "rbe_ubuntu1604_r346485", - constraint_values = [ - "@bazel_tools//platforms:x86_64", - "@bazel_tools//platforms:linux", - "@bazel_tools//tools/cpp:clang", - "@bazel_toolchains//constraints:xenial", - "@bazel_toolchains//constraints/sanitizers:support_msan", - ], - remote_execution_properties = """ - properties: { - name: "container-image" - value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" - } - properties: { - name: "dockerAddCapabilities" - value: "SYS_ADMIN" - } - properties: { - name: "dockerPrivileged" - value: "true" - } - """, -) - -toolchain( - name = "cc-toolchain-clang-x86_64-default", - exec_compatible_with = [ - ], - target_compatible_with = [ - ], - toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", - toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", -) +package(licenses = ["notice"]) diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 4fe03a220..76e04f878 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,7 @@ go_test( "integration_test.go", "regression_test.go", ], - embed = [":integration"], + library = ":integration", tags = [ # Requires docker and runsc to be configured before the test runs. "manual", @@ -29,5 +29,4 @@ go_test( go_library( name = "integration", srcs = ["integration.go"], - importpath = "gvisor.dev/gvisor/test/integration", ) diff --git a/test/image/BUILD b/test/image/BUILD index 09b0a0ad5..7392ac54e 100644 --- a/test/image/BUILD +++ b/test/image/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -14,7 +14,7 @@ go_test( "ruby.rb", "ruby.sh", ], - embed = [":image"], + library = ":image", tags = [ # Requires docker and runsc to be configured before the test runs. "manual", @@ -30,5 +30,4 @@ go_test( go_library( name = "image", srcs = ["image.go"], - importpath = "gvisor.dev/gvisor/test/image", ) diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 22f470092..6bb3b82b5 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "iptables_util.go", "nat.go", ], - importpath = "gvisor.dev/gvisor/test/iptables", visibility = ["//test/iptables:__subpackages__"], deps = [ "//runsc/testutil", @@ -24,7 +23,7 @@ go_test( srcs = [ "iptables_test.go", ], - embed = [":iptables"], + library = ":iptables", tags = [ "local", "manual", diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD index a5b6f082c..b9199387a 100644 --- a/test/iptables/runner/BUILD +++ b/test/iptables/runner/BUILD @@ -1,15 +1,21 @@ -load("@io_bazel_rules_docker//go:image.bzl", "go_image") -load("@io_bazel_rules_docker//container:container.bzl", "container_image") +load("//tools:defs.bzl", "container_image", "go_binary", "go_image") package(licenses = ["notice"]) +go_binary( + name = "runner", + testonly = 1, + srcs = ["main.go"], + deps = ["//test/iptables"], +) + container_image( name = "iptables-base", base = "@iptables-test//image", ) go_image( - name = "runner", + name = "runner-image", testonly = 1, srcs = ["main.go"], base = ":iptables-base", diff --git a/test/root/BUILD b/test/root/BUILD index d5dd9bca2..23ce2a70f 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "root", srcs = ["root.go"], - importpath = "gvisor.dev/gvisor/test/root", ) go_test( @@ -21,7 +20,7 @@ go_test( data = [ "//runsc", ], - embed = [":root"], + library = ":root", tags = [ # Requires docker and runsc to be configured before the test runs. # Also test only runs as root. diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD index 125633680..bca5f9cab 100644 --- a/test/root/testdata/BUILD +++ b/test/root/testdata/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "sandbox.go", "simple.go", ], - importpath = "gvisor.dev/gvisor/test/root/testdata", visibility = [ "//visibility:public", ], diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 367295206..2c472bf8d 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,6 +1,6 @@ # These packages are used to run language runtime tests inside gVisor sandboxes. -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_test") load("//test/runtimes:build_defs.bzl", "runtime_test") package(licenses = ["notice"]) @@ -49,5 +49,5 @@ go_test( name = "blacklist_test", size = "small", srcs = ["blacklist_test.go"], - embed = [":runner"], + library = ":runner", ) diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl index 6f84ca852..92e275a76 100644 --- a/test/runtimes/build_defs.bzl +++ b/test/runtimes/build_defs.bzl @@ -1,6 +1,6 @@ """Defines a rule for runtime test targets.""" -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test", "loopback") def runtime_test( name, @@ -34,6 +34,7 @@ def runtime_test( ] data = [ ":runner", + loopback, ] if blacklist_file: args += ["--blacklist_file", "test/runtimes/" + blacklist_file] @@ -61,7 +62,7 @@ def blacklist_test(name, blacklist_file): """Test that a blacklist parses correctly.""" go_test( name = name + "_blacklist_test", - embed = [":runner"], + library = ":runner", srcs = ["blacklist_test.go"], args = ["--blacklist_file", "test/runtimes/" + blacklist_file], data = [blacklist_file], diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/images/proctor/BUILD index 09dc6c42f..85e004c45 100644 --- a/test/runtimes/images/proctor/BUILD +++ b/test/runtimes/images/proctor/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_test") package(licenses = ["notice"]) @@ -19,7 +19,7 @@ go_test( name = "proctor_test", size = "small", srcs = ["proctor_test.go"], - embed = [":proctor"], + library = ":proctor", deps = [ "//runsc/testutil", ], diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 90d52e73b..40e974314 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") load("//test/syscalls:build_defs.bzl", "syscall_test") package(licenses = ["notice"]) diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index aaf77c65b..1df761dd0 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -1,5 +1,7 @@ """Defines a rule for syscall test targets.""" +load("//tools:defs.bzl", "loopback") + # syscall_test is a macro that will create targets to run the given test target # on the host (native) and runsc. def syscall_test( @@ -135,6 +137,7 @@ def _syscall_test( name = name, data = [ ":syscall_test_runner", + loopback, test, ], args = args, @@ -148,6 +151,3 @@ def sh_test(**kwargs): native.sh_test( **kwargs ) - -def select_for_linux(for_linux, for_others = []): - return for_linux diff --git a/test/syscalls/gtest/BUILD b/test/syscalls/gtest/BUILD index 9293f25cb..de4b2727c 100644 --- a/test/syscalls/gtest/BUILD +++ b/test/syscalls/gtest/BUILD @@ -1,12 +1,9 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "gtest", srcs = ["gtest.go"], - importpath = "gvisor.dev/gvisor/test/syscalls/gtest", - visibility = [ - "//test:__subpackages__", - ], + visibility = ["//:sandbox"], ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 4c7ec3f06..c2ef50c1d 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,5 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") -load("//test/syscalls:build_defs.bzl", "select_for_linux") +load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "select_system") package( default_visibility = ["//:sandbox"], @@ -126,13 +125,11 @@ cc_library( testonly = 1, srcs = [ "socket_test_util.cc", - ] + select_for_linux( - [ - "socket_test_util_impl.cc", - ], - ), + "socket_test_util_impl.cc", + ], hdrs = ["socket_test_util.h"], - deps = [ + defines = select_system(), + deps = default_net_util() + [ "@com_google_googletest//:gtest", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -143,8 +140,7 @@ cc_library( "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", - ] + select_for_linux([ - ]), + ], ) cc_library( @@ -1443,6 +1439,7 @@ cc_binary( srcs = ["arch_prctl.cc"], linkstatic = 1, deps = [ + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", "@com_google_googletest//:gtest", @@ -3383,11 +3380,11 @@ cc_library( name = "udp_socket_test_cases", testonly = 1, srcs = [ - "udp_socket_test_cases.cc", - ] + select_for_linux([ "udp_socket_errqueue_test_case.cc", - ]), + "udp_socket_test_cases.cc", + ], hdrs = ["udp_socket_test_cases.h"], + defines = select_system(), deps = [ ":socket_test_util", ":unix_domain_socket_test_util", diff --git a/test/syscalls/linux/arch_prctl.cc b/test/syscalls/linux/arch_prctl.cc index 81bf5a775..3a901faf5 100644 --- a/test/syscalls/linux/arch_prctl.cc +++ b/test/syscalls/linux/arch_prctl.cc @@ -14,8 +14,10 @@ #include #include +#include #include "gtest/gtest.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" // glibc does not provide a prototype for arch_prctl() so declare it here. diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD index 5cfe4e56f..ed488dbc2 100644 --- a/test/syscalls/linux/rseq/BUILD +++ b/test/syscalls/linux/rseq/BUILD @@ -1,8 +1,7 @@ # This package contains a standalone rseq test binary. This binary must not # depend on libc, which might use rseq itself. -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier") -load("@rules_cc//cc:defs.bzl", "cc_library") +load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain") package(licenses = ["notice"]) @@ -37,8 +36,8 @@ genrule( "$(location start.S)", ]), toolchains = [ + cc_toolchain, ":no_pie_cc_flags", - "@bazel_tools//tools/cpp:current_cc_toolchain", ], visibility = ["//:sandbox"], ) diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc index 147978f46..9a24e1df0 100644 --- a/test/syscalls/linux/udp_socket_errqueue_test_case.cc +++ b/test/syscalls/linux/udp_socket_errqueue_test_case.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __fuchsia__ + #include "test/syscalls/linux/udp_socket_test_cases.h" #include @@ -52,3 +54,5 @@ TEST_P(UdpSocketTest, ErrorQueue) { } // namespace testing } // namespace gvisor + +#endif // __fuchsia__ diff --git a/test/uds/BUILD b/test/uds/BUILD index a3843e699..51e2c7ce8 100644 --- a/test/uds/BUILD +++ b/test/uds/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package( default_visibility = ["//:sandbox"], @@ -9,7 +9,6 @@ go_library( name = "uds", testonly = 1, srcs = ["uds.go"], - importpath = "gvisor.dev/gvisor/test/uds", deps = [ "//pkg/log", "//pkg/unet", diff --git a/test/util/BUILD b/test/util/BUILD index cbc728159..3c732be62 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,5 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") -load("//test/syscalls:build_defs.bzl", "select_for_linux") +load("//tools:defs.bzl", "cc_library", "cc_test", "select_system") package( default_visibility = ["//:sandbox"], @@ -142,12 +141,13 @@ cc_library( cc_library( name = "save_util", testonly = 1, - srcs = ["save_util.cc"] + - select_for_linux( - ["save_util_linux.cc"], - ["save_util_other.cc"], - ), + srcs = [ + "save_util.cc", + "save_util_linux.cc", + "save_util_other.cc", + ], hdrs = ["save_util.h"], + defines = select_system(), ) cc_library( @@ -234,13 +234,16 @@ cc_library( testonly = 1, srcs = [ "test_util.cc", - ] + select_for_linux( - [ - "test_util_impl.cc", - "test_util_runfiles.cc", + "test_util_impl.cc", + "test_util_runfiles.cc", + ], + hdrs = ["test_util.h"], + defines = select_system( + fuchsia = [ + "__opensource__", + "__fuchsia__", ], ), - hdrs = ["test_util.h"], deps = [ ":fs_util", ":logging", diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc index cd56118c0..d0aea8e6a 100644 --- a/test/util/save_util_linux.cc +++ b/test/util/save_util_linux.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifdef __linux__ + #include #include #include @@ -43,3 +45,5 @@ void MaybeSave() { } // namespace testing } // namespace gvisor + +#endif diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc index 1aca663b7..931af2c29 100644 --- a/test/util/save_util_other.cc +++ b/test/util/save_util_other.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __linux__ + namespace gvisor { namespace testing { @@ -21,3 +23,5 @@ void MaybeSave() { } // namespace testing } // namespace gvisor + +#endif diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc index 7210094eb..694d21692 100644 --- a/test/util/test_util_runfiles.cc +++ b/test/util/test_util_runfiles.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __fuchsia__ + #include #include @@ -44,3 +46,5 @@ std::string RunfilePath(std::string path) { } // namespace testing } // namespace gvisor + +#endif // __fuchsia__ diff --git a/tools/BUILD b/tools/BUILD new file mode 100644 index 000000000..e73a9c885 --- /dev/null +++ b/tools/BUILD @@ -0,0 +1,3 @@ +package(licenses = ["notice"]) + +exports_files(["nogo.js"]) diff --git a/tools/build/BUILD b/tools/build/BUILD new file mode 100644 index 000000000..0c0ce3f4d --- /dev/null +++ b/tools/build/BUILD @@ -0,0 +1,10 @@ +package(licenses = ["notice"]) + +# In bazel, no special support is required for loopback networking. This is +# just a dummy data target that does not change the test environment. +genrule( + name = "loopback", + outs = ["loopback.txt"], + cmd = "touch $@", + visibility = ["//visibility:public"], +) diff --git a/tools/build/defs.bzl b/tools/build/defs.bzl new file mode 100644 index 000000000..d0556abd1 --- /dev/null +++ b/tools/build/defs.bzl @@ -0,0 +1,91 @@ +"""Bazel implementations of standard rules.""" + +load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") +load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library") +load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library") +load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") +load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") +load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") +load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") +load("@pydeps//:requirements.bzl", _py_requirement = "requirement") + +container_image = _container_image +cc_binary = _cc_binary +cc_library = _cc_library +cc_flags_supplier = _cc_flags_supplier +cc_proto_library = _cc_proto_library +cc_test = _cc_test +cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" +go_image = _go_image +go_embed_data = _go_embed_data +loopback = "//tools/build:loopback" +proto_library = native.proto_library +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_library = native.py_library +py_binary = native.py_binary +py_test = native.py_test + +def go_binary(name, static = False, pure = False, **kwargs): + if static: + kwargs["static"] = "on" + if pure: + kwargs["pure"] = "on" + _go_binary( + name = name, + **kwargs + ) + +def go_library(name, **kwargs): + _go_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_tool_library(name, **kwargs): + _go_tool_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_proto_library(name, proto, **kwargs): + deps = kwargs.pop("deps", []) + _go_proto_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name, + proto = proto, + deps = [dep.replace("_proto", "_go_proto") for dep in deps], + **kwargs + ) + +def go_test(name, **kwargs): + library = kwargs.pop("library", None) + if library: + kwargs["embed"] = [library] + _go_test( + name = name, + **kwargs + ) + +def py_requirement(name, direct = False): + return _py_requirement(name) + +def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): + values = { + "@bazel_tools//src/conditions:linux_x86_64": amd64, + "@bazel_tools//src/conditions:linux_aarch64": arm64, + } + if default: + values["//conditions:default"] = default + return select(values, **kwargs) + +def select_system(linux = ["__linux__"], **kwargs): + return linux # Only Linux supported. + +def default_installer(): + return None + +def default_net_util(): + return [] # Nothing needed. diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD index d85c56131..92ba8ab06 100644 --- a/tools/checkunsafe/BUILD +++ b/tools/checkunsafe/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_tool_library") +load("//tools:defs.bzl", "go_tool_library") package(licenses = ["notice"]) go_tool_library( name = "checkunsafe", srcs = ["check_unsafe.go"], - importpath = "checkunsafe", visibility = ["//visibility:public"], deps = [ "@org_golang_x_tools//go/analysis:go_tool_library", diff --git a/tools/defs.bzl b/tools/defs.bzl new file mode 100644 index 000000000..819f12b0d --- /dev/null +++ b/tools/defs.bzl @@ -0,0 +1,154 @@ +"""Wrappers for common build rules. + +These wrappers apply common BUILD configurations (e.g., proto_library +automagically creating cc_ and go_ proto targets) and act as a single point of +change for Google-internal and bazel-compatible rules. +""" + +load("//tools/go_stateify:defs.bzl", "go_stateify") +load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") +load("//tools/build:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") + +# Delegate directly. +cc_binary = _cc_binary +cc_library = _cc_library +cc_test = _cc_test +cc_toolchain = _cc_toolchain +cc_flags_supplier = _cc_flags_supplier +container_image = _container_image +go_embed_data = _go_embed_data +go_image = _go_image +go_test = _go_test +go_tool_library = _go_tool_library +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_library = _py_library +py_binary = _py_binary +py_test = _py_test +py_requirement = _py_requirement +select_arch = _select_arch +select_system = _select_system +loopback = _loopback +default_installer = _default_installer +default_net_util = _default_net_util + +def go_binary(name, **kwargs): + """Wraps the standard go_binary. + + Args: + name: the rule name. + **kwargs: standard go_binary arguments. + """ + _go_binary( + name = name, + **kwargs + ) + +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): + """Wraps the standard go_library and does stateification and marshalling. + + The recommended way is to use this rule with mostly identical configuration as the native + go_library rule. + + These definitions provide additional flags (stateify, marshal) that can be used + with the generators to automatically supplement the library code. + + load("//tools:defs.bzl", "go_library") + + go_library( + name = "foo", + srcs = ["foo.go"], + ) + + Args: + name: the rule name. + srcs: the library sources. + deps: the library dependencies. + imports: imports required for stateify. + stateify: whether statify is enabled (default: true). + marshal: whether marshal is enabled (default: false). + **kwargs: standard go_library arguments. + """ + if stateify: + # Only do stateification for non-state packages without manual autogen. + go_stateify( + name = name + "_state_autogen", + srcs = [src for src in srcs if src.endswith(".go")], + imports = imports, + package = name, + arch = select_arch(), + out = name + "_state_autogen.go", + ) + all_srcs = srcs + [name + "_state_autogen.go"] + if "//pkg/state" not in deps: + all_deps = deps + ["//pkg/state"] + else: + all_deps = deps + else: + all_deps = deps + all_srcs = srcs + if marshal: + go_marshal( + name = name + "_abi_autogen", + srcs = [src for src in srcs if src.endswith(".go")], + debug = False, + imports = imports, + package = name, + ) + extra_deps = [ + dep + for dep in marshal_deps + if not dep in all_deps + ] + all_deps = all_deps + extra_deps + all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] + + _go_library( + name = name, + srcs = all_srcs, + deps = all_deps, + **kwargs + ) + + if marshal: + # Ignore importpath for go_test. + kwargs.pop("importpath", None) + + _go_test( + name = name + "_abi_autogen_test", + srcs = [name + "_abi_autogen_test.go"], + library = ":" + name, + deps = marshal_test_deps, + **kwargs + ) + +def proto_library(name, srcs, **kwargs): + """Wraps the standard proto_library. + + Given a proto_library named "foo", this produces three different targets: + - foo_proto: proto_library rule. + - foo_go_proto: go_proto_library rule. + - foo_cc_proto: cc_proto_library rule. + + Args: + srcs: the proto sources. + **kwargs: standard proto_library arguments. + """ + deps = kwargs.pop("deps", []) + _proto_library( + name = name + "_proto", + srcs = srcs, + deps = deps, + **kwargs + ) + _go_proto_library( + name = name + "_go_proto", + proto = ":" + name + "_proto", + deps = deps, + **kwargs + ) + _cc_proto_library( + name = name + "_cc_proto", + deps = [":" + name + "_proto"], + **kwargs + ) diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD index 39318b877..069df3856 100644 --- a/tools/go_generics/BUILD +++ b/tools/go_generics/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD index 74853c7d2..38caa3ce7 100644 --- a/tools/go_generics/globals/BUILD +++ b/tools/go_generics/globals/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,6 +8,6 @@ go_library( "globals_visitor.go", "scope.go", ], - importpath = "gvisor.dev/gvisor/tools/go_generics/globals", + stateify = False, visibility = ["//tools/go_generics:__pkg__"], ) diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD index 02b09120e..b7d35e272 100644 --- a/tools/go_generics/go_merge/BUILD +++ b/tools/go_generics/go_merge/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD index 9d26a88b7..8a329dfc6 100644 --- a/tools/go_generics/rules_tests/BUILD +++ b/tools/go_generics/rules_tests/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index c862b277c..80d9c0504 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -1,6 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") -package(licenses = ["notice"]) +licenses(["notice"]) go_binary( name = "go_marshal", diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 481575bd3..4886efddf 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -20,19 +20,7 @@ comment `// +marshal`. # Usage -See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`. - -The recommended way to generate a go library with marshalling is to use the -`go_library` with mostly identical configuration as the native go_library rule. - -``` -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) -``` +See `defs.bzl`: a new rule is provided, `go_marshal`. Under the hood, the `go_marshal` rule is used to generate a file that will appear in a Go target; the output file should appear explicitly in a srcs list. @@ -54,11 +42,7 @@ go_library( "foo.go", "foo_abi.go", ], - deps = [ - "/gvisor/pkg/abi", - "/gvisor/pkg/sentry/safemem/safemem", - "/gvisor/pkg/sentry/usermem/usermem", - ], + ... ) ``` @@ -69,22 +53,6 @@ These tests use reflection to verify properties of the ABI struct, and should be considered part of the generated interfaces (but are too expensive to execute at runtime). Ensure these tests run at some point. -``` -$ cat BUILD -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) -$ blaze build :foo -$ blaze query ... -:foo_abi_autogen -:foo_abi_autogen_test -$ blaze test :foo_abi_autogen_test - -``` - # Restrictions Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is @@ -131,22 +99,6 @@ for embedded structs that are not aligned. Because of this, it's generally best to avoid using `marshal:"unaligned"` and insert explicit padding fields instead. -## Debugging go_marshal - -To enable debugging output from the go marshal tool, pass the `-debug` flag to -the tool. When using the build rules from above, add a `debug = True` field to -the build rule like this: - -``` -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], - debug = True, -) -``` - ## Modifying the `go_marshal` Tool The following are some guidelines for modifying the `go_marshal` tool: diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD index c859ced77..c2a4d45c4 100644 --- a/tools/go_marshal/analysis/BUILD +++ b/tools/go_marshal/analysis/BUILD @@ -1,12 +1,11 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "analysis", testonly = 1, srcs = ["analysis_unsafe.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis", visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index c32eb559f..2918ceffe 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -1,57 +1,14 @@ -"""Marshal is a tool for generating marshalling interfaces for Go types. - -The recommended way is to use the go_library rule defined below with mostly -identical configuration as the native go_library rule. - -load("//tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) - -Under the hood, the go_marshal rule is used to generate a file that will -appear in a Go target; the output file should appear explicitly in a srcs list. -For example (the above is still the preferred way): - -load("//tools/go_marshal:defs.bzl", "go_marshal") - -go_marshal( - name = "foo_abi", - srcs = ["foo.go"], - out = "foo_abi.go", - package = "foo", -) - -go_library( - name = "foo", - srcs = [ - "foo.go", - "foo_abi.go", - ], - deps = [ - "//tools/go_marshal:marshal", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", - ], -) -""" - -load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test") +"""Marshal is a tool for generating marshalling interfaces for Go types.""" def _go_marshal_impl(ctx): """Execute the go_marshal tool.""" output = ctx.outputs.lib output_test = ctx.outputs.test - (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD") - - decl = "/".join(["gvisor.dev/gvisor", build_dir]) # Run the marshal command. args = ["-output=%s" % output.path] args += ["-pkg=%s" % ctx.attr.package] args += ["-output_test=%s" % output_test.path] - args += ["-declarationPkg=%s" % decl] if ctx.attr.debug: args += ["-debug"] @@ -83,7 +40,6 @@ go_marshal = rule( implementation = _go_marshal_impl, attrs = { "srcs": attr.label_list(mandatory = True, allow_files = True), - "libname": attr.string(mandatory = True), "imports": attr.string_list(mandatory = False), "package": attr.string(mandatory = True), "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"), @@ -95,58 +51,14 @@ go_marshal = rule( }, ) -def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs): - """wraps the standard go_library and does mashalling interface generation. - - Args: - name: Same as native go_library. - srcs: Same as native go_library. - deps: Same as native go_library. - imports: Extra import paths to pass to the go_marshal tool. - debug: Enables debugging output from the go_marshal tool. - **kwargs: Remaining args to pass to the native go_library rule unmodified. - """ - go_marshal( - name = name + "_abi_autogen", - libname = name, - srcs = [src for src in srcs if src.endswith(".go")], - debug = debug, - imports = imports, - package = name, - ) - - extra_deps = [ - "//tools/go_marshal/marshal", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", - ] - - all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] - all_deps = deps + [] # + extra_deps - - for extra in extra_deps: - if extra not in deps: - all_deps.append(extra) - - _go_library( - name = name, - srcs = all_srcs, - deps = all_deps, - **kwargs - ) - - # Don't pass importpath arg to go_test. - kwargs.pop("importpath", "") - - _go_test( - name = name + "_abi_autogen_test", - srcs = [name + "_abi_autogen_test.go"], - # Generated test has a fixed set of dependencies since we generate these - # tests. They should only depend on the library generated above, and the - # Marshallable interface. - deps = [ - ":" + name, - "//tools/go_marshal/analysis", - ], - **kwargs - ) +# marshal_deps are the dependencies requied by generated code. +marshal_deps = [ + "//tools/go_marshal/marshal", + "//pkg/sentry/platform/safecopy", + "//pkg/sentry/usermem", +] + +# marshal_test_deps are required by test targets. +marshal_test_deps = [ + "//tools/go_marshal/analysis", +] diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index a0eae6492..c92b59dd6 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -1,6 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "gomarshal", @@ -10,7 +10,7 @@ go_library( "generator_tests.go", "util.go", ], - importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal", + stateify = False, visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 641ccd938..8392f3f6d 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -62,15 +62,12 @@ type Generator struct { outputTest *os.File // Package name for the generated file. pkg string - // Go import path for package we're processing. This package should directly - // declare the type we're generating code for. - declaration string // Set of extra packages to import in the generated file. imports *importTable } // NewGenerator creates a new code Generator. -func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) { +func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) { f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) @@ -80,12 +77,11 @@ func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) } g := Generator{ - inputs: srcs, - output: f, - outputTest: fTest, - pkg: pkg, - declaration: declaration, - imports: newImportTable(), + inputs: srcs, + output: f, + outputTest: fTest, + pkg: pkg, + imports: newImportTable(), } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as @@ -264,7 +260,7 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface // generateOneTestSuite generates a test suite for the automatically generated // implementations type t. func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator { - i := newTestGenerator(t, g.declaration) + i := newTestGenerator(t) i.emitTests() return i } @@ -359,7 +355,7 @@ func (g *Generator) Run() error { // source file. func (g *Generator) writeTests(ts []*testGenerator) error { var b sourceBuffer - b.emit("package %s_test\n\n", g.pkg) + b.emit("package %s\n\n", g.pkg) if err := b.write(g.outputTest); err != nil { return err } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index df25cb5b2..bcda17c3b 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -46,7 +46,7 @@ type testGenerator struct { decl *importStmt } -func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator { +func newTestGenerator(t *ast.TypeSpec) *testGenerator { if _, ok := t.Type.(*ast.StructType); !ok { panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) } @@ -59,14 +59,12 @@ func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator { for _, i := range standardImports { g.imports.add(i).markUsed() } - g.decl = g.imports.add(declaration) - g.decl.markUsed() return g } func (g *testGenerator) typeName() string { - return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name) + return g.t.Name.Name } func (g *testGenerator) forEachField(fn func(f *ast.Field)) { diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go index 3d12eb93c..e1a97b311 100644 --- a/tools/go_marshal/main.go +++ b/tools/go_marshal/main.go @@ -31,11 +31,10 @@ import ( ) var ( - pkg = flag.String("pkg", "", "output package") - output = flag.String("output", "", "output file") - outputTest = flag.String("output_test", "", "output file for tests") - imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") - declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on") + pkg = flag.String("pkg", "", "output package") + output = flag.String("output", "", "output file") + outputTest = flag.String("output_test", "", "output file for tests") + imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") ) func main() { @@ -62,7 +61,7 @@ func main() { // as an import. extraImports = strings.Split(*imports, ",") } - g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports) + g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports) if err != nil { panic(err) } diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD index 47dda97a1..ad508c72f 100644 --- a/tools/go_marshal/marshal/BUILD +++ b/tools/go_marshal/marshal/BUILD @@ -1,13 +1,12 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "marshal", srcs = [ "marshal.go", ], - importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal", visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index d412e1ccf..38ba49fed 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,7 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_marshal:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) package_group( name = "gomarshal_test", @@ -25,6 +24,6 @@ go_library( name = "test", testonly = 1, srcs = ["test.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/test", + marshal = True, deps = ["//tools/go_marshal/test/external"], ) diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD index 9bb89e1da..0cf6da603 100644 --- a/tools/go_marshal/test/external/BUILD +++ b/tools/go_marshal/test/external/BUILD @@ -1,11 +1,11 @@ -load("//tools/go_marshal:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "external", testonly = 1, srcs = ["external.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external", + marshal = True, visibility = ["//tools/go_marshal/test:gomarshal_test"], ) diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index bb53f8ae9..a133d6f8b 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index 33267c074..0f261d89f 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -1,41 +1,4 @@ -"""Stateify is a tool for generating state wrappers for Go types. - -The recommended way is to use the go_library rule defined below with mostly -identical configuration as the native go_library rule. - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) - -Under the hood, the go_stateify rule is used to generate a file that will -appear in a Go target; the output file should appear explicitly in a srcs list. -For example (the above is still the preferred way): - -load("//tools/go_stateify:defs.bzl", "go_stateify") - -go_stateify( - name = "foo_state", - srcs = ["foo.go"], - out = "foo_state.go", - package = "foo", -) - -go_library( - name = "foo", - srcs = [ - "foo.go", - "foo_state.go", - ], - deps = [ - "//pkg/state", - ], -) -""" - -load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library") +"""Stateify is a tool for generating state wrappers for Go types.""" def _go_stateify_impl(ctx): """Implementation for the stateify tool.""" @@ -103,43 +66,3 @@ files and must be added to the srcs of the relevant go_library. "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), }, ) - -def go_library(name, srcs, deps = [], imports = [], **kwargs): - """Standard go_library wrapped which generates state source files. - - Args: - name: the name of the go_library rule. - srcs: sources of the go_library. Each will be processed for stateify - annotations. - deps: dependencies for the go_library. - imports: an optional list of extra non-aliased, Go-style absolute import - paths required for stateified types. - **kwargs: passed to go_library. - """ - if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs: - # Only do stateification for non-state packages without manual autogen. - go_stateify( - name = name + "_state_autogen", - srcs = [src for src in srcs if src.endswith(".go")], - imports = imports, - package = name, - arch = select({ - "@bazel_tools//src/conditions:linux_aarch64": "arm64", - "//conditions:default": "amd64", - }), - out = name + "_state_autogen.go", - ) - all_srcs = srcs + [name + "_state_autogen.go"] - if "//pkg/state" not in deps: - all_deps = deps + ["//pkg/state"] - else: - all_deps = deps - else: - all_deps = deps - all_srcs = srcs - _go_library( - name = name, - srcs = all_srcs, - deps = all_deps, - **kwargs - ) diff --git a/tools/images/BUILD b/tools/images/BUILD index 2b77c2737..f1699b184 100644 --- a/tools/images/BUILD +++ b/tools/images/BUILD @@ -1,4 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary") +load("//tools:defs.bzl", "cc_binary") load("//tools/images:defs.bzl", "vm_image", "vm_test") package( diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl index d8e422a5d..32235813a 100644 --- a/tools/images/defs.bzl +++ b/tools/images/defs.bzl @@ -28,6 +28,8 @@ The vm_test rule can be used to execute a command remotely. For example, ) """ +load("//tools:defs.bzl", "default_installer") + def _vm_image_impl(ctx): script_paths = [] for script in ctx.files.scripts: @@ -165,8 +167,8 @@ def vm_test( targets = kwargs.pop("targets", []) if installer: targets = [installer] + targets - targets = [ - ] + targets + if default_installer(): + targets = [default_installer()] + targets _vm_test( tags = [ "local", diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD index ee7ea11fd..4ef1a3124 100644 --- a/tools/issue_reviver/BUILD +++ b/tools/issue_reviver/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD index 6da22ba1c..da4133472 100644 --- a/tools/issue_reviver/github/BUILD +++ b/tools/issue_reviver/github/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "github", srcs = ["github.go"], - importpath = "gvisor.dev/gvisor/tools/issue_reviver/github", visibility = [ "//tools/issue_reviver:__subpackages__", ], diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD index 2c3675977..d262932bd 100644 --- a/tools/issue_reviver/reviver/BUILD +++ b/tools/issue_reviver/reviver/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "reviver", srcs = ["reviver.go"], - importpath = "gvisor.dev/gvisor/tools/issue_reviver/reviver", visibility = [ "//tools/issue_reviver:__subpackages__", ], @@ -15,5 +14,5 @@ go_test( name = "reviver_test", size = "small", srcs = ["reviver_test.go"], - embed = [":reviver"], + library = ":reviver", ) diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh index fb09ff331..a22c8c9f2 100755 --- a/tools/workspace_status.sh +++ b/tools/workspace_status.sh @@ -15,4 +15,4 @@ # limitations under the License. # The STABLE_ prefix will trigger a re-link if it changes. -echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty) +echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty || echo 0.0.0) diff --git a/vdso/BUILD b/vdso/BUILD index 2b6744c26..d37d4266d 100644 --- a/vdso/BUILD +++ b/vdso/BUILD @@ -3,20 +3,10 @@ # normal system VDSO (time, gettimeofday, clock_gettimeofday) but which uses # timekeeping parameters managed by the sandbox kernel. -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier") +load("//tools:defs.bzl", "cc_flags_supplier", "cc_toolchain", "select_arch") package(licenses = ["notice"]) -config_setting( - name = "x86_64", - constraint_values = ["@bazel_tools//platforms:x86_64"], -) - -config_setting( - name = "aarch64", - constraint_values = ["@bazel_tools//platforms:aarch64"], -) - genrule( name = "vdso", srcs = [ @@ -39,14 +29,15 @@ genrule( "-O2 " + "-std=c++11 " + "-fPIC " + + "-fno-sanitize=all " + # Some toolchains enable stack protector by default. Disable it, the # VDSO has no hooks to handle failures. "-fno-stack-protector " + "-fuse-ld=gold " + - select({ - ":x86_64": "-m64 ", - "//conditions:default": "", - }) + + select_arch( + amd64 = "-m64 ", + arm64 = "", + ) + "-shared " + "-nostdlib " + "-Wl,-soname=linux-vdso.so.1 " + @@ -55,12 +46,10 @@ genrule( "-Wl,-Bsymbolic " + "-Wl,-z,max-page-size=4096 " + "-Wl,-z,common-page-size=4096 " + - select( - { - ":x86_64": "-Wl,-T$(location vdso_amd64.lds) ", - ":aarch64": "-Wl,-T$(location vdso_arm64.lds) ", - }, - no_match_error = "Unsupported architecture", + select_arch( + amd64 = "-Wl,-T$(location vdso_amd64.lds) ", + arm64 = "-Wl,-T$(location vdso_arm64.lds) ", + no_match_error = "unsupported architecture", ) + "-o $(location vdso.so) " + "$(location vdso.cc) " + @@ -73,7 +62,7 @@ genrule( ], features = ["-pie"], toolchains = [ - "@bazel_tools//tools/cpp:current_cc_toolchain", + cc_toolchain, ":no_pie_cc_flags", ], visibility = ["//:sandbox"], -- cgit v1.2.3 From 0e2f1b7abd219f39d67cc2cecd00c441a13eeb29 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Mon, 27 Jan 2020 15:17:58 -0800 Subject: Update package locations. Because the abi will depend on the core types for marshalling (usermem, context, safemem, safecopy), these need to be flattened from the sentry directory. These packages contain no sentry-specific details. PiperOrigin-RevId: 291811289 --- pkg/abi/abi.go | 4 + pkg/context/BUILD | 13 + pkg/context/context.go | 141 +++++ pkg/safecopy/BUILD | 29 + pkg/safecopy/LICENSE | 27 + pkg/safecopy/atomic_amd64.s | 136 +++++ pkg/safecopy/atomic_arm64.s | 126 +++++ pkg/safecopy/memclr_amd64.s | 147 +++++ pkg/safecopy/memclr_arm64.s | 74 +++ pkg/safecopy/memcpy_amd64.s | 250 +++++++++ pkg/safecopy/memcpy_arm64.s | 78 +++ pkg/safecopy/safecopy.go | 144 +++++ pkg/safecopy/safecopy_test.go | 617 +++++++++++++++++++++ pkg/safecopy/safecopy_unsafe.go | 335 +++++++++++ pkg/safecopy/sighandler_amd64.s | 133 +++++ pkg/safecopy/sighandler_arm64.s | 143 +++++ pkg/safemem/BUILD | 27 + pkg/safemem/block_unsafe.go | 279 ++++++++++ pkg/safemem/io.go | 392 +++++++++++++ pkg/safemem/io_test.go | 199 +++++++ pkg/safemem/safemem.go | 16 + pkg/safemem/seq_test.go | 196 +++++++ pkg/safemem/seq_unsafe.go | 299 ++++++++++ pkg/sentry/arch/BUILD | 4 +- pkg/sentry/arch/arch.go | 2 +- pkg/sentry/arch/arch_aarch64.go | 2 +- pkg/sentry/arch/arch_amd64.go | 2 +- pkg/sentry/arch/arch_arm64.go | 2 +- pkg/sentry/arch/arch_state_x86.go | 2 +- pkg/sentry/arch/arch_x86.go | 2 +- pkg/sentry/arch/auxv.go | 2 +- pkg/sentry/arch/signal.go | 2 +- pkg/sentry/arch/signal_amd64.go | 2 +- pkg/sentry/arch/signal_arm64.go | 2 +- pkg/sentry/arch/signal_stack.go | 2 +- pkg/sentry/arch/stack.go | 4 +- pkg/sentry/context/BUILD | 13 - pkg/sentry/context/context.go | 141 ----- pkg/sentry/context/contexttest/BUILD | 21 - pkg/sentry/context/contexttest/contexttest.go | 188 ------- pkg/sentry/contexttest/BUILD | 21 + pkg/sentry/contexttest/contexttest.go | 188 +++++++ pkg/sentry/fs/BUILD | 12 +- pkg/sentry/fs/anon/BUILD | 4 +- pkg/sentry/fs/anon/anon.go | 4 +- pkg/sentry/fs/attr.go | 2 +- pkg/sentry/fs/context.go | 2 +- pkg/sentry/fs/copy_up.go | 4 +- pkg/sentry/fs/copy_up_test.go | 2 +- pkg/sentry/fs/dev/BUILD | 6 +- pkg/sentry/fs/dev/dev.go | 4 +- pkg/sentry/fs/dev/fs.go | 2 +- pkg/sentry/fs/dev/full.go | 4 +- pkg/sentry/fs/dev/null.go | 2 +- pkg/sentry/fs/dev/random.go | 6 +- pkg/sentry/fs/dev/tty.go | 2 +- pkg/sentry/fs/dirent.go | 2 +- pkg/sentry/fs/dirent_refs_test.go | 4 +- pkg/sentry/fs/fdpipe/BUILD | 12 +- pkg/sentry/fs/fdpipe/pipe.go | 6 +- pkg/sentry/fs/fdpipe/pipe_opener.go | 2 +- pkg/sentry/fs/fdpipe/pipe_opener_test.go | 6 +- pkg/sentry/fs/fdpipe/pipe_state.go | 2 +- pkg/sentry/fs/fdpipe/pipe_test.go | 4 +- pkg/sentry/fs/file.go | 4 +- pkg/sentry/fs/file_operations.go | 4 +- pkg/sentry/fs/file_overlay.go | 4 +- pkg/sentry/fs/file_overlay_test.go | 2 +- pkg/sentry/fs/filesystems.go | 2 +- pkg/sentry/fs/filetest/BUILD | 6 +- pkg/sentry/fs/filetest/filetest.go | 6 +- pkg/sentry/fs/fs.go | 2 +- pkg/sentry/fs/fsutil/BUILD | 14 +- pkg/sentry/fs/fsutil/dirty_set.go | 6 +- pkg/sentry/fs/fsutil/dirty_set_test.go | 2 +- pkg/sentry/fs/fsutil/file.go | 4 +- pkg/sentry/fs/fsutil/file_range_set.go | 6 +- pkg/sentry/fs/fsutil/host_file_mapper.go | 4 +- pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go | 2 +- pkg/sentry/fs/fsutil/host_mappable.go | 6 +- pkg/sentry/fs/fsutil/inode.go | 2 +- pkg/sentry/fs/fsutil/inode_cached.go | 6 +- pkg/sentry/fs/fsutil/inode_cached_test.go | 8 +- pkg/sentry/fs/gofer/BUILD | 10 +- pkg/sentry/fs/gofer/attr.go | 4 +- pkg/sentry/fs/gofer/cache_policy.go | 2 +- pkg/sentry/fs/gofer/context_file.go | 2 +- pkg/sentry/fs/gofer/file.go | 4 +- pkg/sentry/fs/gofer/file_state.go | 2 +- pkg/sentry/fs/gofer/fs.go | 2 +- pkg/sentry/fs/gofer/gofer_test.go | 4 +- pkg/sentry/fs/gofer/handles.go | 4 +- pkg/sentry/fs/gofer/inode.go | 4 +- pkg/sentry/fs/gofer/inode_state.go | 2 +- pkg/sentry/fs/gofer/path.go | 2 +- pkg/sentry/fs/gofer/session.go | 2 +- pkg/sentry/fs/gofer/session_state.go | 2 +- pkg/sentry/fs/gofer/socket.go | 2 +- pkg/sentry/fs/gofer/util.go | 2 +- pkg/sentry/fs/host/BUILD | 12 +- pkg/sentry/fs/host/control.go | 2 +- pkg/sentry/fs/host/file.go | 6 +- pkg/sentry/fs/host/fs.go | 2 +- pkg/sentry/fs/host/fs_test.go | 4 +- pkg/sentry/fs/host/inode.go | 4 +- pkg/sentry/fs/host/inode_state.go | 2 +- pkg/sentry/fs/host/inode_test.go | 2 +- pkg/sentry/fs/host/socket.go | 2 +- pkg/sentry/fs/host/socket_test.go | 4 +- pkg/sentry/fs/host/tty.go | 4 +- pkg/sentry/fs/host/wait_test.go | 2 +- pkg/sentry/fs/inode.go | 2 +- pkg/sentry/fs/inode_operations.go | 2 +- pkg/sentry/fs/inode_overlay.go | 2 +- pkg/sentry/fs/inode_overlay_test.go | 2 +- pkg/sentry/fs/inotify.go | 4 +- pkg/sentry/fs/inotify_event.go | 4 +- pkg/sentry/fs/mock.go | 2 +- pkg/sentry/fs/mount.go | 2 +- pkg/sentry/fs/mount_overlay.go | 2 +- pkg/sentry/fs/mount_test.go | 2 +- pkg/sentry/fs/mounts.go | 2 +- pkg/sentry/fs/mounts_test.go | 2 +- pkg/sentry/fs/offset.go | 2 +- pkg/sentry/fs/overlay.go | 4 +- pkg/sentry/fs/proc/BUILD | 8 +- pkg/sentry/fs/proc/cgroup.go | 2 +- pkg/sentry/fs/proc/cpuinfo.go | 2 +- pkg/sentry/fs/proc/exec_args.go | 4 +- pkg/sentry/fs/proc/fds.go | 2 +- pkg/sentry/fs/proc/filesystems.go | 2 +- pkg/sentry/fs/proc/fs.go | 2 +- pkg/sentry/fs/proc/inode.go | 4 +- pkg/sentry/fs/proc/loadavg.go | 2 +- pkg/sentry/fs/proc/meminfo.go | 4 +- pkg/sentry/fs/proc/mounts.go | 2 +- pkg/sentry/fs/proc/net.go | 4 +- pkg/sentry/fs/proc/proc.go | 2 +- pkg/sentry/fs/proc/seqfile/BUILD | 10 +- pkg/sentry/fs/proc/seqfile/seqfile.go | 4 +- pkg/sentry/fs/proc/seqfile/seqfile_test.go | 6 +- pkg/sentry/fs/proc/stat.go | 2 +- pkg/sentry/fs/proc/sys.go | 4 +- pkg/sentry/fs/proc/sys_net.go | 4 +- pkg/sentry/fs/proc/sys_net_test.go | 4 +- pkg/sentry/fs/proc/task.go | 4 +- pkg/sentry/fs/proc/uid_gid_map.go | 4 +- pkg/sentry/fs/proc/uptime.go | 4 +- pkg/sentry/fs/proc/version.go | 2 +- pkg/sentry/fs/ramfs/BUILD | 6 +- pkg/sentry/fs/ramfs/dir.go | 2 +- pkg/sentry/fs/ramfs/socket.go | 2 +- pkg/sentry/fs/ramfs/symlink.go | 2 +- pkg/sentry/fs/ramfs/tree.go | 4 +- pkg/sentry/fs/ramfs/tree_test.go | 2 +- pkg/sentry/fs/splice.go | 2 +- pkg/sentry/fs/sys/BUILD | 4 +- pkg/sentry/fs/sys/devices.go | 2 +- pkg/sentry/fs/sys/fs.go | 2 +- pkg/sentry/fs/sys/sys.go | 4 +- pkg/sentry/fs/timerfd/BUILD | 4 +- pkg/sentry/fs/timerfd/timerfd.go | 4 +- pkg/sentry/fs/tmpfs/BUILD | 10 +- pkg/sentry/fs/tmpfs/file_regular.go | 4 +- pkg/sentry/fs/tmpfs/file_test.go | 4 +- pkg/sentry/fs/tmpfs/fs.go | 2 +- pkg/sentry/fs/tmpfs/inode_file.go | 6 +- pkg/sentry/fs/tmpfs/tmpfs.go | 4 +- pkg/sentry/fs/tty/BUILD | 10 +- pkg/sentry/fs/tty/dir.go | 4 +- pkg/sentry/fs/tty/fs.go | 2 +- pkg/sentry/fs/tty/line_discipline.go | 4 +- pkg/sentry/fs/tty/master.go | 4 +- pkg/sentry/fs/tty/queue.go | 6 +- pkg/sentry/fs/tty/slave.go | 4 +- pkg/sentry/fs/tty/terminal.go | 4 +- pkg/sentry/fs/tty/tty_test.go | 4 +- pkg/sentry/fsimpl/ext/BUILD | 12 +- pkg/sentry/fsimpl/ext/benchmark/BUILD | 4 +- pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go | 4 +- pkg/sentry/fsimpl/ext/directory.go | 2 +- pkg/sentry/fsimpl/ext/ext.go | 2 +- pkg/sentry/fsimpl/ext/ext_test.go | 6 +- pkg/sentry/fsimpl/ext/file_description.go | 2 +- pkg/sentry/fsimpl/ext/filesystem.go | 2 +- pkg/sentry/fsimpl/ext/regular_file.go | 6 +- pkg/sentry/fsimpl/ext/symlink.go | 4 +- pkg/sentry/fsimpl/kernfs/BUILD | 10 +- pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go | 4 +- pkg/sentry/fsimpl/kernfs/fd_impl_util.go | 4 +- pkg/sentry/fsimpl/kernfs/filesystem.go | 2 +- pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 6 +- pkg/sentry/fsimpl/kernfs/symlink.go | 2 +- pkg/sentry/fsimpl/proc/BUILD | 12 +- pkg/sentry/fsimpl/proc/filesystem.go | 2 +- pkg/sentry/fsimpl/proc/subtasks.go | 2 +- pkg/sentry/fsimpl/proc/task.go | 2 +- pkg/sentry/fsimpl/proc/task_files.go | 6 +- pkg/sentry/fsimpl/proc/tasks.go | 2 +- pkg/sentry/fsimpl/proc/tasks_files.go | 4 +- pkg/sentry/fsimpl/proc/tasks_net.go | 4 +- pkg/sentry/fsimpl/proc/tasks_sys.go | 2 +- pkg/sentry/fsimpl/proc/tasks_sys_test.go | 2 +- pkg/sentry/fsimpl/proc/tasks_test.go | 4 +- pkg/sentry/fsimpl/sys/BUILD | 2 +- pkg/sentry/fsimpl/sys/sys.go | 2 +- pkg/sentry/fsimpl/testutil/BUILD | 4 +- pkg/sentry/fsimpl/testutil/kernel.go | 2 +- pkg/sentry/fsimpl/testutil/testutil.go | 4 +- pkg/sentry/fsimpl/tmpfs/BUILD | 16 +- pkg/sentry/fsimpl/tmpfs/benchmark_test.go | 4 +- pkg/sentry/fsimpl/tmpfs/directory.go | 2 +- pkg/sentry/fsimpl/tmpfs/filesystem.go | 2 +- pkg/sentry/fsimpl/tmpfs/named_pipe.go | 4 +- pkg/sentry/fsimpl/tmpfs/pipe_test.go | 6 +- pkg/sentry/fsimpl/tmpfs/regular_file.go | 6 +- pkg/sentry/fsimpl/tmpfs/regular_file_test.go | 4 +- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 2 +- pkg/sentry/hostmm/BUILD | 2 +- pkg/sentry/hostmm/hostmm.go | 2 +- pkg/sentry/inet/BUILD | 2 +- pkg/sentry/inet/context.go | 2 +- pkg/sentry/kernel/BUILD | 12 +- pkg/sentry/kernel/auth/BUILD | 2 +- pkg/sentry/kernel/auth/context.go | 2 +- pkg/sentry/kernel/auth/id_map.go | 2 +- pkg/sentry/kernel/context.go | 2 +- pkg/sentry/kernel/contexttest/BUILD | 4 +- pkg/sentry/kernel/contexttest/contexttest.go | 4 +- pkg/sentry/kernel/epoll/BUILD | 6 +- pkg/sentry/kernel/epoll/epoll.go | 4 +- pkg/sentry/kernel/epoll/epoll_test.go | 2 +- pkg/sentry/kernel/eventfd/BUILD | 8 +- pkg/sentry/kernel/eventfd/eventfd.go | 4 +- pkg/sentry/kernel/eventfd/eventfd_test.go | 4 +- pkg/sentry/kernel/fd_table.go | 2 +- pkg/sentry/kernel/fd_table_test.go | 4 +- pkg/sentry/kernel/futex/BUILD | 6 +- pkg/sentry/kernel/futex/futex.go | 2 +- pkg/sentry/kernel/futex/futex_test.go | 2 +- pkg/sentry/kernel/kernel.go | 2 +- pkg/sentry/kernel/pipe/BUILD | 12 +- pkg/sentry/kernel/pipe/buffer.go | 2 +- pkg/sentry/kernel/pipe/buffer_test.go | 2 +- pkg/sentry/kernel/pipe/node.go | 2 +- pkg/sentry/kernel/pipe/node_test.go | 6 +- pkg/sentry/kernel/pipe/pipe.go | 2 +- pkg/sentry/kernel/pipe/pipe_test.go | 4 +- pkg/sentry/kernel/pipe/pipe_util.go | 4 +- pkg/sentry/kernel/pipe/reader_writer.go | 4 +- pkg/sentry/kernel/pipe/vfs.go | 4 +- pkg/sentry/kernel/ptrace.go | 2 +- pkg/sentry/kernel/ptrace_amd64.go | 2 +- pkg/sentry/kernel/ptrace_arm64.go | 2 +- pkg/sentry/kernel/rseq.go | 2 +- pkg/sentry/kernel/seccomp.go | 2 +- pkg/sentry/kernel/semaphore/BUILD | 6 +- pkg/sentry/kernel/semaphore/semaphore.go | 2 +- pkg/sentry/kernel/semaphore/semaphore_test.go | 4 +- pkg/sentry/kernel/shm/BUILD | 4 +- pkg/sentry/kernel/shm/shm.go | 4 +- pkg/sentry/kernel/signalfd/BUILD | 4 +- pkg/sentry/kernel/signalfd/signalfd.go | 4 +- pkg/sentry/kernel/syscalls.go | 2 +- pkg/sentry/kernel/task.go | 4 +- pkg/sentry/kernel/task_clone.go | 2 +- pkg/sentry/kernel/task_context.go | 4 +- pkg/sentry/kernel/task_futex.go | 2 +- pkg/sentry/kernel/task_log.go | 2 +- pkg/sentry/kernel/task_run.go | 2 +- pkg/sentry/kernel/task_signals.go | 2 +- pkg/sentry/kernel/task_start.go | 2 +- pkg/sentry/kernel/task_syscall.go | 2 +- pkg/sentry/kernel/task_usermem.go | 2 +- pkg/sentry/kernel/time/BUILD | 2 +- pkg/sentry/kernel/time/context.go | 2 +- pkg/sentry/kernel/timekeeper_test.go | 4 +- pkg/sentry/kernel/vdso.go | 4 +- pkg/sentry/limits/BUILD | 2 +- pkg/sentry/limits/context.go | 2 +- pkg/sentry/loader/BUILD | 6 +- pkg/sentry/loader/elf.go | 4 +- pkg/sentry/loader/interpreter.go | 4 +- pkg/sentry/loader/loader.go | 4 +- pkg/sentry/loader/vdso.go | 6 +- pkg/sentry/memmap/BUILD | 6 +- pkg/sentry/memmap/mapping_set.go | 2 +- pkg/sentry/memmap/mapping_set_test.go | 2 +- pkg/sentry/memmap/memmap.go | 4 +- pkg/sentry/mm/BUILD | 18 +- pkg/sentry/mm/address_space.go | 2 +- pkg/sentry/mm/aio_context.go | 4 +- pkg/sentry/mm/debug.go | 2 +- pkg/sentry/mm/io.go | 6 +- pkg/sentry/mm/lifecycle.go | 4 +- pkg/sentry/mm/metadata.go | 2 +- pkg/sentry/mm/mm.go | 4 +- pkg/sentry/mm/mm_test.go | 6 +- pkg/sentry/mm/pma.go | 8 +- pkg/sentry/mm/procfs.go | 4 +- pkg/sentry/mm/save_restore.go | 2 +- pkg/sentry/mm/shm.go | 4 +- pkg/sentry/mm/special_mappable.go | 4 +- pkg/sentry/mm/syscalls.go | 4 +- pkg/sentry/mm/vma.go | 4 +- pkg/sentry/pgalloc/BUILD | 8 +- pkg/sentry/pgalloc/context.go | 2 +- pkg/sentry/pgalloc/pgalloc.go | 6 +- pkg/sentry/pgalloc/pgalloc_test.go | 2 +- pkg/sentry/pgalloc/save_restore.go | 2 +- pkg/sentry/platform/BUILD | 8 +- pkg/sentry/platform/context.go | 2 +- pkg/sentry/platform/kvm/BUILD | 6 +- pkg/sentry/platform/kvm/address_space.go | 2 +- pkg/sentry/platform/kvm/bluepill.go | 2 +- pkg/sentry/platform/kvm/bluepill_fault.go | 2 +- pkg/sentry/platform/kvm/context.go | 2 +- pkg/sentry/platform/kvm/kvm.go | 2 +- pkg/sentry/platform/kvm/kvm_test.go | 2 +- pkg/sentry/platform/kvm/machine.go | 2 +- pkg/sentry/platform/kvm/machine_amd64.go | 2 +- pkg/sentry/platform/kvm/machine_arm64.go | 2 +- pkg/sentry/platform/kvm/machine_arm64_unsafe.go | 2 +- pkg/sentry/platform/kvm/physical_map.go | 2 +- pkg/sentry/platform/kvm/virtual_map.go | 2 +- pkg/sentry/platform/kvm/virtual_map_test.go | 2 +- pkg/sentry/platform/mmap_min_addr.go | 2 +- pkg/sentry/platform/platform.go | 4 +- pkg/sentry/platform/ptrace/BUILD | 4 +- pkg/sentry/platform/ptrace/ptrace.go | 2 +- pkg/sentry/platform/ptrace/ptrace_unsafe.go | 2 +- pkg/sentry/platform/ptrace/stub_unsafe.go | 4 +- pkg/sentry/platform/ptrace/subprocess.go | 2 +- pkg/sentry/platform/ring0/BUILD | 2 +- pkg/sentry/platform/ring0/defs_amd64.go | 2 +- pkg/sentry/platform/ring0/defs_arm64.go | 2 +- pkg/sentry/platform/ring0/gen_offsets/BUILD | 2 +- pkg/sentry/platform/ring0/pagetables/BUILD | 4 +- .../platform/ring0/pagetables/allocator_unsafe.go | 2 +- pkg/sentry/platform/ring0/pagetables/pagetables.go | 2 +- .../ring0/pagetables/pagetables_aarch64.go | 2 +- .../ring0/pagetables/pagetables_amd64_test.go | 2 +- .../ring0/pagetables/pagetables_arm64_test.go | 2 +- .../platform/ring0/pagetables/pagetables_test.go | 2 +- .../platform/ring0/pagetables/pagetables_x86.go | 2 +- pkg/sentry/platform/safecopy/BUILD | 29 - pkg/sentry/platform/safecopy/LICENSE | 27 - pkg/sentry/platform/safecopy/atomic_amd64.s | 136 ----- pkg/sentry/platform/safecopy/atomic_arm64.s | 126 ----- pkg/sentry/platform/safecopy/memclr_amd64.s | 147 ----- pkg/sentry/platform/safecopy/memclr_arm64.s | 74 --- pkg/sentry/platform/safecopy/memcpy_amd64.s | 250 --------- pkg/sentry/platform/safecopy/memcpy_arm64.s | 78 --- pkg/sentry/platform/safecopy/safecopy.go | 144 ----- pkg/sentry/platform/safecopy/safecopy_test.go | 617 --------------------- pkg/sentry/platform/safecopy/safecopy_unsafe.go | 335 ----------- pkg/sentry/platform/safecopy/sighandler_amd64.s | 133 ----- pkg/sentry/platform/safecopy/sighandler_arm64.s | 143 ----- pkg/sentry/safemem/BUILD | 27 - pkg/sentry/safemem/block_unsafe.go | 279 ---------- pkg/sentry/safemem/io.go | 392 ------------- pkg/sentry/safemem/io_test.go | 199 ------- pkg/sentry/safemem/safemem.go | 16 - pkg/sentry/safemem/seq_test.go | 196 ------- pkg/sentry/safemem/seq_unsafe.go | 299 ---------- pkg/sentry/socket/BUILD | 4 +- pkg/sentry/socket/control/BUILD | 4 +- pkg/sentry/socket/control/control.go | 4 +- pkg/sentry/socket/hostinet/BUILD | 6 +- pkg/sentry/socket/hostinet/socket.go | 6 +- pkg/sentry/socket/hostinet/socket_unsafe.go | 4 +- pkg/sentry/socket/hostinet/stack.go | 4 +- pkg/sentry/socket/netfilter/BUILD | 2 +- pkg/sentry/socket/netfilter/netfilter.go | 2 +- pkg/sentry/socket/netlink/BUILD | 4 +- pkg/sentry/socket/netlink/message.go | 2 +- pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/route/BUILD | 2 +- pkg/sentry/socket/netlink/route/protocol.go | 2 +- pkg/sentry/socket/netlink/socket.go | 4 +- pkg/sentry/socket/netlink/uevent/BUILD | 2 +- pkg/sentry/socket/netlink/uevent/protocol.go | 2 +- pkg/sentry/socket/netstack/BUILD | 6 +- pkg/sentry/socket/netstack/netstack.go | 6 +- pkg/sentry/socket/netstack/provider.go | 2 +- pkg/sentry/socket/socket.go | 4 +- pkg/sentry/socket/unix/BUILD | 6 +- pkg/sentry/socket/unix/io.go | 4 +- pkg/sentry/socket/unix/transport/BUILD | 2 +- pkg/sentry/socket/unix/transport/connectioned.go | 2 +- pkg/sentry/socket/unix/transport/connectionless.go | 2 +- pkg/sentry/socket/unix/transport/unix.go | 2 +- pkg/sentry/socket/unix/unix.go | 4 +- pkg/sentry/strace/BUILD | 2 +- pkg/sentry/strace/poll.go | 2 +- pkg/sentry/strace/select.go | 2 +- pkg/sentry/strace/signal.go | 2 +- pkg/sentry/strace/socket.go | 2 +- pkg/sentry/strace/strace.go | 2 +- pkg/sentry/syscalls/linux/BUILD | 6 +- pkg/sentry/syscalls/linux/linux64_amd64.go | 2 +- pkg/sentry/syscalls/linux/linux64_arm64.go | 2 +- pkg/sentry/syscalls/linux/sigset.go | 2 +- pkg/sentry/syscalls/linux/sys_aio.go | 2 +- pkg/sentry/syscalls/linux/sys_epoll.go | 2 +- pkg/sentry/syscalls/linux/sys_file.go | 4 +- pkg/sentry/syscalls/linux/sys_futex.go | 2 +- pkg/sentry/syscalls/linux/sys_getdents.go | 2 +- pkg/sentry/syscalls/linux/sys_mempolicy.go | 2 +- pkg/sentry/syscalls/linux/sys_mmap.go | 2 +- pkg/sentry/syscalls/linux/sys_mount.go | 2 +- pkg/sentry/syscalls/linux/sys_pipe.go | 2 +- pkg/sentry/syscalls/linux/sys_poll.go | 2 +- pkg/sentry/syscalls/linux/sys_random.go | 4 +- pkg/sentry/syscalls/linux/sys_read.go | 2 +- pkg/sentry/syscalls/linux/sys_rlimit.go | 2 +- pkg/sentry/syscalls/linux/sys_seccomp.go | 2 +- pkg/sentry/syscalls/linux/sys_sem.go | 2 +- pkg/sentry/syscalls/linux/sys_signal.go | 2 +- pkg/sentry/syscalls/linux/sys_socket.go | 2 +- pkg/sentry/syscalls/linux/sys_stat.go | 2 +- pkg/sentry/syscalls/linux/sys_stat_amd64.go | 2 +- pkg/sentry/syscalls/linux/sys_stat_arm64.go | 2 +- pkg/sentry/syscalls/linux/sys_thread.go | 2 +- pkg/sentry/syscalls/linux/sys_time.go | 2 +- pkg/sentry/syscalls/linux/sys_timer.go | 2 +- pkg/sentry/syscalls/linux/sys_write.go | 2 +- pkg/sentry/syscalls/linux/sys_xattr.go | 2 +- pkg/sentry/syscalls/linux/timespec.go | 2 +- pkg/sentry/unimpl/BUILD | 2 +- pkg/sentry/unimpl/events.go | 2 +- pkg/sentry/uniqueid/BUILD | 2 +- pkg/sentry/uniqueid/context.go | 2 +- pkg/sentry/usermem/BUILD | 55 -- pkg/sentry/usermem/README.md | 31 -- pkg/sentry/usermem/access_type.go | 128 ----- pkg/sentry/usermem/addr.go | 108 ---- pkg/sentry/usermem/addr_range_seq_test.go | 197 ------- pkg/sentry/usermem/addr_range_seq_unsafe.go | 277 --------- pkg/sentry/usermem/bytes_io.go | 141 ----- pkg/sentry/usermem/bytes_io_unsafe.go | 47 -- pkg/sentry/usermem/usermem.go | 597 -------------------- pkg/sentry/usermem/usermem_arm64.go | 53 -- pkg/sentry/usermem/usermem_test.go | 424 -------------- pkg/sentry/usermem/usermem_unsafe.go | 27 - pkg/sentry/usermem/usermem_x86.go | 38 -- pkg/sentry/vfs/BUILD | 10 +- pkg/sentry/vfs/context.go | 2 +- pkg/sentry/vfs/device.go | 2 +- pkg/sentry/vfs/file_description.go | 4 +- pkg/sentry/vfs/file_description_impl_util.go | 4 +- pkg/sentry/vfs/file_description_impl_util_test.go | 6 +- pkg/sentry/vfs/filesystem.go | 2 +- pkg/sentry/vfs/filesystem_type.go | 2 +- pkg/sentry/vfs/mount.go | 2 +- pkg/sentry/vfs/pathname.go | 2 +- pkg/sentry/vfs/testutil.go | 2 +- pkg/sentry/vfs/vfs.go | 2 +- pkg/usermem/BUILD | 55 ++ pkg/usermem/README.md | 31 ++ pkg/usermem/access_type.go | 128 +++++ pkg/usermem/addr.go | 108 ++++ pkg/usermem/addr_range_seq_test.go | 197 +++++++ pkg/usermem/addr_range_seq_unsafe.go | 277 +++++++++ pkg/usermem/bytes_io.go | 141 +++++ pkg/usermem/bytes_io_unsafe.go | 47 ++ pkg/usermem/usermem.go | 597 ++++++++++++++++++++ pkg/usermem/usermem_arm64.go | 53 ++ pkg/usermem/usermem_test.go | 424 ++++++++++++++ pkg/usermem/usermem_unsafe.go | 27 + pkg/usermem/usermem_x86.go | 38 ++ runsc/boot/BUILD | 6 +- runsc/boot/fds.go | 2 +- runsc/boot/fs.go | 2 +- runsc/boot/loader_test.go | 2 +- runsc/boot/user.go | 4 +- runsc/boot/user_test.go | 2 +- tools/go_marshal/defs.bzl | 4 +- tools/go_marshal/gomarshal/generator.go | 4 +- tools/go_marshal/test/BUILD | 2 +- tools/go_marshal/test/benchmark_test.go | 2 +- 483 files changed, 6839 insertions(+), 6835 deletions(-) create mode 100644 pkg/context/BUILD create mode 100644 pkg/context/context.go create mode 100644 pkg/safecopy/BUILD create mode 100644 pkg/safecopy/LICENSE create mode 100644 pkg/safecopy/atomic_amd64.s create mode 100644 pkg/safecopy/atomic_arm64.s create mode 100644 pkg/safecopy/memclr_amd64.s create mode 100644 pkg/safecopy/memclr_arm64.s create mode 100644 pkg/safecopy/memcpy_amd64.s create mode 100644 pkg/safecopy/memcpy_arm64.s create mode 100644 pkg/safecopy/safecopy.go create mode 100644 pkg/safecopy/safecopy_test.go create mode 100644 pkg/safecopy/safecopy_unsafe.go create mode 100644 pkg/safecopy/sighandler_amd64.s create mode 100644 pkg/safecopy/sighandler_arm64.s create mode 100644 pkg/safemem/BUILD create mode 100644 pkg/safemem/block_unsafe.go create mode 100644 pkg/safemem/io.go create mode 100644 pkg/safemem/io_test.go create mode 100644 pkg/safemem/safemem.go create mode 100644 pkg/safemem/seq_test.go create mode 100644 pkg/safemem/seq_unsafe.go delete mode 100644 pkg/sentry/context/BUILD delete mode 100644 pkg/sentry/context/context.go delete mode 100644 pkg/sentry/context/contexttest/BUILD delete mode 100644 pkg/sentry/context/contexttest/contexttest.go create mode 100644 pkg/sentry/contexttest/BUILD create mode 100644 pkg/sentry/contexttest/contexttest.go delete mode 100644 pkg/sentry/platform/safecopy/BUILD delete mode 100644 pkg/sentry/platform/safecopy/LICENSE delete mode 100644 pkg/sentry/platform/safecopy/atomic_amd64.s delete mode 100644 pkg/sentry/platform/safecopy/atomic_arm64.s delete mode 100644 pkg/sentry/platform/safecopy/memclr_amd64.s delete mode 100644 pkg/sentry/platform/safecopy/memclr_arm64.s delete mode 100644 pkg/sentry/platform/safecopy/memcpy_amd64.s delete mode 100644 pkg/sentry/platform/safecopy/memcpy_arm64.s delete mode 100644 pkg/sentry/platform/safecopy/safecopy.go delete mode 100644 pkg/sentry/platform/safecopy/safecopy_test.go delete mode 100644 pkg/sentry/platform/safecopy/safecopy_unsafe.go delete mode 100644 pkg/sentry/platform/safecopy/sighandler_amd64.s delete mode 100644 pkg/sentry/platform/safecopy/sighandler_arm64.s delete mode 100644 pkg/sentry/safemem/BUILD delete mode 100644 pkg/sentry/safemem/block_unsafe.go delete mode 100644 pkg/sentry/safemem/io.go delete mode 100644 pkg/sentry/safemem/io_test.go delete mode 100644 pkg/sentry/safemem/safemem.go delete mode 100644 pkg/sentry/safemem/seq_test.go delete mode 100644 pkg/sentry/safemem/seq_unsafe.go delete mode 100644 pkg/sentry/usermem/BUILD delete mode 100644 pkg/sentry/usermem/README.md delete mode 100644 pkg/sentry/usermem/access_type.go delete mode 100644 pkg/sentry/usermem/addr.go delete mode 100644 pkg/sentry/usermem/addr_range_seq_test.go delete mode 100644 pkg/sentry/usermem/addr_range_seq_unsafe.go delete mode 100644 pkg/sentry/usermem/bytes_io.go delete mode 100644 pkg/sentry/usermem/bytes_io_unsafe.go delete mode 100644 pkg/sentry/usermem/usermem.go delete mode 100644 pkg/sentry/usermem/usermem_arm64.go delete mode 100644 pkg/sentry/usermem/usermem_test.go delete mode 100644 pkg/sentry/usermem/usermem_unsafe.go delete mode 100644 pkg/sentry/usermem/usermem_x86.go create mode 100644 pkg/usermem/BUILD create mode 100644 pkg/usermem/README.md create mode 100644 pkg/usermem/access_type.go create mode 100644 pkg/usermem/addr.go create mode 100644 pkg/usermem/addr_range_seq_test.go create mode 100644 pkg/usermem/addr_range_seq_unsafe.go create mode 100644 pkg/usermem/bytes_io.go create mode 100644 pkg/usermem/bytes_io_unsafe.go create mode 100644 pkg/usermem/usermem.go create mode 100644 pkg/usermem/usermem_arm64.go create mode 100644 pkg/usermem/usermem_test.go create mode 100644 pkg/usermem/usermem_unsafe.go create mode 100644 pkg/usermem/usermem_x86.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/abi.go b/pkg/abi/abi.go index d56c481c9..e6be93c3a 100644 --- a/pkg/abi/abi.go +++ b/pkg/abi/abi.go @@ -39,3 +39,7 @@ func (o OS) String() string { return fmt.Sprintf("OS(%d)", o) } } + +// ABI is an interface that defines OS-specific interactions. +type ABI interface { +} diff --git a/pkg/context/BUILD b/pkg/context/BUILD new file mode 100644 index 000000000..239f31149 --- /dev/null +++ b/pkg/context/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "context", + srcs = ["context.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/amutex", + "//pkg/log", + ], +) diff --git a/pkg/context/context.go b/pkg/context/context.go new file mode 100644 index 000000000..23e009ef3 --- /dev/null +++ b/pkg/context/context.go @@ -0,0 +1,141 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context defines an internal context type. +// +// The given Context conforms to the standard Go context, but mandates +// additional methods that are specific to the kernel internals. Note however, +// that the Context described by this package carries additional constraints +// regarding concurrent access and retaining beyond the scope of a call. +// +// See the Context type for complete details. +package context + +import ( + "context" + "time" + + "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/log" +) + +type contextID int + +// Globally accessible values from a context. These keys are defined in the +// context package to resolve dependency cycles by not requiring the caller to +// import packages usually required to get these information. +const ( + // CtxThreadGroupID is the current thread group ID when a context represents + // a task context. The value is represented as an int32. + CtxThreadGroupID contextID = iota +) + +// ThreadGroupIDFromContext returns the current thread group ID when ctx +// represents a task context. +func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { + if tgid := ctx.Value(CtxThreadGroupID); tgid != nil { + return tgid.(int32), true + } + return 0, false +} + +// A Context represents a thread of execution (hereafter "goroutine" to reflect +// Go idiosyncrasy). It carries state associated with the goroutine across API +// boundaries. +// +// While Context exists for essentially the same reasons as Go's standard +// context.Context, the standard type represents the state of an operation +// rather than that of a goroutine. This is a critical distinction: +// +// - Unlike context.Context, which "may be passed to functions running in +// different goroutines", it is *not safe* to use the same Context in multiple +// concurrent goroutines. +// +// - It is *not safe* to retain a Context passed to a function beyond the scope +// of that function call. +// +// In both cases, values extracted from the Context should be used instead. +type Context interface { + log.Logger + amutex.Sleeper + context.Context + + // UninterruptibleSleepStart indicates the beginning of an uninterruptible + // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate + // is true and the Context represents a Task, the Task's AddressSpace is + // deactivated. + UninterruptibleSleepStart(deactivate bool) + + // UninterruptibleSleepFinish indicates the end of an uninterruptible sleep + // state that was begun by a previous call to UninterruptibleSleepStart. If + // activate is true and the Context represents a Task, the Task's + // AddressSpace is activated. Normally activate is the same value as the + // deactivate parameter passed to UninterruptibleSleepStart. + UninterruptibleSleepFinish(activate bool) +} + +// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep +// methods for anonymous embedding in other types that do not implement sleeps. +type NoopSleeper struct { + amutex.NoopSleeper +} + +// UninterruptibleSleepStart does nothing. +func (NoopSleeper) UninterruptibleSleepStart(bool) {} + +// UninterruptibleSleepFinish does nothing. +func (NoopSleeper) UninterruptibleSleepFinish(bool) {} + +// Deadline returns zero values, meaning no deadline. +func (NoopSleeper) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done returns nil. +func (NoopSleeper) Done() <-chan struct{} { + return nil +} + +// Err returns nil. +func (NoopSleeper) Err() error { + return nil +} + +// logContext implements basic logging. +type logContext struct { + log.Logger + NoopSleeper +} + +// Value implements Context.Value. +func (logContext) Value(key interface{}) interface{} { + return nil +} + +// bgContext is the context returned by context.Background. +var bgContext = &logContext{Logger: log.Log()} + +// Background returns an empty context using the default logger. +// +// Users should be wary of using a Background context. Please tag any use with +// FIXME(b/38173783) and a note to remove this use. +// +// Generally, one should use the Task as their context when available, or avoid +// having to use a context in places where a Task is unavailable. +// +// Using a Background context for tests is fine, as long as no values are +// needed from the context in the tested code paths. +func Background() Context { + return bgContext +} diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD new file mode 100644 index 000000000..426ef30c9 --- /dev/null +++ b/pkg/safecopy/BUILD @@ -0,0 +1,29 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "safecopy", + srcs = [ + "atomic_amd64.s", + "atomic_arm64.s", + "memclr_amd64.s", + "memclr_arm64.s", + "memcpy_amd64.s", + "memcpy_arm64.s", + "safecopy.go", + "safecopy_unsafe.go", + "sighandler_amd64.s", + "sighandler_arm64.s", + ], + visibility = ["//:sandbox"], + deps = ["//pkg/syserror"], +) + +go_test( + name = "safecopy_test", + srcs = [ + "safecopy_test.go", + ], + library = ":safecopy", +) diff --git a/pkg/safecopy/LICENSE b/pkg/safecopy/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/pkg/safecopy/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s new file mode 100644 index 000000000..a0cd78f33 --- /dev/null +++ b/pkg/safecopy/atomic_amd64.s @@ -0,0 +1,136 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// handleSwapUint32Fault returns the value stored in DI. Control is transferred +// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal +// number stored in DI. +// +// It must have the same frame configuration as swapUint32 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 + MOVL DI, sig+20(FP) + RET + +// swapUint32 atomically stores new into *addr and returns (the previous *addr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: addr must be aligned to a 4-byte boundary. +// +//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) +TEXT ·swapUint32(SB), NOSPLIT, $0-24 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleSwapUint32Fault will store a different value in this address. + MOVL $0, sig+20(FP) + + MOVQ addr+0(FP), DI + MOVL new+8(FP), AX + XCHGL AX, 0(DI) + MOVL AX, old+16(FP) + RET + +// handleSwapUint64Fault returns the value stored in DI. Control is transferred +// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal +// number stored in DI. +// +// It must have the same frame configuration as swapUint64 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 + MOVL DI, sig+24(FP) + RET + +// swapUint64 atomically stores new into *addr and returns (the previous *addr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: addr must be aligned to a 8-byte boundary. +// +//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) +TEXT ·swapUint64(SB), NOSPLIT, $0-28 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleSwapUint64Fault will store a different value in this address. + MOVL $0, sig+24(FP) + + MOVQ addr+0(FP), DI + MOVQ new+8(FP), AX + XCHGQ AX, 0(DI) + MOVQ AX, old+16(FP) + RET + +// handleCompareAndSwapUint32Fault returns the value stored in DI. Control is +// transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the +// signal number stored in DI. +// +// It must have the same frame configuration as compareAndSwapUint32 so that it +// can undo any potential call frame set up by the assembler. +TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 + MOVL DI, sig+20(FP) + RET + +// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns +// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is +// received during the operation, the value of prev is unspecified, and sig is +// the number of the signal that was received. +// +// Preconditions: addr must be aligned to a 4-byte boundary. +// +//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) +TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 + // Store 0 as the returned signal number. If we run to completion, this is + // the value the caller will see; if a signal is received, + // handleCompareAndSwapUint32Fault will store a different value in this + // address. + MOVL $0, sig+20(FP) + + MOVQ addr+0(FP), DI + MOVL old+8(FP), AX + MOVL new+12(FP), DX + LOCK + CMPXCHGL DX, 0(DI) + MOVL AX, prev+16(FP) + RET + +// handleLoadUint32Fault returns the value stored in DI. Control is transferred +// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal +// number stored in DI. +// +// It must have the same frame configuration as loadUint32 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 + MOVL DI, sig+12(FP) + RET + +// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS +// signal is received, the value returned is unspecified, and sig is the number +// of the signal that was received. +// +// Preconditions: addr must be aligned to a 4-byte boundary. +// +//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) +TEXT ·loadUint32(SB), NOSPLIT, $0-16 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleLoadUint32Fault will store a different value in this address. + MOVL $0, sig+12(FP) + + MOVQ addr+0(FP), AX + MOVL (AX), BX + MOVL BX, val+8(FP) + RET diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s new file mode 100644 index 000000000..d58ed71f7 --- /dev/null +++ b/pkg/safecopy/atomic_arm64.s @@ -0,0 +1,126 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "textflag.h" + +// handleSwapUint32Fault returns the value stored in R1. Control is transferred +// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal +// number stored in R1. +// +// It must have the same frame configuration as swapUint32 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 + MOVW R1, sig+20(FP) + RET + +// See the corresponding doc in safecopy_unsafe.go +// +// The code is derived from Go source runtime/internal/atomic.Xchg. +// +//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) +TEXT ·swapUint32(SB), NOSPLIT, $0-24 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleSwapUint32Fault will store a different value in this address. + MOVW $0, sig+20(FP) +again: + MOVD addr+0(FP), R0 + MOVW new+8(FP), R1 + LDAXRW (R0), R2 + STLXRW R1, (R0), R3 + CBNZ R3, again + MOVW R2, old+16(FP) + RET + +// handleSwapUint64Fault returns the value stored in R1. Control is transferred +// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal +// number stored in R1. +// +// It must have the same frame configuration as swapUint64 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 + MOVW R1, sig+24(FP) + RET + +// See the corresponding doc in safecopy_unsafe.go +// +// The code is derived from Go source runtime/internal/atomic.Xchg64. +// +//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) +TEXT ·swapUint64(SB), NOSPLIT, $0-28 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleSwapUint64Fault will store a different value in this address. + MOVW $0, sig+24(FP) +again: + MOVD addr+0(FP), R0 + MOVD new+8(FP), R1 + LDAXR (R0), R2 + STLXR R1, (R0), R3 + CBNZ R3, again + MOVD R2, old+16(FP) + RET + +// handleCompareAndSwapUint32Fault returns the value stored in R1. Control is +// transferred to it when compareAndSwapUint32 below receives SIGSEGV or SIGBUS, +// with the signal number stored in R1. +// +// It must have the same frame configuration as compareAndSwapUint32 so that it +// can undo any potential call frame set up by the assembler. +TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 + MOVW R1, sig+20(FP) + RET + +// See the corresponding doc in safecopy_unsafe.go +// +// The code is derived from Go source runtime/internal/atomic.Cas. +// +//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) +TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 + // Store 0 as the returned signal number. If we run to completion, this is + // the value the caller will see; if a signal is received, + // handleCompareAndSwapUint32Fault will store a different value in this + // address. + MOVW $0, sig+20(FP) + + MOVD addr+0(FP), R0 + MOVW old+8(FP), R1 + MOVW new+12(FP), R2 +again: + LDAXRW (R0), R3 + CMPW R1, R3 + BNE done + STLXRW R2, (R0), R4 + CBNZ R4, again +done: + MOVW R3, prev+16(FP) + RET + +// handleLoadUint32Fault returns the value stored in DI. Control is transferred +// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal +// number stored in DI. +// +// It must have the same frame configuration as loadUint32 so that it can undo +// any potential call frame set up by the assembler. +TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 + MOVW R1, sig+12(FP) + RET + +// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS +// signal is received, the value returned is unspecified, and sig is the number +// of the signal that was received. +// +// Preconditions: addr must be aligned to a 4-byte boundary. +// +//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) +TEXT ·loadUint32(SB), NOSPLIT, $0-16 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleLoadUint32Fault will store a different value in this address. + MOVW $0, sig+12(FP) + + MOVD addr+0(FP), R0 + LDARW (R0), R1 + MOVW R1, val+8(FP) + RET diff --git a/pkg/safecopy/memclr_amd64.s b/pkg/safecopy/memclr_amd64.s new file mode 100644 index 000000000..64cf32f05 --- /dev/null +++ b/pkg/safecopy/memclr_amd64.s @@ -0,0 +1,147 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "textflag.h" + +// handleMemclrFault returns (the value stored in AX, the value stored in DI). +// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS, +// with the faulting address stored in AX and the signal number stored in DI. +// +// It must have the same frame configuration as memclr so that it can undo any +// potential call frame set up by the assembler. +TEXT handleMemclrFault(SB), NOSPLIT, $0-28 + MOVQ AX, addr+16(FP) + MOVL DI, sig+24(FP) + RET + +// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS +// signal is received during the write, it returns the address that caused the +// fault and the number of the signal that was received. Otherwise, it returns +// an unspecified address and a signal number of 0. +// +// Data is written in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully written. +// +// The code is derived from runtime.memclrNoHeapPointers. +// +// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) +TEXT ·memclr(SB), NOSPLIT, $0-28 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleMemclrFault will store a different value in this address. + MOVL $0, sig+24(FP) + + MOVQ ptr+0(FP), DI + MOVQ n+8(FP), BX + XORQ AX, AX + + // MOVOU seems always faster than REP STOSQ. +tail: + TESTQ BX, BX + JEQ _0 + CMPQ BX, $2 + JBE _1or2 + CMPQ BX, $4 + JBE _3or4 + CMPQ BX, $8 + JB _5through7 + JE _8 + CMPQ BX, $16 + JBE _9through16 + PXOR X0, X0 + CMPQ BX, $32 + JBE _17through32 + CMPQ BX, $64 + JBE _33through64 + CMPQ BX, $128 + JBE _65through128 + CMPQ BX, $256 + JBE _129through256 + // TODO: use branch table and BSR to make this just a single dispatch + // TODO: for really big clears, use MOVNTDQ, even without AVX2. + +loop: + MOVOU X0, 0(DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, 64(DI) + MOVOU X0, 80(DI) + MOVOU X0, 96(DI) + MOVOU X0, 112(DI) + MOVOU X0, 128(DI) + MOVOU X0, 144(DI) + MOVOU X0, 160(DI) + MOVOU X0, 176(DI) + MOVOU X0, 192(DI) + MOVOU X0, 208(DI) + MOVOU X0, 224(DI) + MOVOU X0, 240(DI) + SUBQ $256, BX + ADDQ $256, DI + CMPQ BX, $256 + JAE loop + JMP tail + +_1or2: + MOVB AX, (DI) + MOVB AX, -1(DI)(BX*1) + RET +_0: + RET +_3or4: + MOVW AX, (DI) + MOVW AX, -2(DI)(BX*1) + RET +_5through7: + MOVL AX, (DI) + MOVL AX, -4(DI)(BX*1) + RET +_8: + // We need a separate case for 8 to make sure we clear pointers atomically. + MOVQ AX, (DI) + RET +_9through16: + MOVQ AX, (DI) + MOVQ AX, -8(DI)(BX*1) + RET +_17through32: + MOVOU X0, (DI) + MOVOU X0, -16(DI)(BX*1) + RET +_33through64: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET +_65through128: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, -64(DI)(BX*1) + MOVOU X0, -48(DI)(BX*1) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET +_129through256: + MOVOU X0, (DI) + MOVOU X0, 16(DI) + MOVOU X0, 32(DI) + MOVOU X0, 48(DI) + MOVOU X0, 64(DI) + MOVOU X0, 80(DI) + MOVOU X0, 96(DI) + MOVOU X0, 112(DI) + MOVOU X0, -128(DI)(BX*1) + MOVOU X0, -112(DI)(BX*1) + MOVOU X0, -96(DI)(BX*1) + MOVOU X0, -80(DI)(BX*1) + MOVOU X0, -64(DI)(BX*1) + MOVOU X0, -48(DI)(BX*1) + MOVOU X0, -32(DI)(BX*1) + MOVOU X0, -16(DI)(BX*1) + RET diff --git a/pkg/safecopy/memclr_arm64.s b/pkg/safecopy/memclr_arm64.s new file mode 100644 index 000000000..7361b9067 --- /dev/null +++ b/pkg/safecopy/memclr_arm64.s @@ -0,0 +1,74 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "textflag.h" + +// handleMemclrFault returns (the value stored in R0, the value stored in R1). +// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS, +// with the faulting address stored in R0 and the signal number stored in R1. +// +// It must have the same frame configuration as memclr so that it can undo any +// potential call frame set up by the assembler. +TEXT handleMemclrFault(SB), NOSPLIT, $0-28 + MOVD R0, addr+16(FP) + MOVW R1, sig+24(FP) + RET + +// See the corresponding doc in safecopy_unsafe.go +// +// The code is derived from runtime.memclrNoHeapPointers. +// +// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) +TEXT ·memclr(SB), NOSPLIT, $0-28 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleMemclrFault will store a different value in this address. + MOVW $0, sig+24(FP) + MOVD ptr+0(FP), R0 + MOVD n+8(FP), R1 + + // If size is less than 16 bytes, use tail_zero to zero what remains + CMP $16, R1 + BLT tail_zero + // Get buffer offset into 16 byte aligned address for better performance + ANDS $15, R0, ZR + BNE unaligned_to_16 +aligned_to_16: + LSR $4, R1, R2 +zero_by_16: + STP.P (ZR, ZR), 16(R0) // Store pair with post index. + SUBS $1, R2, R2 + BNE zero_by_16 + ANDS $15, R1, R1 + BEQ end + + // Zero buffer with size=R1 < 16 +tail_zero: + TBZ $3, R1, tail_zero_4 + MOVD.P ZR, 8(R0) +tail_zero_4: + TBZ $2, R1, tail_zero_2 + MOVW.P ZR, 4(R0) +tail_zero_2: + TBZ $1, R1, tail_zero_1 + MOVH.P ZR, 2(R0) +tail_zero_1: + TBZ $0, R1, end + MOVB ZR, (R0) +end: + RET + +unaligned_to_16: + MOVD R0, R2 +head_loop: + MOVBU.P ZR, 1(R0) + ANDS $15, R0, ZR + BNE head_loop + // Adjust length for what remains + SUB R2, R0, R3 + SUB R3, R1 + // If size is less than 16 bytes, use tail_zero to zero what remains + CMP $16, R1 + BLT tail_zero + B aligned_to_16 diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s new file mode 100644 index 000000000..129691d68 --- /dev/null +++ b/pkg/safecopy/memcpy_amd64.s @@ -0,0 +1,250 @@ +// Copyright © 1994-1999 Lucent Technologies Inc. All rights reserved. +// Revisions Copyright © 2000-2007 Vita Nuova Holdings Limited (www.vitanuova.com). All rights reserved. +// Portions Copyright 2009 The Go Authors. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "textflag.h" + +// handleMemcpyFault returns (the value stored in AX, the value stored in DI). +// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS, +// with the faulting address stored in AX and the signal number stored in DI. +// +// It must have the same frame configuration as memcpy so that it can undo any +// potential call frame set up by the assembler. +TEXT handleMemcpyFault(SB), NOSPLIT, $0-36 + MOVQ AX, addr+24(FP) + MOVL DI, sig+32(FP) + RET + +// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received +// during the copy, it returns the address that caused the fault and the number +// of the signal that was received. Otherwise, it returns an unspecified address +// and a signal number of 0. +// +// Data is copied in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully copied. +// +// The code is derived from the forward copying part of runtime.memmove. +// +// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) +TEXT ·memcpy(SB), NOSPLIT, $0-36 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleMemcpyFault will store a different value in this address. + MOVL $0, sig+32(FP) + + MOVQ to+0(FP), DI + MOVQ from+8(FP), SI + MOVQ n+16(FP), BX + + // REP instructions have a high startup cost, so we handle small sizes + // with some straightline code. The REP MOVSQ instruction is really fast + // for large sizes. The cutover is approximately 2K. +tail: + // move_129through256 or smaller work whether or not the source and the + // destination memory regions overlap because they load all data into + // registers before writing it back. move_256through2048 on the other + // hand can be used only when the memory regions don't overlap or the copy + // direction is forward. + TESTQ BX, BX + JEQ move_0 + CMPQ BX, $2 + JBE move_1or2 + CMPQ BX, $4 + JBE move_3or4 + CMPQ BX, $8 + JB move_5through7 + JE move_8 + CMPQ BX, $16 + JBE move_9through16 + CMPQ BX, $32 + JBE move_17through32 + CMPQ BX, $64 + JBE move_33through64 + CMPQ BX, $128 + JBE move_65through128 + CMPQ BX, $256 + JBE move_129through256 + // TODO: use branch table and BSR to make this just a single dispatch + +/* + * forward copy loop + */ + CMPQ BX, $2048 + JLS move_256through2048 + + // Check alignment + MOVL SI, AX + ORL DI, AX + TESTL $7, AX + JEQ fwdBy8 + + // Do 1 byte at a time + MOVQ BX, CX + REP; MOVSB + RET + +fwdBy8: + // Do 8 bytes at a time + MOVQ BX, CX + SHRQ $3, CX + ANDQ $7, BX + REP; MOVSQ + JMP tail + +move_1or2: + MOVB (SI), AX + MOVB AX, (DI) + MOVB -1(SI)(BX*1), CX + MOVB CX, -1(DI)(BX*1) + RET +move_0: + RET +move_3or4: + MOVW (SI), AX + MOVW AX, (DI) + MOVW -2(SI)(BX*1), CX + MOVW CX, -2(DI)(BX*1) + RET +move_5through7: + MOVL (SI), AX + MOVL AX, (DI) + MOVL -4(SI)(BX*1), CX + MOVL CX, -4(DI)(BX*1) + RET +move_8: + // We need a separate case for 8 to make sure we write pointers atomically. + MOVQ (SI), AX + MOVQ AX, (DI) + RET +move_9through16: + MOVQ (SI), AX + MOVQ AX, (DI) + MOVQ -8(SI)(BX*1), CX + MOVQ CX, -8(DI)(BX*1) + RET +move_17through32: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU -16(SI)(BX*1), X1 + MOVOU X1, -16(DI)(BX*1) + RET +move_33through64: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU -32(SI)(BX*1), X2 + MOVOU X2, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X3 + MOVOU X3, -16(DI)(BX*1) + RET +move_65through128: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU -64(SI)(BX*1), X4 + MOVOU X4, -64(DI)(BX*1) + MOVOU -48(SI)(BX*1), X5 + MOVOU X5, -48(DI)(BX*1) + MOVOU -32(SI)(BX*1), X6 + MOVOU X6, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X7 + MOVOU X7, -16(DI)(BX*1) + RET +move_129through256: + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU 64(SI), X4 + MOVOU X4, 64(DI) + MOVOU 80(SI), X5 + MOVOU X5, 80(DI) + MOVOU 96(SI), X6 + MOVOU X6, 96(DI) + MOVOU 112(SI), X7 + MOVOU X7, 112(DI) + MOVOU -128(SI)(BX*1), X8 + MOVOU X8, -128(DI)(BX*1) + MOVOU -112(SI)(BX*1), X9 + MOVOU X9, -112(DI)(BX*1) + MOVOU -96(SI)(BX*1), X10 + MOVOU X10, -96(DI)(BX*1) + MOVOU -80(SI)(BX*1), X11 + MOVOU X11, -80(DI)(BX*1) + MOVOU -64(SI)(BX*1), X12 + MOVOU X12, -64(DI)(BX*1) + MOVOU -48(SI)(BX*1), X13 + MOVOU X13, -48(DI)(BX*1) + MOVOU -32(SI)(BX*1), X14 + MOVOU X14, -32(DI)(BX*1) + MOVOU -16(SI)(BX*1), X15 + MOVOU X15, -16(DI)(BX*1) + RET +move_256through2048: + SUBQ $256, BX + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU 64(SI), X4 + MOVOU X4, 64(DI) + MOVOU 80(SI), X5 + MOVOU X5, 80(DI) + MOVOU 96(SI), X6 + MOVOU X6, 96(DI) + MOVOU 112(SI), X7 + MOVOU X7, 112(DI) + MOVOU 128(SI), X8 + MOVOU X8, 128(DI) + MOVOU 144(SI), X9 + MOVOU X9, 144(DI) + MOVOU 160(SI), X10 + MOVOU X10, 160(DI) + MOVOU 176(SI), X11 + MOVOU X11, 176(DI) + MOVOU 192(SI), X12 + MOVOU X12, 192(DI) + MOVOU 208(SI), X13 + MOVOU X13, 208(DI) + MOVOU 224(SI), X14 + MOVOU X14, 224(DI) + MOVOU 240(SI), X15 + MOVOU X15, 240(DI) + CMPQ BX, $256 + LEAQ 256(SI), SI + LEAQ 256(DI), DI + JGE move_256through2048 + JMP tail diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s new file mode 100644 index 000000000..e7e541565 --- /dev/null +++ b/pkg/safecopy/memcpy_arm64.s @@ -0,0 +1,78 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "textflag.h" + +// handleMemcpyFault returns (the value stored in R0, the value stored in R1). +// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS, +// with the faulting address stored in R0 and the signal number stored in R1. +// +// It must have the same frame configuration as memcpy so that it can undo any +// potential call frame set up by the assembler. +TEXT handleMemcpyFault(SB), NOSPLIT, $0-36 + MOVD R0, addr+24(FP) + MOVW R1, sig+32(FP) + RET + +// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received +// during the copy, it returns the address that caused the fault and the number +// of the signal that was received. Otherwise, it returns an unspecified address +// and a signal number of 0. +// +// Data is copied in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully copied. +// +// The code is derived from the Go source runtime.memmove. +// +// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) +TEXT ·memcpy(SB), NOSPLIT, $-8-36 + // Store 0 as the returned signal number. If we run to completion, + // this is the value the caller will see; if a signal is received, + // handleMemcpyFault will store a different value in this address. + MOVW $0, sig+32(FP) + + MOVD to+0(FP), R3 + MOVD from+8(FP), R4 + MOVD n+16(FP), R5 + CMP $0, R5 + BNE check + RET + +check: + AND $~7, R5, R7 // R7 is N&~7. + SUB R7, R5, R6 // R6 is N&7. + + // Copying forward proceeds by copying R7/8 words then copying R6 bytes. + // R3 and R4 are advanced as we copy. + + // (There may be implementations of armv8 where copying by bytes until + // at least one of source or dest is word aligned is a worthwhile + // optimization, but the on the one tested so far (xgene) it did not + // make a significance difference.) + + CMP $0, R7 // Do we need to do any word-by-word copying? + BEQ noforwardlarge + ADD R3, R7, R9 // R9 points just past where we copy by word. + +forwardlargeloop: + MOVD.P 8(R4), R8 // R8 is just a scratch register. + MOVD.P R8, 8(R3) + CMP R3, R9 + BNE forwardlargeloop + +noforwardlarge: + CMP $0, R6 // Do we need to do any byte-by-byte copying? + BNE forwardtail + RET + +forwardtail: + ADD R3, R6, R9 // R9 points just past the destination memory. + +forwardtailloop: + MOVBU.P 1(R4), R8 + MOVBU.P R8, 1(R3) + CMP R3, R9 + BNE forwardtailloop + RET diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go new file mode 100644 index 000000000..2fb7e5809 --- /dev/null +++ b/pkg/safecopy/safecopy.go @@ -0,0 +1,144 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package safecopy provides an efficient implementation of functions to access +// memory that may result in SIGSEGV or SIGBUS being sent to the accessor. +package safecopy + +import ( + "fmt" + "reflect" + "runtime" + "syscall" + + "gvisor.dev/gvisor/pkg/syserror" +) + +// SegvError is returned when a safecopy function receives SIGSEGV. +type SegvError struct { + // Addr is the address at which the SIGSEGV occurred. + Addr uintptr +} + +// Error implements error.Error. +func (e SegvError) Error() string { + return fmt.Sprintf("SIGSEGV at %#x", e.Addr) +} + +// BusError is returned when a safecopy function receives SIGBUS. +type BusError struct { + // Addr is the address at which the SIGBUS occurred. + Addr uintptr +} + +// Error implements error.Error. +func (e BusError) Error() string { + return fmt.Sprintf("SIGBUS at %#x", e.Addr) +} + +// AlignmentError is returned when a safecopy function is passed an address +// that does not meet alignment requirements. +type AlignmentError struct { + // Addr is the invalid address. + Addr uintptr + + // Alignment is the required alignment. + Alignment uintptr +} + +// Error implements error.Error. +func (e AlignmentError) Error() string { + return fmt.Sprintf("address %#x is not aligned to a %d-byte boundary", e.Addr, e.Alignment) +} + +var ( + // The begin and end addresses below are for the functions that are + // checked by the signal handler. + memcpyBegin uintptr + memcpyEnd uintptr + memclrBegin uintptr + memclrEnd uintptr + swapUint32Begin uintptr + swapUint32End uintptr + swapUint64Begin uintptr + swapUint64End uintptr + compareAndSwapUint32Begin uintptr + compareAndSwapUint32End uintptr + loadUint32Begin uintptr + loadUint32End uintptr + + // savedSigSegVHandler is a pointer to the SIGSEGV handler that was + // configured before we replaced it with our own. We still call into it + // when we get a SIGSEGV that is not interesting to us. + savedSigSegVHandler uintptr + + // same a above, but for SIGBUS signals. + savedSigBusHandler uintptr +) + +// signalHandler is our replacement signal handler for SIGSEGV and SIGBUS +// signals. +func signalHandler() + +// FindEndAddress returns the end address (one byte beyond the last) of the +// function that contains the specified address (begin). +func FindEndAddress(begin uintptr) uintptr { + f := runtime.FuncForPC(begin) + if f != nil { + for p := begin; ; p++ { + g := runtime.FuncForPC(p) + if f != g { + return p + } + } + } + return begin +} + +// initializeAddresses initializes the addresses used by the signal handler. +func initializeAddresses() { + // The following functions are written in assembly language, so they won't + // be inlined by the existing compiler/linker. Tests will fail if this + // assumption is violated. + memcpyBegin = reflect.ValueOf(memcpy).Pointer() + memcpyEnd = FindEndAddress(memcpyBegin) + memclrBegin = reflect.ValueOf(memclr).Pointer() + memclrEnd = FindEndAddress(memclrBegin) + swapUint32Begin = reflect.ValueOf(swapUint32).Pointer() + swapUint32End = FindEndAddress(swapUint32Begin) + swapUint64Begin = reflect.ValueOf(swapUint64).Pointer() + swapUint64End = FindEndAddress(swapUint64Begin) + compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer() + compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin) + loadUint32Begin = reflect.ValueOf(loadUint32).Pointer() + loadUint32End = FindEndAddress(loadUint32Begin) +} + +func init() { + initializeAddresses() + if err := ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) + } + if err := ReplaceSignalHandler(syscall.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) + } + syserror.AddErrorUnwrapper(func(e error) (syscall.Errno, bool) { + switch e.(type) { + case SegvError, BusError, AlignmentError: + return syscall.EFAULT, true + default: + return 0, false + } + }) +} diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go new file mode 100644 index 000000000..5818f7f9b --- /dev/null +++ b/pkg/safecopy/safecopy_test.go @@ -0,0 +1,617 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safecopy + +import ( + "bytes" + "fmt" + "io/ioutil" + "math/rand" + "os" + "runtime/debug" + "syscall" + "testing" + "unsafe" +) + +// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular +// dependency. +const pageSize = 4096 + +func initRandom(b []byte) { + for i := range b { + b[i] = byte(rand.Intn(256)) + } +} + +func randBuf(size int) []byte { + b := make([]byte, size) + initRandom(b) + return b +} + +func TestCopyInSuccess(t *testing.T) { + // Test that CopyIn does not return an error when all pages are accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := CopyIn(b, unsafe.Pointer(&a[0])) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestCopyOutSuccess(t *testing.T) { + // Test that CopyOut does not return an error when all pages are + // accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := CopyOut(unsafe.Pointer(&b[0]), a) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestCopySuccess(t *testing.T) { + // Test that Copy does not return an error when all pages are accessible. + const bufLen = 8192 + a := randBuf(bufLen) + b := make([]byte, bufLen) + + n, err := Copy(unsafe.Pointer(&b[0]), unsafe.Pointer(&a[0]), bufLen) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestZeroOutSuccess(t *testing.T) { + // Test that ZeroOut does not return an error when all pages are + // accessible. + const bufLen = 8192 + a := make([]byte, bufLen) + b := randBuf(bufLen) + + n, err := ZeroOut(unsafe.Pointer(&b[0]), bufLen) + if n != bufLen { + t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !bytes.Equal(a, b) { + t.Errorf("Buffers are not equal when they should be: %v %v", a, b) + } +} + +func TestSwapUint32Success(t *testing.T) { + // Test that SwapUint32 does not return an error when the page is + // accessible. + before := uint32(rand.Int31()) + after := uint32(rand.Int31()) + val := before + + old, err := SwapUint32(unsafe.Pointer(&val), after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if val != after { + t.Errorf("Unexpected new value: got %v, want %v", val, after) + } +} + +func TestSwapUint32AlignmentError(t *testing.T) { + // Test that SwapUint32 returns an AlignmentError when passed an unaligned + // address. + data := new(struct{ val uint64 }) + addr := uintptr(unsafe.Pointer(&data.val)) + 1 + want := AlignmentError{Addr: addr, Alignment: 4} + if _, err := SwapUint32(unsafe.Pointer(addr), 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +func TestSwapUint64Success(t *testing.T) { + // Test that SwapUint64 does not return an error when the page is + // accessible. + before := uint64(rand.Int63()) + after := uint64(rand.Int63()) + // "The first word in ... an allocated struct or slice can be relied upon + // to be 64-bit aligned." - sync/atomic docs + data := new(struct{ val uint64 }) + data.val = before + + old, err := SwapUint64(unsafe.Pointer(&data.val), after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if data.val != after { + t.Errorf("Unexpected new value: got %v, want %v", data.val, after) + } +} + +func TestSwapUint64AlignmentError(t *testing.T) { + // Test that SwapUint64 returns an AlignmentError when passed an unaligned + // address. + data := new(struct{ val1, val2 uint64 }) + addr := uintptr(unsafe.Pointer(&data.val1)) + 1 + want := AlignmentError{Addr: addr, Alignment: 8} + if _, err := SwapUint64(unsafe.Pointer(addr), 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +func TestCompareAndSwapUint32Success(t *testing.T) { + // Test that CompareAndSwapUint32 does not return an error when the page is + // accessible. + before := uint32(rand.Int31()) + after := uint32(rand.Int31()) + val := before + + old, err := CompareAndSwapUint32(unsafe.Pointer(&val), before, after) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if old != before { + t.Errorf("Unexpected old value: got %v, want %v", old, before) + } + if val != after { + t.Errorf("Unexpected new value: got %v, want %v", val, after) + } +} + +func TestCompareAndSwapUint32AlignmentError(t *testing.T) { + // Test that CompareAndSwapUint32 returns an AlignmentError when passed an + // unaligned address. + data := new(struct{ val uint64 }) + addr := uintptr(unsafe.Pointer(&data.val)) + 1 + want := AlignmentError{Addr: addr, Alignment: 4} + if _, err := CompareAndSwapUint32(unsafe.Pointer(addr), 0, 1); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } +} + +// withSegvErrorTestMapping calls fn with a two-page mapping. The first page +// contains random data, and the second page generates SIGSEGV when accessed. +func withSegvErrorTestMapping(t *testing.T, fn func(m []byte)) { + mapping, err := syscall.Mmap(-1, 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + } + defer syscall.Munmap(mapping) + if err := syscall.Mprotect(mapping[pageSize:], syscall.PROT_NONE); err != nil { + t.Fatalf("Mprotect failed: %v", err) + } + initRandom(mapping[:pageSize]) + + fn(mapping) +} + +// withBusErrorTestMapping calls fn with a two-page mapping. The first page +// contains random data, and the second page generates SIGBUS when accessed. +func withBusErrorTestMapping(t *testing.T, fn func(m []byte)) { + f, err := ioutil.TempFile("", "sigbus_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + defer f.Close() + if err := f.Truncate(pageSize); err != nil { + t.Fatalf("Truncate failed: %v", err) + } + mapping, err := syscall.Mmap(int(f.Fd()), 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + } + defer syscall.Munmap(mapping) + initRandom(mapping[:pageSize]) + + fn(mapping) +} + +func TestCopyInSegvError(t *testing.T) { + // Test that CopyIn returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := CopyIn(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyInBusError(t *testing.T) { + // Test that CopyIn returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := CopyIn(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyOutSegvError(t *testing.T) { + // Test that CopyOut returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := CopyOut(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyOutBusError(t *testing.T) { + // Test that CopyOut returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := CopyOut(dst, src) + if n != bytesBeforeFault { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopySourceSegvError(t *testing.T) { + // Test that Copy returns a SegvError when copying from a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopySourceBusError(t *testing.T) { + // Test that Copy returns a BusError when copying from a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + dst := randBuf(pageSize) + n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyDestinationSegvError(t *testing.T) { + // Test that Copy returns a SegvError when copying to a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestCopyDestinationBusError(t *testing.T) { + // Test that Copy returns a BusError when copying to a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + src := randBuf(pageSize) + n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { + t.Errorf("Buffers are not equal when they should be: %v %v", got, want) + } + }) + }) + } +} + +func TestZeroOutSegvError(t *testing.T) { + // Test that ZeroOut returns a SegvError when reaching a page that signals + // SIGSEGV. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + n, err := ZeroOut(dst, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) + } + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { + t.Errorf("Non-zero bytes in written part of mapping: %v", got) + } + }) + }) + } +} + +func TestZeroOutBusError(t *testing.T) { + // Test that ZeroOut returns a BusError when reaching a page that signals + // SIGBUS. + for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { + t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) + n, err := ZeroOut(dst, pageSize) + if n != uintptr(bytesBeforeFault) { + t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) + } + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { + t.Errorf("Non-zero bytes in written part of mapping: %v", got) + } + }) + }) + } +} + +func TestSwapUint32SegvError(t *testing.T) { + // Test that SwapUint32 returns a SegvError when reaching a page that + // signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint32(unsafe.Pointer(secondPage), 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint32BusError(t *testing.T) { + // Test that SwapUint32 returns a BusError when reaching a page that + // signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint32(unsafe.Pointer(secondPage), 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint64SegvError(t *testing.T) { + // Test that SwapUint64 returns a SegvError when reaching a page that + // signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint64(unsafe.Pointer(secondPage), 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestSwapUint64BusError(t *testing.T) { + // Test that SwapUint64 returns a BusError when reaching a page that + // signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := SwapUint64(unsafe.Pointer(secondPage), 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestCompareAndSwapUint32SegvError(t *testing.T) { + // Test that CompareAndSwapUint32 returns a SegvError when reaching a page + // that signals SIGSEGV. + withSegvErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) + if want := (SegvError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func TestCompareAndSwapUint32BusError(t *testing.T) { + // Test that CompareAndSwapUint32 returns a BusError when reaching a page + // that signals SIGBUS. + withBusErrorTestMapping(t, func(mapping []byte) { + secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize + _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) + if want := (BusError{secondPage}); err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + } + }) +} + +func testCopy(dst, src []byte) (panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + debug.SetPanicOnFault(true) + copy(dst, src) + return +} + +func TestSegVOnMemmove(t *testing.T) { + // Test that SIGSEGVs received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + a, err := syscall.Mmap(-1, 0, bufLen, syscall.PROT_NONE, syscall.MAP_ANON|syscall.MAP_PRIVATE) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer syscall.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} + +func TestSigbusOnMemmove(t *testing.T) { + // Test that SIGBUS received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + f, err := ioutil.TempFile("", "sigbus_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + os.Remove(f.Name()) + defer f.Close() + + a, err := syscall.Mmap(int(f.Fd()), 0, bufLen, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer syscall.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go new file mode 100644 index 000000000..eef028e68 --- /dev/null +++ b/pkg/safecopy/safecopy_unsafe.go @@ -0,0 +1,335 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safecopy + +import ( + "fmt" + "syscall" + "unsafe" +) + +// maxRegisterSize is the maximum register size used in memcpy and memclr. It +// is used to decide by how much to rewind the copy (for memcpy) or zeroing +// (for memclr) before proceeding. +const maxRegisterSize = 16 + +// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received +// during the copy, it returns the address that caused the fault and the number +// of the signal that was received. Otherwise, it returns an unspecified address +// and a signal number of 0. +// +// Data is copied in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully copied. +// +//go:noescape +func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) + +// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS +// signal is received during the write, it returns the address that caused the +// fault and the number of the signal that was received. Otherwise, it returns +// an unspecified address and a signal number of 0. +// +// Data is written in order, such that if a fault happens at address p, it is +// safe to assume that all data before p-maxRegisterSize has already been +// successfully written. +// +//go:noescape +func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) + +// swapUint32 atomically stores new into *ptr and returns (the previous *ptr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: ptr must be aligned to a 4-byte boundary. +// +//go:noescape +func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) + +// swapUint64 atomically stores new into *ptr and returns (the previous *ptr +// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the +// value of old is unspecified, and sig is the number of the signal that was +// received. +// +// Preconditions: ptr must be aligned to a 8-byte boundary. +// +//go:noescape +func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) + +// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns +// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is +// received during the operation, the value of prev is unspecified, and sig is +// the number of the signal that was received. +// +// Preconditions: ptr must be aligned to a 4-byte boundary. +// +//go:noescape +func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) + +// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It +// may fail with SIGSEGV or SIGBUS if it is received while reading from ptr. +// +// Preconditions: ptr must be aligned to a 4-byte boundary. +// +//go:noescape +func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) + +// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes +// copied and an error if SIGSEGV or SIGBUS is received while reading from src. +func CopyIn(dst []byte, src unsafe.Pointer) (int, error) { + toCopy := uintptr(len(dst)) + if len(dst) == 0 { + return 0, nil + } + + fault, sig := memcpy(unsafe.Pointer(&dst[0]), src, toCopy) + if sig == 0 { + return len(dst), nil + } + + faultN, srcN := uintptr(fault), uintptr(src) + if faultN < srcN || faultN >= srcN+toCopy { + panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, faultN, srcN, srcN+toCopy)) + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + var done int + if faultN-srcN > maxRegisterSize { + done = int(faultN - srcN - maxRegisterSize) + } + n, err := CopyIn(dst[done:int(faultN-srcN)], unsafe.Pointer(srcN+uintptr(done))) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// CopyOut copies len(src) bytes from src to dst. If returns the number of +// bytes done and an error if SIGSEGV or SIGBUS is received while writing to +// dst. +func CopyOut(dst unsafe.Pointer, src []byte) (int, error) { + toCopy := uintptr(len(src)) + if toCopy == 0 { + return 0, nil + } + + fault, sig := memcpy(dst, unsafe.Pointer(&src[0]), toCopy) + if sig == 0 { + return len(src), nil + } + + faultN, dstN := uintptr(fault), uintptr(dst) + if faultN < dstN || faultN >= dstN+toCopy { + panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toCopy)) + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + var done int + if faultN-dstN > maxRegisterSize { + done = int(faultN - dstN - maxRegisterSize) + } + n, err := CopyOut(unsafe.Pointer(dstN+uintptr(done)), src[done:int(faultN-dstN)]) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// Copy copies toCopy bytes from src to dst. It returns the number of bytes +// copied and an error if SIGSEGV or SIGBUS is received while reading from src +// or writing to dst. +// +// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap, +// the resulting contents of dst are unspecified. +func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) { + if toCopy == 0 { + return 0, nil + } + + fault, sig := memcpy(dst, src, toCopy) + if sig == 0 { + return toCopy, nil + } + + // Did the fault occur while reading from src or writing to dst? + faultN, srcN, dstN := uintptr(fault), uintptr(src), uintptr(dst) + faultAfterSrc := ^uintptr(0) + if faultN >= srcN { + faultAfterSrc = faultN - srcN + } + faultAfterDst := ^uintptr(0) + if faultN >= dstN { + faultAfterDst = faultN - dstN + } + if faultAfterSrc >= toCopy && faultAfterDst >= toCopy { + panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, faultN, srcN, srcN+toCopy, dstN, dstN+toCopy)) + } + faultedAfter := faultAfterSrc + if faultedAfter > faultAfterDst { + faultedAfter = faultAfterDst + } + + // memcpy might have ended the copy up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to copy up to the fault. + var done uintptr + if faultedAfter > maxRegisterSize { + done = faultedAfter - maxRegisterSize + } + n, err := Copy(unsafe.Pointer(dstN+done), unsafe.Pointer(srcN+done), faultedAfter-done) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes +// written and an error if SIGSEGV or SIGBUS is received while writing to dst. +func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) { + if toZero == 0 { + return 0, nil + } + + fault, sig := memclr(dst, toZero) + if sig == 0 { + return toZero, nil + } + + faultN, dstN := uintptr(fault), uintptr(dst) + if faultN < dstN || faultN >= dstN+toZero { + panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toZero)) + } + + // memclr might have ended the write up to maxRegisterSize bytes before + // fault, if an instruction caused a memory access that straddled two + // pages, and the second one faulted. Try to write up to the fault. + var done uintptr + if faultN-dstN > maxRegisterSize { + done = faultN - dstN - maxRegisterSize + } + n, err := ZeroOut(unsafe.Pointer(dstN+done), faultN-dstN-done) + done += n + if err != nil { + return done, err + } + return done, errorFromFaultSignal(fault, sig) +} + +// SwapUint32 is equivalent to sync/atomic.SwapUint32, except that it returns +// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is +// not aligned to a 4-byte boundary. +func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) { + if addr := uintptr(ptr); addr&3 != 0 { + return 0, AlignmentError{addr, 4} + } + old, sig := swapUint32(ptr, new) + return old, errorFromFaultSignal(ptr, sig) +} + +// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns +// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is +// not aligned to an 8-byte boundary. +func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) { + if addr := uintptr(ptr); addr&7 != 0 { + return 0, AlignmentError{addr, 8} + } + old, sig := swapUint64(ptr, new) + return old, errorFromFaultSignal(ptr, sig) +} + +// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32, +// except that it returns an error if SIGSEGV or SIGBUS is received while +// accessing ptr, or if ptr is not aligned to a 4-byte boundary. +func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) { + if addr := uintptr(ptr); addr&3 != 0 { + return 0, AlignmentError{addr, 4} + } + prev, sig := compareAndSwapUint32(ptr, old, new) + return prev, errorFromFaultSignal(ptr, sig) +} + +// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It +// may fail with SIGSEGV or SIGBUS if it is received while reading from ptr. +// +// Preconditions: ptr must be aligned to a 4-byte boundary. +func LoadUint32(ptr unsafe.Pointer) (uint32, error) { + if addr := uintptr(ptr); addr&3 != 0 { + return 0, AlignmentError{addr, 4} + } + val, sig := loadUint32(ptr) + return val, errorFromFaultSignal(ptr, sig) +} + +func errorFromFaultSignal(addr unsafe.Pointer, sig int32) error { + switch sig { + case 0: + return nil + case int32(syscall.SIGSEGV): + return SegvError{uintptr(addr)} + case int32(syscall.SIGBUS): + return BusError{uintptr(addr)} + default: + panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr)) + } +} + +// ReplaceSignalHandler replaces the existing signal handler for the provided +// signal with the one that handles faults in safecopy-protected functions. +// +// It stores the value of the previously set handler in previous. +// +// This function will be called on initialization in order to install safecopy +// handlers for appropriate signals. These handlers will call the previous +// handler however, and if this is function is being used externally then the +// same courtesy is expected. +func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr) error { + var sa struct { + handler uintptr + flags uint64 + restorer uintptr + mask uint64 + } + const maskLen = 8 + + // Get the existing signal handler information, and save the current + // handler. Once we replace it, we will use this pointer to fall back to + // it when we receive other signals. + if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { + return e + } + + // Fail if there isn't a previous handler. + if sa.handler == 0 { + return fmt.Errorf("previous handler for signal %x isn't set", sig) + } + + *previous = sa.handler + + // Install our own handler. + sa.handler = handler + if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { + return e + } + + return nil +} diff --git a/pkg/safecopy/sighandler_amd64.s b/pkg/safecopy/sighandler_amd64.s new file mode 100644 index 000000000..475ae48e9 --- /dev/null +++ b/pkg/safecopy/sighandler_amd64.s @@ -0,0 +1,133 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// The signals handled by sigHandler. +#define SIGBUS 7 +#define SIGSEGV 11 + +// Offsets to the registers in context->uc_mcontext.gregs[]. +#define REG_RDI 0x68 +#define REG_RAX 0x90 +#define REG_IP 0xa8 + +// Offset to the si_addr field of siginfo. +#define SI_CODE 0x08 +#define SI_ADDR 0x10 + +// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must +// not be set up as a handler to any other signals. +// +// If the instruction causing the signal is within a safecopy-protected +// function, the signal is handled such that execution resumes in the +// appropriate fault handling stub with AX containing the faulting address and +// DI containing the signal number. Otherwise control is transferred to the +// previously configured signal handler (savedSigSegvHandler or +// savedSigBusHandler). +// +// This function cannot be written in go because it runs whenever a signal is +// received by the thread (preempting whatever was running), which includes when +// garbage collector has stopped or isn't expecting any interactions (like +// barriers). +// +// The arguments are the following: +// DI - The signal number. +// SI - Pointer to siginfo_t structure. +// DX - Pointer to ucontext structure. +TEXT ·signalHandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel. + MOVQ $0x0, CX + CMPL CX, SI_CODE(SI) + JGE original_handler + + // Check if RIP is within the area we care about. + MOVQ REG_IP(DX), CX + CMPQ CX, ·memcpyBegin(SB) + JB not_memcpy + CMPQ CX, ·memcpyEnd(SB) + JAE not_memcpy + + // Modify the context such that execution will resume in the fault + // handler. + LEAQ handleMemcpyFault(SB), CX + JMP handle_fault + +not_memcpy: + CMPQ CX, ·memclrBegin(SB) + JB not_memclr + CMPQ CX, ·memclrEnd(SB) + JAE not_memclr + + LEAQ handleMemclrFault(SB), CX + JMP handle_fault + +not_memclr: + CMPQ CX, ·swapUint32Begin(SB) + JB not_swapuint32 + CMPQ CX, ·swapUint32End(SB) + JAE not_swapuint32 + + LEAQ handleSwapUint32Fault(SB), CX + JMP handle_fault + +not_swapuint32: + CMPQ CX, ·swapUint64Begin(SB) + JB not_swapuint64 + CMPQ CX, ·swapUint64End(SB) + JAE not_swapuint64 + + LEAQ handleSwapUint64Fault(SB), CX + JMP handle_fault + +not_swapuint64: + CMPQ CX, ·compareAndSwapUint32Begin(SB) + JB not_casuint32 + CMPQ CX, ·compareAndSwapUint32End(SB) + JAE not_casuint32 + + LEAQ handleCompareAndSwapUint32Fault(SB), CX + JMP handle_fault + +not_casuint32: + CMPQ CX, ·loadUint32Begin(SB) + JB not_loaduint32 + CMPQ CX, ·loadUint32End(SB) + JAE not_loaduint32 + + LEAQ handleLoadUint32Fault(SB), CX + JMP handle_fault + +not_loaduint32: +original_handler: + // Jump to the previous signal handler, which is likely the golang one. + XORQ CX, CX + MOVQ ·savedSigBusHandler(SB), AX + CMPL DI, $SIGSEGV + CMOVQEQ ·savedSigSegVHandler(SB), AX + JMP AX + +handle_fault: + // Entered with the address of the fault handler in RCX; store it in + // RIP. + MOVQ CX, REG_IP(DX) + + // Store the faulting address in RAX. + MOVQ SI_ADDR(SI), CX + MOVQ CX, REG_RAX(DX) + + // Store the signal number in EDI. + MOVL DI, REG_RDI(DX) + + RET diff --git a/pkg/safecopy/sighandler_arm64.s b/pkg/safecopy/sighandler_arm64.s new file mode 100644 index 000000000..53e4ac2c1 --- /dev/null +++ b/pkg/safecopy/sighandler_arm64.s @@ -0,0 +1,143 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// The signals handled by sigHandler. +#define SIGBUS 7 +#define SIGSEGV 11 + +// Offsets to the registers in context->uc_mcontext.gregs[]. +#define REG_R0 0xB8 +#define REG_R1 0xC0 +#define REG_PC 0x1B8 + +// Offset to the si_addr field of siginfo. +#define SI_CODE 0x08 +#define SI_ADDR 0x10 + +// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must +// not be set up as a handler to any other signals. +// +// If the instruction causing the signal is within a safecopy-protected +// function, the signal is handled such that execution resumes in the +// appropriate fault handling stub with R0 containing the faulting address and +// R1 containing the signal number. Otherwise control is transferred to the +// previously configured signal handler (savedSigSegvHandler or +// savedSigBusHandler). +// +// This function cannot be written in go because it runs whenever a signal is +// received by the thread (preempting whatever was running), which includes when +// garbage collector has stopped or isn't expecting any interactions (like +// barriers). +// +// The arguments are the following: +// R0 - The signal number. +// R1 - Pointer to siginfo_t structure. +// R2 - Pointer to ucontext structure. +TEXT ·signalHandler(SB),NOSPLIT,$0 + // Check if the signal is from the kernel, si_code > 0 means a kernel signal. + MOVD SI_CODE(R1), R7 + CMPW $0x0, R7 + BLE original_handler + + // Check if PC is within the area we care about. + MOVD REG_PC(R2), R7 + MOVD ·memcpyBegin(SB), R8 + CMP R8, R7 + BLO not_memcpy + MOVD ·memcpyEnd(SB), R8 + CMP R8, R7 + BHS not_memcpy + + // Modify the context such that execution will resume in the fault handler. + MOVD $handleMemcpyFault(SB), R7 + B handle_fault + +not_memcpy: + MOVD ·memclrBegin(SB), R8 + CMP R8, R7 + BLO not_memclr + MOVD ·memclrEnd(SB), R8 + CMP R8, R7 + BHS not_memclr + + MOVD $handleMemclrFault(SB), R7 + B handle_fault + +not_memclr: + MOVD ·swapUint32Begin(SB), R8 + CMP R8, R7 + BLO not_swapuint32 + MOVD ·swapUint32End(SB), R8 + CMP R8, R7 + BHS not_swapuint32 + + MOVD $handleSwapUint32Fault(SB), R7 + B handle_fault + +not_swapuint32: + MOVD ·swapUint64Begin(SB), R8 + CMP R8, R7 + BLO not_swapuint64 + MOVD ·swapUint64End(SB), R8 + CMP R8, R7 + BHS not_swapuint64 + + MOVD $handleSwapUint64Fault(SB), R7 + B handle_fault + +not_swapuint64: + MOVD ·compareAndSwapUint32Begin(SB), R8 + CMP R8, R7 + BLO not_casuint32 + MOVD ·compareAndSwapUint32End(SB), R8 + CMP R8, R7 + BHS not_casuint32 + + MOVD $handleCompareAndSwapUint32Fault(SB), R7 + B handle_fault + +not_casuint32: + MOVD ·loadUint32Begin(SB), R8 + CMP R8, R7 + BLO not_loaduint32 + MOVD ·loadUint32End(SB), R8 + CMP R8, R7 + BHS not_loaduint32 + + MOVD $handleLoadUint32Fault(SB), R7 + B handle_fault + +not_loaduint32: +original_handler: + // Jump to the previous signal handler, which is likely the golang one. + MOVD ·savedSigBusHandler(SB), R7 + MOVD ·savedSigSegVHandler(SB), R8 + CMPW $SIGSEGV, R0 + CSEL EQ, R8, R7, R7 + B (R7) + +handle_fault: + // Entered with the address of the fault handler in R7; store it in PC. + MOVD R7, REG_PC(R2) + + // Store the faulting address in R0. + MOVD SI_ADDR(R1), R7 + MOVD R7, REG_R0(R2) + + // Store the signal number in R1. + MOVW R0, REG_R1(R2) + + RET diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD new file mode 100644 index 000000000..ce30382ab --- /dev/null +++ b/pkg/safemem/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "safemem", + srcs = [ + "block_unsafe.go", + "io.go", + "safemem.go", + "seq_unsafe.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/safecopy", + ], +) + +go_test( + name = "safemem_test", + size = "small", + srcs = [ + "io_test.go", + "seq_test.go", + ], + library = ":safemem", +) diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go new file mode 100644 index 000000000..e7fd30743 --- /dev/null +++ b/pkg/safemem/block_unsafe.go @@ -0,0 +1,279 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "fmt" + "reflect" + "unsafe" + + "gvisor.dev/gvisor/pkg/safecopy" +) + +// A Block is a range of contiguous bytes, similar to []byte but with the +// following differences: +// +// - The memory represented by a Block may require the use of safecopy to +// access. +// +// - Block does not carry a capacity and cannot be expanded. +// +// Blocks are immutable and may be copied by value. The zero value of Block +// represents an empty range, analogous to a nil []byte. +type Block struct { + // [start, start+length) is the represented memory. + // + // start is an unsafe.Pointer to ensure that Block prevents the represented + // memory from being garbage-collected. + start unsafe.Pointer + length int + + // needSafecopy is true if accessing the represented memory requires the + // use of safecopy. + needSafecopy bool +} + +// BlockFromSafeSlice returns a Block equivalent to slice, which is safe to +// access without safecopy. +func BlockFromSafeSlice(slice []byte) Block { + return blockFromSlice(slice, false) +} + +// BlockFromUnsafeSlice returns a Block equivalent to bs, which is not safe to +// access without safecopy. +func BlockFromUnsafeSlice(slice []byte) Block { + return blockFromSlice(slice, true) +} + +func blockFromSlice(slice []byte, needSafecopy bool) Block { + if len(slice) == 0 { + return Block{} + } + return Block{ + start: unsafe.Pointer(&slice[0]), + length: len(slice), + needSafecopy: needSafecopy, + } +} + +// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is +// safe to access without safecopy. +// +// Preconditions: ptr+len does not overflow. +func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block { + return blockFromPointer(ptr, len, false) +} + +// BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which +// is not safe to access without safecopy. +// +// Preconditions: ptr+len does not overflow. +func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block { + return blockFromPointer(ptr, len, true) +} + +func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block { + if uptr := uintptr(ptr); uptr+uintptr(len) < uptr { + panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len)) + } + return Block{ + start: ptr, + length: len, + needSafecopy: needSafecopy, + } +} + +// DropFirst returns a Block equivalent to b, but with the first n bytes +// omitted. It is analogous to the [n:] operation on a slice, except that if n +// > b.Len(), DropFirst returns an empty Block instead of panicking. +// +// Preconditions: n >= 0. +func (b Block) DropFirst(n int) Block { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return b.DropFirst64(uint64(n)) +} + +// DropFirst64 is equivalent to DropFirst but takes a uint64. +func (b Block) DropFirst64(n uint64) Block { + if n >= uint64(b.length) { + return Block{} + } + return Block{ + start: unsafe.Pointer(uintptr(b.start) + uintptr(n)), + length: b.length - int(n), + needSafecopy: b.needSafecopy, + } +} + +// TakeFirst returns a Block equivalent to the first n bytes of b. It is +// analogous to the [:n] operation on a slice, except that if n > b.Len(), +// TakeFirst returns a copy of b instead of panicking. +// +// Preconditions: n >= 0. +func (b Block) TakeFirst(n int) Block { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return b.TakeFirst64(uint64(n)) +} + +// TakeFirst64 is equivalent to TakeFirst but takes a uint64. +func (b Block) TakeFirst64(n uint64) Block { + if n == 0 { + return Block{} + } + if n >= uint64(b.length) { + return b + } + return Block{ + start: b.start, + length: int(n), + needSafecopy: b.needSafecopy, + } +} + +// ToSlice returns a []byte equivalent to b. +func (b Block) ToSlice() []byte { + var bs []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) + hdr.Data = uintptr(b.start) + hdr.Len = b.length + hdr.Cap = b.length + return bs +} + +// Addr returns b's start address as a uintptr. It returns uintptr instead of +// unsafe.Pointer so that code using safemem cannot obtain unsafe.Pointers +// without importing the unsafe package explicitly. +// +// Note that a uintptr is not recognized as a pointer by the garbage collector, +// such that if there are no uses of b after a call to b.Addr() and the address +// is to Go-managed memory, the returned uintptr does not prevent garbage +// collection of the pointee. +func (b Block) Addr() uintptr { + return uintptr(b.start) +} + +// Len returns b's length in bytes. +func (b Block) Len() int { + return b.length +} + +// NeedSafecopy returns true if accessing b.ToSlice() requires the use of safecopy. +func (b Block) NeedSafecopy() bool { + return b.needSafecopy +} + +// String implements fmt.Stringer.String. +func (b Block) String() string { + if uintptr(b.start) == 0 && b.length == 0 { + return "" + } + var suffix string + if b.needSafecopy { + suffix = "*" + } + return fmt.Sprintf("[%#x-%#x)%s", uintptr(b.start), uintptr(b.start)+uintptr(b.length), suffix) +} + +// Copy copies src.Len() or dst.Len() bytes, whichever is less, from src +// to dst and returns the number of bytes copied. +// +// If src and dst overlap, the data stored in dst is unspecified. +func Copy(dst, src Block) (int, error) { + if !dst.needSafecopy && !src.needSafecopy { + return copy(dst.ToSlice(), src.ToSlice()), nil + } + + n := dst.length + if n > src.length { + n = src.length + } + if n == 0 { + return 0, nil + } + + switch { + case dst.needSafecopy && !src.needSafecopy: + return safecopy.CopyOut(dst.start, src.TakeFirst(n).ToSlice()) + case !dst.needSafecopy && src.needSafecopy: + return safecopy.CopyIn(dst.TakeFirst(n).ToSlice(), src.start) + case dst.needSafecopy && src.needSafecopy: + n64, err := safecopy.Copy(dst.start, src.start, uintptr(n)) + return int(n64), err + default: + panic("unreachable") + } +} + +// Zero sets all bytes in dst to 0 and returns the number of bytes zeroed. +func Zero(dst Block) (int, error) { + if !dst.needSafecopy { + bs := dst.ToSlice() + for i := range bs { + bs[i] = 0 + } + return len(bs), nil + } + + n64, err := safecopy.ZeroOut(dst.start, uintptr(dst.length)) + return int(n64), err +} + +// Safecopy atomics are no slower than non-safecopy atomics, so use the former +// even when !b.needSafecopy to get consistent alignment checking. + +// SwapUint32 invokes safecopy.SwapUint32 on the first 4 bytes of b. +// +// Preconditions: b.Len() >= 4. +func SwapUint32(b Block, new uint32) (uint32, error) { + if b.length < 4 { + panic(fmt.Sprintf("insufficient length: %d", b.length)) + } + return safecopy.SwapUint32(b.start, new) +} + +// SwapUint64 invokes safecopy.SwapUint64 on the first 8 bytes of b. +// +// Preconditions: b.Len() >= 8. +func SwapUint64(b Block, new uint64) (uint64, error) { + if b.length < 8 { + panic(fmt.Sprintf("insufficient length: %d", b.length)) + } + return safecopy.SwapUint64(b.start, new) +} + +// CompareAndSwapUint32 invokes safecopy.CompareAndSwapUint32 on the first 4 +// bytes of b. +// +// Preconditions: b.Len() >= 4. +func CompareAndSwapUint32(b Block, old, new uint32) (uint32, error) { + if b.length < 4 { + panic(fmt.Sprintf("insufficient length: %d", b.length)) + } + return safecopy.CompareAndSwapUint32(b.start, old, new) +} + +// LoadUint32 invokes safecopy.LoadUint32 on the first 4 bytes of b. +// +// Preconditions: b.Len() >= 4. +func LoadUint32(b Block) (uint32, error) { + if b.length < 4 { + panic(fmt.Sprintf("insufficient length: %d", b.length)) + } + return safecopy.LoadUint32(b.start) +} diff --git a/pkg/safemem/io.go b/pkg/safemem/io.go new file mode 100644 index 000000000..f039a5c34 --- /dev/null +++ b/pkg/safemem/io.go @@ -0,0 +1,392 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "errors" + "io" + "math" +) + +// ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write +// beyond the end of the BlockSeq. +var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq") + +// Reader represents a streaming byte source like io.Reader. +type Reader interface { + // ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the + // number of bytes read. It may return a partial read without an error + // (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a + // full read with an error (i.e. (dsts.NumBytes(), err) where err != nil); + // note that this differs from io.Reader.Read (in particular, io.EOF should + // not be returned if ReadToBlocks successfully reads dsts.NumBytes() + // bytes.) + ReadToBlocks(dsts BlockSeq) (uint64, error) +} + +// Writer represents a streaming byte sink like io.Writer. +type Writer interface { + // WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns + // the number of bytes written. It may return a partial write without an + // error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not + // return a full write with an error (i.e. srcs.NumBytes(), err) where err + // != nil). + WriteFromBlocks(srcs BlockSeq) (uint64, error) +} + +// ReadFullToBlocks repeatedly invokes r.ReadToBlocks until dsts.NumBytes() +// bytes have been read or ReadToBlocks returns an error. +func ReadFullToBlocks(r Reader, dsts BlockSeq) (uint64, error) { + var done uint64 + for !dsts.IsEmpty() { + n, err := r.ReadToBlocks(dsts) + done += n + if err != nil { + return done, err + } + dsts = dsts.DropFirst64(n) + } + return done, nil +} + +// WriteFullFromBlocks repeatedly invokes w.WriteFromBlocks until +// srcs.NumBytes() bytes have been written or WriteFromBlocks returns an error. +func WriteFullFromBlocks(w Writer, srcs BlockSeq) (uint64, error) { + var done uint64 + for !srcs.IsEmpty() { + n, err := w.WriteFromBlocks(srcs) + done += n + if err != nil { + return done, err + } + srcs = srcs.DropFirst64(n) + } + return done, nil +} + +// BlockSeqReader implements Reader by reading from a BlockSeq. +type BlockSeqReader struct { + Blocks BlockSeq +} + +// ReadToBlocks implements Reader.ReadToBlocks. +func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { + n, err := CopySeq(dsts, r.Blocks) + r.Blocks = r.Blocks.DropFirst64(n) + if err != nil { + return n, err + } + if n < dsts.NumBytes() { + return n, io.EOF + } + return n, nil +} + +// BlockSeqWriter implements Writer by writing to a BlockSeq. +type BlockSeqWriter struct { + Blocks BlockSeq +} + +// WriteFromBlocks implements Writer.WriteFromBlocks. +func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { + n, err := CopySeq(w.Blocks, srcs) + w.Blocks = w.Blocks.DropFirst64(n) + if err != nil { + return n, err + } + if n < srcs.NumBytes() { + return n, ErrEndOfBlockSeq + } + return n, nil +} + +// ReaderFunc implements Reader for a function with the semantics of +// Reader.ReadToBlocks. +type ReaderFunc func(dsts BlockSeq) (uint64, error) + +// ReadToBlocks implements Reader.ReadToBlocks. +func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { + return f(dsts) +} + +// WriterFunc implements Writer for a function with the semantics of +// Writer.WriteFromBlocks. +type WriterFunc func(srcs BlockSeq) (uint64, error) + +// WriteFromBlocks implements Writer.WriteFromBlocks. +func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { + return f(srcs) +} + +// ToIOReader implements io.Reader for a (safemem.)Reader. +// +// ToIOReader will return a successful partial read iff Reader.ReadToBlocks does +// so. +type ToIOReader struct { + Reader Reader +} + +// Read implements io.Reader.Read. +func (r ToIOReader) Read(dst []byte) (int, error) { + n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst))) + return int(n), err +} + +// ToIOWriter implements io.Writer for a (safemem.)Writer. +type ToIOWriter struct { + Writer Writer +} + +// Write implements io.Writer.Write. +func (w ToIOWriter) Write(src []byte) (int, error) { + // io.Writer does not permit partial writes. + n, err := WriteFullFromBlocks(w.Writer, BlockSeqOf(BlockFromSafeSlice(src))) + return int(n), err +} + +// FromIOReader implements Reader for an io.Reader by repeatedly invoking +// io.Reader.Read until it returns an error or partial read. This is not +// thread-safe. +// +// FromIOReader will return a successful partial read iff Reader.Read does so. +type FromIOReader struct { + Reader io.Reader +} + +// ReadToBlocks implements Reader.ReadToBlocks. +func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { + var buf []byte + var done uint64 + for !dsts.IsEmpty() { + dst := dsts.Head() + var n int + var err error + n, buf, err = r.readToBlock(dst, buf) + done += uint64(n) + if n != dst.Len() { + return done, err + } + dsts = dsts.Tail() + if err != nil { + if dsts.IsEmpty() && err == io.EOF { + return done, nil + } + return done, err + } + } + return done, nil +} + +func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) { + // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require + // safecopy. + if !dst.NeedSafecopy() { + n, err := r.Reader.Read(dst.ToSlice()) + return n, buf, err + } + if len(buf) < dst.Len() { + buf = make([]byte, dst.Len()) + } + rn, rerr := r.Reader.Read(buf[:dst.Len()]) + wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) + if wberr != nil { + return wbn, buf, wberr + } + return wbn, buf, rerr +} + +// FromIOReaderAt implements Reader for an io.ReaderAt. Does not repeatedly +// invoke io.ReaderAt.ReadAt because ReadAt is more strict than Read. A partial +// read indicates an error. This is not thread-safe. +type FromIOReaderAt struct { + ReaderAt io.ReaderAt + Offset int64 +} + +// ReadToBlocks implements Reader.ReadToBlocks. +func (r FromIOReaderAt) ReadToBlocks(dsts BlockSeq) (uint64, error) { + var buf []byte + var done uint64 + for !dsts.IsEmpty() { + dst := dsts.Head() + var n int + var err error + n, buf, err = r.readToBlock(dst, buf) + done += uint64(n) + if n != dst.Len() { + return done, err + } + dsts = dsts.Tail() + if err != nil { + if dsts.IsEmpty() && err == io.EOF { + return done, nil + } + return done, err + } + } + return done, nil +} + +func (r FromIOReaderAt) readToBlock(dst Block, buf []byte) (int, []byte, error) { + // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require + // safecopy. + if !dst.NeedSafecopy() { + n, err := r.ReaderAt.ReadAt(dst.ToSlice(), r.Offset) + r.Offset += int64(n) + return n, buf, err + } + if len(buf) < dst.Len() { + buf = make([]byte, dst.Len()) + } + rn, rerr := r.ReaderAt.ReadAt(buf[:dst.Len()], r.Offset) + r.Offset += int64(rn) + wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) + if wberr != nil { + return wbn, buf, wberr + } + return wbn, buf, rerr +} + +// FromIOWriter implements Writer for an io.Writer by repeatedly invoking +// io.Writer.Write until it returns an error or partial write. +// +// FromIOWriter will tolerate implementations of io.Writer.Write that return +// partial writes with a nil error in contravention of io.Writer's +// requirements, since Writer is permitted to do so. FromIOWriter will return a +// successful partial write iff Writer.Write does so. +type FromIOWriter struct { + Writer io.Writer +} + +// WriteFromBlocks implements Writer.WriteFromBlocks. +func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { + var buf []byte + var done uint64 + for !srcs.IsEmpty() { + src := srcs.Head() + var n int + var err error + n, buf, err = w.writeFromBlock(src, buf) + done += uint64(n) + if n != src.Len() || err != nil { + return done, err + } + srcs = srcs.Tail() + } + return done, nil +} + +func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) { + // io.Writer isn't safecopy-aware, so we have to buffer Blocks that require + // safecopy. + if !src.NeedSafecopy() { + n, err := w.Writer.Write(src.ToSlice()) + return n, buf, err + } + if len(buf) < src.Len() { + buf = make([]byte, src.Len()) + } + bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src) + wn, werr := w.Writer.Write(buf[:bufn]) + if werr != nil { + return wn, buf, werr + } + return wn, buf, buferr +} + +// FromVecReaderFunc implements Reader for a function that reads data into a +// [][]byte and returns the number of bytes read as an int64. +type FromVecReaderFunc struct { + ReadVec func(dsts [][]byte) (int64, error) +} + +// ReadToBlocks implements Reader.ReadToBlocks. +// +// ReadToBlocks calls r.ReadVec at most once. +func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { + if dsts.IsEmpty() { + return 0, nil + } + // Ensure that we don't pass a [][]byte with a total length > MaxInt64. + dsts = dsts.TakeFirst64(uint64(math.MaxInt64)) + dstSlices := make([][]byte, 0, dsts.NumBlocks()) + // Buffer Blocks that require safecopy. + for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() { + dst := tmp.Head() + if dst.NeedSafecopy() { + dstSlices = append(dstSlices, make([]byte, dst.Len())) + } else { + dstSlices = append(dstSlices, dst.ToSlice()) + } + } + rn, rerr := r.ReadVec(dstSlices) + dsts = dsts.TakeFirst64(uint64(rn)) + var done uint64 + var i int + for !dsts.IsEmpty() { + dst := dsts.Head() + if dst.NeedSafecopy() { + n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i])) + done += uint64(n) + if err != nil { + return done, err + } + } else { + done += uint64(dst.Len()) + } + dsts = dsts.Tail() + i++ + } + return done, rerr +} + +// FromVecWriterFunc implements Writer for a function that writes data from a +// [][]byte and returns the number of bytes written. +type FromVecWriterFunc struct { + WriteVec func(srcs [][]byte) (int64, error) +} + +// WriteFromBlocks implements Writer.WriteFromBlocks. +// +// WriteFromBlocks calls w.WriteVec at most once. +func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { + if srcs.IsEmpty() { + return 0, nil + } + // Ensure that we don't pass a [][]byte with a total length > MaxInt64. + srcs = srcs.TakeFirst64(uint64(math.MaxInt64)) + srcSlices := make([][]byte, 0, srcs.NumBlocks()) + // Buffer Blocks that require safecopy. + var buferr error + for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() { + src := tmp.Head() + if src.NeedSafecopy() { + slice := make([]byte, src.Len()) + n, err := Copy(BlockFromSafeSlice(slice), src) + srcSlices = append(srcSlices, slice[:n]) + if err != nil { + buferr = err + break + } + } else { + srcSlices = append(srcSlices, src.ToSlice()) + } + } + n, err := w.WriteVec(srcSlices) + if err != nil { + return uint64(n), err + } + return uint64(n), buferr +} diff --git a/pkg/safemem/io_test.go b/pkg/safemem/io_test.go new file mode 100644 index 000000000..629741bee --- /dev/null +++ b/pkg/safemem/io_test.go @@ -0,0 +1,199 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "bytes" + "io" + "testing" +) + +func makeBlocks(slices ...[]byte) []Block { + blocks := make([]Block, 0, len(slices)) + for _, s := range slices { + blocks = append(blocks, BlockFromSafeSlice(s)) + } + return blocks +} + +func TestFromIOReaderFullRead(t *testing.T) { + r := FromIOReader{bytes.NewBufferString("foobar")} + dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) + n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) + if wantN := uint64(6); n != wantN || err != nil { + t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { + if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { + t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) + } + } +} + +type eofHidingReader struct { + Reader io.Reader +} + +func (r eofHidingReader) Read(dst []byte) (int, error) { + n, err := r.Reader.Read(dst) + if err == io.EOF { + return n, nil + } + return n, err +} + +func TestFromIOReaderPartialRead(t *testing.T) { + r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}} + dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) + n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) + // FromIOReader should stop after the eofHidingReader returns (1, nil) + // for a 3-byte read. + if wantN := uint64(4); n != wantN || err != nil { + t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} { + if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { + t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) + } + } +} + +type singleByteReader struct { + Reader io.Reader +} + +func (r singleByteReader) Read(dst []byte) (int, error) { + if len(dst) == 0 { + return r.Reader.Read(dst) + } + return r.Reader.Read(dst[:1]) +} + +func TestSingleByteReader(t *testing.T) { + r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} + dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) + n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) + // FromIOReader should stop after the singleByteReader returns (1, nil) + // for a 3-byte read. + if wantN := uint64(1); n != wantN || err != nil { + t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} { + if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { + t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) + } + } +} + +func TestReadFullToBlocks(t *testing.T) { + r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} + dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) + n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts)) + // ReadFullToBlocks should call into FromIOReader => singleByteReader + // repeatedly until dsts is exhausted. + if wantN := uint64(6); n != wantN || err != nil { + t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { + if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { + t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) + } + } +} + +func TestFromIOWriterFullWrite(t *testing.T) { + srcs := makeBlocks([]byte("foo"), []byte("bar")) + var dst bytes.Buffer + w := FromIOWriter{&dst} + n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) + if wantN := uint64(6); n != wantN || err != nil { + t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} + +type limitedWriter struct { + Writer io.Writer + Done int + Limit int +} + +func (w *limitedWriter) Write(src []byte) (int, error) { + count := len(src) + if count > (w.Limit - w.Done) { + count = w.Limit - w.Done + } + n, err := w.Writer.Write(src[:count]) + w.Done += n + return n, err +} + +func TestFromIOWriterPartialWrite(t *testing.T) { + srcs := makeBlocks([]byte("foo"), []byte("bar")) + var dst bytes.Buffer + w := FromIOWriter{&limitedWriter{&dst, 0, 4}} + n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) + // FromIOWriter should stop after the limitedWriter returns (1, nil) for a + // 3-byte write. + if wantN := uint64(4); n != wantN || err != nil { + t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} + +type singleByteWriter struct { + Writer io.Writer +} + +func (w singleByteWriter) Write(src []byte) (int, error) { + if len(src) == 0 { + return w.Writer.Write(src) + } + return w.Writer.Write(src[:1]) +} + +func TestSingleByteWriter(t *testing.T) { + srcs := makeBlocks([]byte("foo"), []byte("bar")) + var dst bytes.Buffer + w := FromIOWriter{singleByteWriter{&dst}} + n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) + // FromIOWriter should stop after the singleByteWriter returns (1, nil) + // for a 3-byte write. + if wantN := uint64(1); n != wantN || err != nil { + t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} + +func TestWriteFullToBlocks(t *testing.T) { + srcs := makeBlocks([]byte("foo"), []byte("bar")) + var dst bytes.Buffer + w := FromIOWriter{singleByteWriter{&dst}} + n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs)) + // WriteFullToBlocks should call into FromIOWriter => singleByteWriter + // repeatedly until srcs is exhausted. + if wantN := uint64(6); n != wantN || err != nil { + t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} diff --git a/pkg/safemem/safemem.go b/pkg/safemem/safemem.go new file mode 100644 index 000000000..3e70d33a2 --- /dev/null +++ b/pkg/safemem/safemem.go @@ -0,0 +1,16 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package safemem provides the Block and BlockSeq types. +package safemem diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go new file mode 100644 index 000000000..eba4bb535 --- /dev/null +++ b/pkg/safemem/seq_test.go @@ -0,0 +1,196 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "bytes" + "reflect" + "testing" +) + +type blockSeqTest struct { + desc string + + pieces []string + haveOffset bool + offset uint64 + haveLimit bool + limit uint64 + + want string +} + +func (t blockSeqTest) NonEmptyByteSlices() [][]byte { + // t is a value, so we can mutate it freely. + slices := make([][]byte, 0, len(t.pieces)) + for _, str := range t.pieces { + if t.haveOffset { + strOff := t.offset + if strOff > uint64(len(str)) { + strOff = uint64(len(str)) + } + str = str[strOff:] + t.offset -= strOff + } + if t.haveLimit { + strLim := t.limit + if strLim > uint64(len(str)) { + strLim = uint64(len(str)) + } + str = str[:strLim] + t.limit -= strLim + } + if len(str) != 0 { + slices = append(slices, []byte(str)) + } + } + return slices +} + +func (t blockSeqTest) BlockSeq() BlockSeq { + blocks := make([]Block, 0, len(t.pieces)) + for _, str := range t.pieces { + blocks = append(blocks, BlockFromSafeSlice([]byte(str))) + } + bs := BlockSeqFromSlice(blocks) + if t.haveOffset { + bs = bs.DropFirst64(t.offset) + } + if t.haveLimit { + bs = bs.TakeFirst64(t.limit) + } + return bs +} + +var blockSeqTests = []blockSeqTest{ + { + desc: "Empty sequence", + }, + { + desc: "Sequence of length 1", + pieces: []string{"foobar"}, + want: "foobar", + }, + { + desc: "Sequence of length 2", + pieces: []string{"foo", "bar"}, + want: "foobar", + }, + { + desc: "Empty Blocks", + pieces: []string{"", "foo", "", "", "bar", ""}, + want: "foobar", + }, + { + desc: "Sequence with non-zero offset", + pieces: []string{"foo", "bar"}, + haveOffset: true, + offset: 2, + want: "obar", + }, + { + desc: "Sequence with non-maximal limit", + pieces: []string{"foo", "bar"}, + haveLimit: true, + limit: 5, + want: "fooba", + }, + { + desc: "Sequence with offset and limit", + pieces: []string{"foo", "bar"}, + haveOffset: true, + offset: 2, + haveLimit: true, + limit: 3, + want: "oba", + }, +} + +func TestBlockSeqNumBytes(t *testing.T) { + for _, test := range blockSeqTests { + t.Run(test.desc, func(t *testing.T) { + if got, want := test.BlockSeq().NumBytes(), uint64(len(test.want)); got != want { + t.Errorf("NumBytes: got %d, wanted %d", got, want) + } + }) + } +} + +func TestBlockSeqIterBlocks(t *testing.T) { + // Tests BlockSeq iteration using Head/Tail. + for _, test := range blockSeqTests { + t.Run(test.desc, func(t *testing.T) { + srcs := test.BlockSeq() + // "Note that a non-nil empty slice and a nil slice ... are not + // deeply equal." - reflect + slices := make([][]byte, 0, 0) + for !srcs.IsEmpty() { + src := srcs.Head() + slices = append(slices, src.ToSlice()) + nextSrcs := srcs.Tail() + if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-uint64(src.Len()); got != want { + t.Fatalf("%v.Tail(): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) + } + srcs = nextSrcs + } + if wantSlices := test.NonEmptyByteSlices(); !reflect.DeepEqual(slices, wantSlices) { + t.Errorf("Accumulated slices: got %v, wanted %v", slices, wantSlices) + } + }) + } +} + +func TestBlockSeqIterBytes(t *testing.T) { + // Tests BlockSeq iteration using Head/DropFirst. + for _, test := range blockSeqTests { + t.Run(test.desc, func(t *testing.T) { + srcs := test.BlockSeq() + var dst bytes.Buffer + for !srcs.IsEmpty() { + src := srcs.Head() + var b [1]byte + n, err := Copy(BlockFromSafeSlice(b[:]), src) + if n != 1 || err != nil { + t.Fatalf("Copy: got (%v, %v), wanted (1, nil)", n, err) + } + dst.WriteByte(b[0]) + nextSrcs := srcs.DropFirst(1) + if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-1; got != want { + t.Fatalf("%v.DropFirst(1): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) + } + srcs = nextSrcs + } + if got := string(dst.Bytes()); got != test.want { + t.Errorf("Copied string: got %q, wanted %q", got, test.want) + } + }) + } +} + +func TestBlockSeqDropBeyondLimit(t *testing.T) { + blocks := []Block{BlockFromSafeSlice([]byte("123")), BlockFromSafeSlice([]byte("4"))} + bs := BlockSeqFromSlice(blocks) + if got, want := bs.NumBytes(), uint64(4); got != want { + t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) + } + bs = bs.TakeFirst(1) + if got, want := bs.NumBytes(), uint64(1); got != want { + t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) + } + bs = bs.DropFirst(2) + if got, want := bs.NumBytes(), uint64(0); got != want { + t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) + } +} diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go new file mode 100644 index 000000000..354a95dde --- /dev/null +++ b/pkg/safemem/seq_unsafe.go @@ -0,0 +1,299 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "bytes" + "fmt" + "reflect" + "unsafe" +) + +// A BlockSeq represents a sequence of Blocks, each of which has non-zero +// length. +// +// BlockSeqs are immutable and may be copied by value. The zero value of +// BlockSeq represents an empty sequence. +type BlockSeq struct { + // If length is 0, then the BlockSeq is empty. Invariants: data == 0; + // offset == 0; limit == 0. + // + // If length is -1, then the BlockSeq represents the single Block{data, + // limit, false}. Invariants: offset == 0; limit > 0; limit does not + // overflow the range of an int. + // + // If length is -2, then the BlockSeq represents the single Block{data, + // limit, true}. Invariants: offset == 0; limit > 0; limit does not + // overflow the range of an int. + // + // Otherwise, length >= 2, and the BlockSeq represents the `length` Blocks + // in the array of Blocks starting at address `data`, starting at `offset` + // bytes into the first Block and limited to the following `limit` bytes. + // Invariants: data != 0; offset < len(data[0]); limit > 0; offset+limit <= + // the combined length of all Blocks in the array; the first Block in the + // array has non-zero length. + // + // length is never 1; sequences consisting of a single Block are always + // stored inline (with length < 0). + data unsafe.Pointer + length int + offset int + limit uint64 +} + +// BlockSeqOf returns a BlockSeq representing the single Block b. +func BlockSeqOf(b Block) BlockSeq { + bs := BlockSeq{ + data: b.start, + length: -1, + limit: uint64(b.length), + } + if b.needSafecopy { + bs.length = -2 + } + return bs +} + +// BlockSeqFromSlice returns a BlockSeq representing all Blocks in slice. +// If slice contains Blocks with zero length, BlockSeq will skip them during +// iteration. +// +// Whether the returned BlockSeq shares memory with slice is unspecified; +// clients should avoid mutating slices passed to BlockSeqFromSlice. +// +// Preconditions: The combined length of all Blocks in slice <= math.MaxUint64. +func BlockSeqFromSlice(slice []Block) BlockSeq { + slice = skipEmpty(slice) + var limit uint64 + for _, b := range slice { + sum := limit + uint64(b.Len()) + if sum < limit { + panic("BlockSeq length overflows uint64") + } + limit = sum + } + return blockSeqFromSliceLimited(slice, limit) +} + +// Preconditions: The combined length of all Blocks in slice <= limit. If +// len(slice) != 0, the first Block in slice has non-zero length, and limit > +// 0. +func blockSeqFromSliceLimited(slice []Block, limit uint64) BlockSeq { + switch len(slice) { + case 0: + return BlockSeq{} + case 1: + return BlockSeqOf(slice[0].TakeFirst64(limit)) + default: + return BlockSeq{ + data: unsafe.Pointer(&slice[0]), + length: len(slice), + limit: limit, + } + } +} + +func skipEmpty(slice []Block) []Block { + for i, b := range slice { + if b.Len() != 0 { + return slice[i:] + } + } + return nil +} + +// IsEmpty returns true if bs contains no Blocks. +// +// Invariants: bs.IsEmpty() == (bs.NumBlocks() == 0) == (bs.NumBytes() == 0). +// (Of these, prefer to use bs.IsEmpty().) +func (bs BlockSeq) IsEmpty() bool { + return bs.length == 0 +} + +// NumBlocks returns the number of Blocks in bs. +func (bs BlockSeq) NumBlocks() int { + // In general, we have to count: if bs represents a windowed slice then the + // slice may contain Blocks with zero length, and bs.length may be larger + // than the actual number of Blocks due to bs.limit. + var n int + for !bs.IsEmpty() { + n++ + bs = bs.Tail() + } + return n +} + +// NumBytes returns the sum of Block.Len() for all Blocks in bs. +func (bs BlockSeq) NumBytes() uint64 { + return bs.limit +} + +// Head returns the first Block in bs. +// +// Preconditions: !bs.IsEmpty(). +func (bs BlockSeq) Head() Block { + if bs.length == 0 { + panic("empty BlockSeq") + } + if bs.length < 0 { + return bs.internalBlock() + } + return (*Block)(bs.data).DropFirst(bs.offset).TakeFirst64(bs.limit) +} + +// Preconditions: bs.length < 0. +func (bs BlockSeq) internalBlock() Block { + return Block{ + start: bs.data, + length: int(bs.limit), + needSafecopy: bs.length == -2, + } +} + +// Tail returns a BlockSeq consisting of all Blocks in bs after the first. +// +// Preconditions: !bs.IsEmpty(). +func (bs BlockSeq) Tail() BlockSeq { + if bs.length == 0 { + panic("empty BlockSeq") + } + if bs.length < 0 { + return BlockSeq{} + } + head := (*Block)(bs.data).DropFirst(bs.offset) + headLen := uint64(head.Len()) + if headLen >= bs.limit { + // The head Block exhausts the limit, so the tail is empty. + return BlockSeq{} + } + var extSlice []Block + extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice)) + extSliceHdr.Data = uintptr(bs.data) + extSliceHdr.Len = bs.length + extSliceHdr.Cap = bs.length + tailSlice := skipEmpty(extSlice[1:]) + tailLimit := bs.limit - headLen + return blockSeqFromSliceLimited(tailSlice, tailLimit) +} + +// DropFirst returns a BlockSeq equivalent to bs, but with the first n bytes +// omitted. If n > bs.NumBytes(), DropFirst returns an empty BlockSeq. +// +// Preconditions: n >= 0. +func (bs BlockSeq) DropFirst(n int) BlockSeq { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return bs.DropFirst64(uint64(n)) +} + +// DropFirst64 is equivalent to DropFirst but takes an uint64. +func (bs BlockSeq) DropFirst64(n uint64) BlockSeq { + if n >= bs.limit { + return BlockSeq{} + } + for { + // Calling bs.Head() here is surprisingly expensive, so inline getting + // the head's length. + var headLen uint64 + if bs.length < 0 { + headLen = bs.limit + } else { + headLen = uint64((*Block)(bs.data).Len() - bs.offset) + } + if n < headLen { + // Dropping ends partway through the head Block. + if bs.length < 0 { + return BlockSeqOf(bs.internalBlock().DropFirst64(n)) + } + bs.offset += int(n) + bs.limit -= n + return bs + } + n -= headLen + bs = bs.Tail() + } +} + +// TakeFirst returns a BlockSeq equivalent to the first n bytes of bs. If n > +// bs.NumBytes(), TakeFirst returns a BlockSeq equivalent to bs. +// +// Preconditions: n >= 0. +func (bs BlockSeq) TakeFirst(n int) BlockSeq { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return bs.TakeFirst64(uint64(n)) +} + +// TakeFirst64 is equivalent to TakeFirst but takes a uint64. +func (bs BlockSeq) TakeFirst64(n uint64) BlockSeq { + if n == 0 { + return BlockSeq{} + } + if bs.limit > n { + bs.limit = n + } + return bs +} + +// String implements fmt.Stringer.String. +func (bs BlockSeq) String() string { + var buf bytes.Buffer + buf.WriteByte('[') + var sep string + for !bs.IsEmpty() { + buf.WriteString(sep) + sep = " " + buf.WriteString(bs.Head().String()) + bs = bs.Tail() + } + buf.WriteByte(']') + return buf.String() +} + +// CopySeq copies srcs.NumBytes() or dsts.NumBytes() bytes, whichever is less, +// from srcs to dsts and returns the number of bytes copied. +// +// If srcs and dsts overlap, the data stored in dsts is unspecified. +func CopySeq(dsts, srcs BlockSeq) (uint64, error) { + var done uint64 + for !dsts.IsEmpty() && !srcs.IsEmpty() { + dst := dsts.Head() + src := srcs.Head() + n, err := Copy(dst, src) + done += uint64(n) + if err != nil { + return done, err + } + dsts = dsts.DropFirst(n) + srcs = srcs.DropFirst(n) + } + return done, nil +} + +// ZeroSeq sets all bytes in dsts to 0 and returns the number of bytes zeroed. +func ZeroSeq(dsts BlockSeq) (uint64, error) { + var done uint64 + for !dsts.IsEmpty() { + n, err := Zero(dsts.Head()) + done += uint64(n) + if err != nil { + return done, err + } + dsts = dsts.DropFirst(n) + } + return done, nil +} diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 51ca09b24..34c0a867d 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -30,13 +30,13 @@ go_library( ":registers_go_proto", "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/cpuid", "//pkg/log", - "//pkg/sentry/context", "//pkg/sentry/limits", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 81ec98a77..1d11cc472 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -24,7 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Arch describes an architecture. diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index ea4dedbdf..3b6987665 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -25,8 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 2aa08b1a9..85d6acc0f 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -25,7 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Host specifies the host architecture. diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index 0d5b7d317..94f1a808f 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Host specifies the host architecture. diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index 84f11b0d1..d388ee9cf 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -21,7 +21,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // ErrFloatingPoint indicates a failed restore due to unusable floating point diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 9f41e566f..a18093155 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -25,9 +25,9 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // System-related constants for x86. diff --git a/pkg/sentry/arch/auxv.go b/pkg/sentry/arch/auxv.go index 4546b2ef9..2b4c8f3fc 100644 --- a/pkg/sentry/arch/auxv.go +++ b/pkg/sentry/arch/auxv.go @@ -15,7 +15,7 @@ package arch import ( - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // An AuxEntry represents an entry in an ELF auxiliary vector. diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index 402e46025..8b03d0187 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -16,7 +16,7 @@ package arch import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // SignalAct represents the action that should be taken when a signal is diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 1e4f9c3c2..81b92bb43 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -23,7 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // SignalContext64 is equivalent to struct sigcontext, the type passed as the diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 7d0e98935..4f4cc46a8 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -19,7 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // SignalContext64 is equivalent to struct sigcontext, the type passed as the diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go index d324da705..1a6056171 100644 --- a/pkg/sentry/arch/signal_stack.go +++ b/pkg/sentry/arch/signal_stack.go @@ -17,7 +17,7 @@ package arch import ( - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 7472c3c61..09bceabc9 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -18,8 +18,8 @@ import ( "encoding/binary" "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/usermem" ) // Stack is a simple wrapper around a usermem.IO and an address. diff --git a/pkg/sentry/context/BUILD b/pkg/sentry/context/BUILD deleted file mode 100644 index e13a9ce20..000000000 --- a/pkg/sentry/context/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - srcs = ["context.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/amutex", - "//pkg/log", - ], -) diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go deleted file mode 100644 index 23e009ef3..000000000 --- a/pkg/sentry/context/context.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context defines an internal context type. -// -// The given Context conforms to the standard Go context, but mandates -// additional methods that are specific to the kernel internals. Note however, -// that the Context described by this package carries additional constraints -// regarding concurrent access and retaining beyond the scope of a call. -// -// See the Context type for complete details. -package context - -import ( - "context" - "time" - - "gvisor.dev/gvisor/pkg/amutex" - "gvisor.dev/gvisor/pkg/log" -) - -type contextID int - -// Globally accessible values from a context. These keys are defined in the -// context package to resolve dependency cycles by not requiring the caller to -// import packages usually required to get these information. -const ( - // CtxThreadGroupID is the current thread group ID when a context represents - // a task context. The value is represented as an int32. - CtxThreadGroupID contextID = iota -) - -// ThreadGroupIDFromContext returns the current thread group ID when ctx -// represents a task context. -func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { - if tgid := ctx.Value(CtxThreadGroupID); tgid != nil { - return tgid.(int32), true - } - return 0, false -} - -// A Context represents a thread of execution (hereafter "goroutine" to reflect -// Go idiosyncrasy). It carries state associated with the goroutine across API -// boundaries. -// -// While Context exists for essentially the same reasons as Go's standard -// context.Context, the standard type represents the state of an operation -// rather than that of a goroutine. This is a critical distinction: -// -// - Unlike context.Context, which "may be passed to functions running in -// different goroutines", it is *not safe* to use the same Context in multiple -// concurrent goroutines. -// -// - It is *not safe* to retain a Context passed to a function beyond the scope -// of that function call. -// -// In both cases, values extracted from the Context should be used instead. -type Context interface { - log.Logger - amutex.Sleeper - context.Context - - // UninterruptibleSleepStart indicates the beginning of an uninterruptible - // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate - // is true and the Context represents a Task, the Task's AddressSpace is - // deactivated. - UninterruptibleSleepStart(deactivate bool) - - // UninterruptibleSleepFinish indicates the end of an uninterruptible sleep - // state that was begun by a previous call to UninterruptibleSleepStart. If - // activate is true and the Context represents a Task, the Task's - // AddressSpace is activated. Normally activate is the same value as the - // deactivate parameter passed to UninterruptibleSleepStart. - UninterruptibleSleepFinish(activate bool) -} - -// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep -// methods for anonymous embedding in other types that do not implement sleeps. -type NoopSleeper struct { - amutex.NoopSleeper -} - -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} - -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} - -// Deadline returns zero values, meaning no deadline. -func (NoopSleeper) Deadline() (time.Time, bool) { - return time.Time{}, false -} - -// Done returns nil. -func (NoopSleeper) Done() <-chan struct{} { - return nil -} - -// Err returns nil. -func (NoopSleeper) Err() error { - return nil -} - -// logContext implements basic logging. -type logContext struct { - log.Logger - NoopSleeper -} - -// Value implements Context.Value. -func (logContext) Value(key interface{}) interface{} { - return nil -} - -// bgContext is the context returned by context.Background. -var bgContext = &logContext{Logger: log.Log()} - -// Background returns an empty context using the default logger. -// -// Users should be wary of using a Background context. Please tag any use with -// FIXME(b/38173783) and a note to remove this use. -// -// Generally, one should use the Task as their context when available, or avoid -// having to use a context in places where a Task is unavailable. -// -// Using a Background context for tests is fine, as long as no values are -// needed from the context in the tested code paths. -func Background() Context { - return bgContext -} diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD deleted file mode 100644 index f91a6d4ed..000000000 --- a/pkg/sentry/context/contexttest/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "contexttest", - testonly = 1, - srcs = ["contexttest.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/memutil", - "//pkg/sentry/context", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/platform/ptrace", - "//pkg/sentry/uniqueid", - ], -) diff --git a/pkg/sentry/context/contexttest/contexttest.go b/pkg/sentry/context/contexttest/contexttest.go deleted file mode 100644 index 15cf086a9..000000000 --- a/pkg/sentry/context/contexttest/contexttest.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package contexttest builds a test context.Context. -package contexttest - -import ( - "os" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ptrace" - "gvisor.dev/gvisor/pkg/sentry/uniqueid" -) - -// Context returns a Context that may be used in tests. Uses ptrace as the -// platform.Platform. -// -// Note that some filesystems may require a minimal kernel for testing, which -// this test context does not provide. For such tests, see kernel/contexttest. -func Context(tb testing.TB) context.Context { - const memfileName = "contexttest-memory" - memfd, err := memutil.CreateMemFD(memfileName, 0) - if err != nil { - tb.Fatalf("error creating application memory file: %v", err) - } - memfile := os.NewFile(uintptr(memfd), memfileName) - mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) - if err != nil { - memfile.Close() - tb.Fatalf("error creating pgalloc.MemoryFile: %v", err) - } - p, err := ptrace.New() - if err != nil { - tb.Fatal(err) - } - // Test usage of context.Background is fine. - return &TestContext{ - Context: context.Background(), - l: limits.NewLimitSet(), - mf: mf, - platform: p, - creds: auth.NewAnonymousCredentials(), - otherValues: make(map[interface{}]interface{}), - } -} - -// TestContext represents a context with minimal functionality suitable for -// running tests. -type TestContext struct { - context.Context - l *limits.LimitSet - mf *pgalloc.MemoryFile - platform platform.Platform - creds *auth.Credentials - otherValues map[interface{}]interface{} -} - -// globalUniqueID tracks incremental unique identifiers for tests. -var globalUniqueID uint64 - -// globalUniqueIDProvider implements unix.UniqueIDProvider. -type globalUniqueIDProvider struct{} - -// UniqueID implements unix.UniqueIDProvider.UniqueID. -func (*globalUniqueIDProvider) UniqueID() uint64 { - return atomic.AddUint64(&globalUniqueID, 1) -} - -// lastInotifyCookie is a monotonically increasing counter for generating unique -// inotify cookies. Must be accessed using atomic ops. -var lastInotifyCookie uint32 - -// hostClock implements ktime.Clock. -type hostClock struct { - ktime.WallRateClock - ktime.NoClockEvents -} - -// Now implements ktime.Clock.Now. -func (hostClock) Now() ktime.Time { - return ktime.FromNanoseconds(time.Now().UnixNano()) -} - -// RegisterValue registers additional values with this test context. Useful for -// providing values from external packages that contexttest can't depend on. -func (t *TestContext) RegisterValue(key, value interface{}) { - t.otherValues[key] = value -} - -// Value implements context.Context. -func (t *TestContext) Value(key interface{}) interface{} { - switch key { - case auth.CtxCredentials: - return t.creds - case limits.CtxLimits: - return t.l - case pgalloc.CtxMemoryFile: - return t.mf - case pgalloc.CtxMemoryFileProvider: - return t - case platform.CtxPlatform: - return t.platform - case uniqueid.CtxGlobalUniqueID: - return (*globalUniqueIDProvider).UniqueID(nil) - case uniqueid.CtxGlobalUniqueIDProvider: - return &globalUniqueIDProvider{} - case uniqueid.CtxInotifyCookie: - return atomic.AddUint32(&lastInotifyCookie, 1) - case ktime.CtxRealtimeClock: - return hostClock{} - default: - if val, ok := t.otherValues[key]; ok { - return val - } - return t.Context.Value(key) - } -} - -// MemoryFile implements pgalloc.MemoryFileProvider.MemoryFile. -func (t *TestContext) MemoryFile() *pgalloc.MemoryFile { - return t.mf -} - -// RootContext returns a Context that may be used in tests that need root -// credentials. Uses ptrace as the platform.Platform. -func RootContext(tb testing.TB) context.Context { - return WithCreds(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) -} - -// WithCreds returns a copy of ctx carrying creds. -func WithCreds(ctx context.Context, creds *auth.Credentials) context.Context { - return &authContext{ctx, creds} -} - -type authContext struct { - context.Context - creds *auth.Credentials -} - -// Value implements context.Context. -func (ac *authContext) Value(key interface{}) interface{} { - switch key { - case auth.CtxCredentials: - return ac.creds - default: - return ac.Context.Value(key) - } -} - -// WithLimitSet returns a copy of ctx carrying l. -func WithLimitSet(ctx context.Context, l *limits.LimitSet) context.Context { - return limitContext{ctx, l} -} - -type limitContext struct { - context.Context - l *limits.LimitSet -} - -// Value implements context.Context. -func (lc limitContext) Value(key interface{}) interface{} { - switch key { - case limits.CtxLimits: - return lc.l - default: - return lc.Context.Value(key) - } -} diff --git a/pkg/sentry/contexttest/BUILD b/pkg/sentry/contexttest/BUILD new file mode 100644 index 000000000..6f4c86684 --- /dev/null +++ b/pkg/sentry/contexttest/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "contexttest", + testonly = 1, + srcs = ["contexttest.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/context", + "//pkg/memutil", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", + "//pkg/sentry/limits", + "//pkg/sentry/pgalloc", + "//pkg/sentry/platform", + "//pkg/sentry/platform/ptrace", + "//pkg/sentry/uniqueid", + ], +) diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go new file mode 100644 index 000000000..031fc64ec --- /dev/null +++ b/pkg/sentry/contexttest/contexttest.go @@ -0,0 +1,188 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package contexttest builds a test context.Context. +package contexttest + +import ( + "os" + "sync/atomic" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ptrace" + "gvisor.dev/gvisor/pkg/sentry/uniqueid" +) + +// Context returns a Context that may be used in tests. Uses ptrace as the +// platform.Platform. +// +// Note that some filesystems may require a minimal kernel for testing, which +// this test context does not provide. For such tests, see kernel/contexttest. +func Context(tb testing.TB) context.Context { + const memfileName = "contexttest-memory" + memfd, err := memutil.CreateMemFD(memfileName, 0) + if err != nil { + tb.Fatalf("error creating application memory file: %v", err) + } + memfile := os.NewFile(uintptr(memfd), memfileName) + mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) + if err != nil { + memfile.Close() + tb.Fatalf("error creating pgalloc.MemoryFile: %v", err) + } + p, err := ptrace.New() + if err != nil { + tb.Fatal(err) + } + // Test usage of context.Background is fine. + return &TestContext{ + Context: context.Background(), + l: limits.NewLimitSet(), + mf: mf, + platform: p, + creds: auth.NewAnonymousCredentials(), + otherValues: make(map[interface{}]interface{}), + } +} + +// TestContext represents a context with minimal functionality suitable for +// running tests. +type TestContext struct { + context.Context + l *limits.LimitSet + mf *pgalloc.MemoryFile + platform platform.Platform + creds *auth.Credentials + otherValues map[interface{}]interface{} +} + +// globalUniqueID tracks incremental unique identifiers for tests. +var globalUniqueID uint64 + +// globalUniqueIDProvider implements unix.UniqueIDProvider. +type globalUniqueIDProvider struct{} + +// UniqueID implements unix.UniqueIDProvider.UniqueID. +func (*globalUniqueIDProvider) UniqueID() uint64 { + return atomic.AddUint64(&globalUniqueID, 1) +} + +// lastInotifyCookie is a monotonically increasing counter for generating unique +// inotify cookies. Must be accessed using atomic ops. +var lastInotifyCookie uint32 + +// hostClock implements ktime.Clock. +type hostClock struct { + ktime.WallRateClock + ktime.NoClockEvents +} + +// Now implements ktime.Clock.Now. +func (hostClock) Now() ktime.Time { + return ktime.FromNanoseconds(time.Now().UnixNano()) +} + +// RegisterValue registers additional values with this test context. Useful for +// providing values from external packages that contexttest can't depend on. +func (t *TestContext) RegisterValue(key, value interface{}) { + t.otherValues[key] = value +} + +// Value implements context.Context. +func (t *TestContext) Value(key interface{}) interface{} { + switch key { + case auth.CtxCredentials: + return t.creds + case limits.CtxLimits: + return t.l + case pgalloc.CtxMemoryFile: + return t.mf + case pgalloc.CtxMemoryFileProvider: + return t + case platform.CtxPlatform: + return t.platform + case uniqueid.CtxGlobalUniqueID: + return (*globalUniqueIDProvider).UniqueID(nil) + case uniqueid.CtxGlobalUniqueIDProvider: + return &globalUniqueIDProvider{} + case uniqueid.CtxInotifyCookie: + return atomic.AddUint32(&lastInotifyCookie, 1) + case ktime.CtxRealtimeClock: + return hostClock{} + default: + if val, ok := t.otherValues[key]; ok { + return val + } + return t.Context.Value(key) + } +} + +// MemoryFile implements pgalloc.MemoryFileProvider.MemoryFile. +func (t *TestContext) MemoryFile() *pgalloc.MemoryFile { + return t.mf +} + +// RootContext returns a Context that may be used in tests that need root +// credentials. Uses ptrace as the platform.Platform. +func RootContext(tb testing.TB) context.Context { + return WithCreds(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) +} + +// WithCreds returns a copy of ctx carrying creds. +func WithCreds(ctx context.Context, creds *auth.Credentials) context.Context { + return &authContext{ctx, creds} +} + +type authContext struct { + context.Context + creds *auth.Credentials +} + +// Value implements context.Context. +func (ac *authContext) Value(key interface{}) interface{} { + switch key { + case auth.CtxCredentials: + return ac.creds + default: + return ac.Context.Value(key) + } +} + +// WithLimitSet returns a copy of ctx carrying l. +func WithLimitSet(ctx context.Context, l *limits.LimitSet) context.Context { + return limitContext{ctx, l} +} + +type limitContext struct { + context.Context + l *limits.LimitSet +} + +// Value implements context.Context. +func (lc limitContext) Value(key interface{}) interface{} { + switch key { + case limits.CtxLimits: + return lc.l + default: + return lc.Context.Value(key) + } +} diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 605d61dbe..ea85ab33c 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -47,13 +47,13 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/amutex", + "//pkg/context", "//pkg/log", "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/secio", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", @@ -64,10 +64,10 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/uniqueid", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/state", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -107,14 +107,14 @@ go_test( ], deps = [ ":fs", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", "//pkg/sentry/kernel/contexttest", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) @@ -129,7 +129,7 @@ go_test( ], library = ":fs", deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/sentry/contexttest", ], ) diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD index c14e5405e..aedcecfa1 100644 --- a/pkg/sentry/fs/anon/BUILD +++ b/pkg/sentry/fs/anon/BUILD @@ -11,10 +11,10 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/anon/anon.go b/pkg/sentry/fs/anon/anon.go index 7323c7222..5c421f5fb 100644 --- a/pkg/sentry/fs/anon/anon.go +++ b/pkg/sentry/fs/anon/anon.go @@ -18,10 +18,10 @@ package anon import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // NewInode constructs an anonymous Inode that is not associated diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go index 4f3d6410e..fa9e7d517 100644 --- a/pkg/sentry/fs/attr.go +++ b/pkg/sentry/fs/attr.go @@ -20,8 +20,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) diff --git a/pkg/sentry/fs/context.go b/pkg/sentry/fs/context.go index dd427de5d..0fbd60056 100644 --- a/pkg/sentry/fs/context.go +++ b/pkg/sentry/fs/context.go @@ -16,7 +16,7 @@ package fs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index e03e3e417..f6c79e51b 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -19,12 +19,12 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // copyUp copies a file in an overlay from a lower filesystem to an diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index 738580c5f..91792d9fe 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -24,8 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 0c7247bd7..4c4b7d5cc 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -16,8 +16,9 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/rand", - "//pkg/sentry/context", + "//pkg/safemem", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", @@ -26,9 +27,8 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/pgalloc", - "//pkg/sentry/safemem", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index f739c476c..35bd23991 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -18,11 +18,11 @@ package dev import ( "math" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Memory device numbers are from Linux's drivers/char/mem.c diff --git a/pkg/sentry/fs/dev/fs.go b/pkg/sentry/fs/dev/fs.go index 55f8af704..5e518fb63 100644 --- a/pkg/sentry/fs/dev/fs.go +++ b/pkg/sentry/fs/dev/fs.go @@ -15,7 +15,7 @@ package dev import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/dev/full.go b/pkg/sentry/fs/dev/full.go index 07e0ea010..deb9c6ad8 100644 --- a/pkg/sentry/fs/dev/full.go +++ b/pkg/sentry/fs/dev/full.go @@ -16,11 +16,11 @@ package dev import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/dev/null.go b/pkg/sentry/fs/dev/null.go index 4404b97ef..aec33d0d9 100644 --- a/pkg/sentry/fs/dev/null.go +++ b/pkg/sentry/fs/dev/null.go @@ -16,7 +16,7 @@ package dev import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" diff --git a/pkg/sentry/fs/dev/random.go b/pkg/sentry/fs/dev/random.go index 49cb92f6e..2a9bbeb18 100644 --- a/pkg/sentry/fs/dev/random.go +++ b/pkg/sentry/fs/dev/random.go @@ -16,12 +16,12 @@ package dev import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/dev/tty.go b/pkg/sentry/fs/dev/tty.go index 87d80e292..760ca563d 100644 --- a/pkg/sentry/fs/dev/tty.go +++ b/pkg/sentry/fs/dev/tty.go @@ -16,7 +16,7 @@ package dev import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 31fc4d87b..acab0411a 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -22,8 +22,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go index 47bc72a88..98d69c6f2 100644 --- a/pkg/sentry/fs/dirent_refs_test.go +++ b/pkg/sentry/fs/dirent_refs_test.go @@ -18,8 +18,8 @@ import ( "syscall" "testing" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" ) func newMockDirInode(ctx context.Context, cache *DirentCache) *Inode { diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index 25ef96299..1d09e983c 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -12,17 +12,17 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/sentry/fs"], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", "//pkg/log", + "//pkg/safemem", "//pkg/secio", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/safemem", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -36,13 +36,13 @@ go_test( ], library = ":fdpipe", deps = [ + "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", "@com_github_google_uuid//:go_default_library", ], ) diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 5b6cfeb0a..9fce177ad 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -19,17 +19,17 @@ import ( "os" "syscall" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/secio" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fdpipe/pipe_opener.go b/pkg/sentry/fs/fdpipe/pipe_opener.go index 64b558975..0c3595998 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener.go @@ -20,8 +20,8 @@ import ( "syscall" "time" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go index 577445148..e556da48a 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go @@ -26,12 +26,12 @@ import ( "github.com/google/uuid" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) type hostOpener struct { diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go index cee87f726..af8230a7d 100644 --- a/pkg/sentry/fs/fdpipe/pipe_state.go +++ b/pkg/sentry/fs/fdpipe/pipe_state.go @@ -18,7 +18,7 @@ import ( "fmt" "io/ioutil" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" ) diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index 69abc1e71..5aff0cc95 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -23,10 +23,10 @@ import ( "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) func singlePipeFD() (int, error) { diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 7c4586296..ca3466f4f 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -20,16 +20,16 @@ import ( "time" "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index b88303f17..beba0f771 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -17,10 +17,10 @@ package fs import ( "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 8991207b4..dcc1df38f 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -17,13 +17,13 @@ package fs import ( "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go index 2fb824d5c..02538bb4f 100644 --- a/pkg/sentry/fs/file_overlay_test.go +++ b/pkg/sentry/fs/file_overlay_test.go @@ -18,7 +18,7 @@ import ( "reflect" "testing" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go index c5b51620a..084da2a8d 100644 --- a/pkg/sentry/fs/filesystems.go +++ b/pkg/sentry/fs/filesystems.go @@ -19,7 +19,7 @@ import ( "sort" "strings" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD index 9a7608cae..a8000e010 100644 --- a/pkg/sentry/fs/filetest/BUILD +++ b/pkg/sentry/fs/filetest/BUILD @@ -8,12 +8,12 @@ go_library( srcs = ["filetest.go"], visibility = ["//pkg/sentry:internal"], deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/usermem", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go index 22270a494..8049538f2 100644 --- a/pkg/sentry/fs/filetest/filetest.go +++ b/pkg/sentry/fs/filetest/filetest.go @@ -19,12 +19,12 @@ import ( "fmt" "testing" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index 26abf49e2..bdba6efe5 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -54,8 +54,8 @@ package fs import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sync" ) diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 9142f5bdf..4ab2a384f 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -77,22 +77,22 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", - "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/state", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -106,13 +106,13 @@ go_test( ], library = ":fsutil", deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/safemem", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", - "//pkg/sentry/safemem", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go index 12132680b..c6cd45087 100644 --- a/pkg/sentry/fs/fsutil/dirty_set.go +++ b/pkg/sentry/fs/fsutil/dirty_set.go @@ -17,11 +17,11 @@ package fsutil import ( "math" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // DirtySet maps offsets into a memmap.Mappable to DirtyInfo. It is used to diff --git a/pkg/sentry/fs/fsutil/dirty_set_test.go b/pkg/sentry/fs/fsutil/dirty_set_test.go index 75575d994..e3579c23c 100644 --- a/pkg/sentry/fs/fsutil/dirty_set_test.go +++ b/pkg/sentry/fs/fsutil/dirty_set_test.go @@ -19,7 +19,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) func TestDirtySet(t *testing.T) { diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index fc5b3b1a1..08695391c 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -17,12 +17,12 @@ package fsutil import ( "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index f52d712e3..5643cdac9 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -19,13 +19,13 @@ import ( "io" "math" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // FileRangeSet maps offsets into a memmap.Mappable to offsets into a diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 837fc70b5..67278aa86 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -19,11 +19,11 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // HostFileMapper caches mappings of an arbitrary host file descriptor. It is diff --git a/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go b/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go index ad11a0573..2d4778d64 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go @@ -17,7 +17,7 @@ package fsutil import ( "unsafe" - "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/safemem" ) func (*HostFileMapper) unsafeBlockFromChunkMapping(addr uintptr) safemem.Block { diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index a625f0e26..78fec553e 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -17,13 +17,13 @@ package fsutil import ( "math" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // HostMappable implements memmap.Mappable and platform.File over a diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index df7b74855..252830572 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -16,7 +16,7 @@ package fsutil import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 20a014402..573b8586e 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -18,18 +18,18 @@ import ( "fmt" "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // Lock order (compare the lock order model in mm/mm.go): diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go index 129f314c8..1547584c5 100644 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ b/pkg/sentry/fs/fsutil/inode_cached_test.go @@ -19,14 +19,14 @@ import ( "io" "testing" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) type noopBackingFile struct{} diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index cf48e7c03..971d3718e 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -24,13 +24,14 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fd", "//pkg/log", "//pkg/metric", "//pkg/p9", "//pkg/refs", + "//pkg/safemem", "//pkg/secio", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fdpipe", @@ -39,13 +40,12 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", - "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -56,10 +56,10 @@ go_test( srcs = ["gofer_test.go"], library = ":gofer", deps = [ + "//pkg/context", "//pkg/p9", "//pkg/p9/p9test", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", ], ) diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go index 4848e2374..71cccdc34 100644 --- a/pkg/sentry/fs/gofer/attr.go +++ b/pkg/sentry/fs/gofer/attr.go @@ -17,12 +17,12 @@ package gofer import ( "syscall" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // getattr returns the 9p attributes of the p9.File. On success, Mode, Size, and RDev diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go index cc11c6339..ebea03c42 100644 --- a/pkg/sentry/fs/gofer/cache_policy.go +++ b/pkg/sentry/fs/gofer/cache_policy.go @@ -17,7 +17,7 @@ package gofer import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/gofer/context_file.go b/pkg/sentry/fs/gofer/context_file.go index 2125dafef..3da818aed 100644 --- a/pkg/sentry/fs/gofer/context_file.go +++ b/pkg/sentry/fs/gofer/context_file.go @@ -15,9 +15,9 @@ package gofer import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" ) // contextFile is a wrapper around p9.File that notifies the context that diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 7960b9c7b..23296f246 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -19,16 +19,16 @@ import ( "syscall" "time" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go index bb8312849..ff96b28ba 100644 --- a/pkg/sentry/fs/gofer/file_state.go +++ b/pkg/sentry/fs/gofer/file_state.go @@ -17,7 +17,7 @@ package gofer import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go index cf96dd9fa..9d41fcbdb 100644 --- a/pkg/sentry/fs/gofer/fs.go +++ b/pkg/sentry/fs/gofer/fs.go @@ -20,8 +20,8 @@ import ( "fmt" "strconv" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go index 7fc3c32ae..0c2f89ae8 100644 --- a/pkg/sentry/fs/gofer/gofer_test.go +++ b/pkg/sentry/fs/gofer/gofer_test.go @@ -20,10 +20,10 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/p9/p9test" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index b86c49b39..9f7c3e89f 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -17,14 +17,14 @@ package gofer import ( "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/secio" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/safemem" ) // handles are the open handles of a gofer file. They are reference counted to diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 98d1a8a48..ac28174d2 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -19,17 +19,17 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fdpipe" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/host" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go index 0b2eedb7c..238f7804c 100644 --- a/pkg/sentry/fs/gofer/inode_state.go +++ b/pkg/sentry/fs/gofer/inode_state.go @@ -20,8 +20,8 @@ import ( "path/filepath" "strings" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/time" diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index c09f3b71c..0c1be05ef 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -18,9 +18,9 @@ import ( "fmt" "syscall" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index edc796ce0..498c4645a 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -17,9 +17,9 @@ package gofer import ( "fmt" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go index d045e04ff..0285c5361 100644 --- a/pkg/sentry/fs/gofer/session_state.go +++ b/pkg/sentry/fs/gofer/session_state.go @@ -17,8 +17,8 @@ package gofer import ( "fmt" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index a45a8f36c..376cfce2c 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -16,9 +16,9 @@ package gofer import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" diff --git a/pkg/sentry/fs/gofer/util.go b/pkg/sentry/fs/gofer/util.go index 848e6812b..2d8d3a2ea 100644 --- a/pkg/sentry/fs/gofer/util.go +++ b/pkg/sentry/fs/gofer/util.go @@ -17,8 +17,8 @@ package gofer import ( "syscall" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index f586f47c1..21003ea45 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -27,13 +27,14 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", "//pkg/log", "//pkg/refs", + "//pkg/safemem", "//pkg/secio", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", @@ -41,18 +42,17 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", - "//pkg/sentry/safemem", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/uniqueid", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/unet", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -69,17 +69,17 @@ go_test( ], library = ":host", deps = [ + "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/time", "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/tcpip", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go index 5532ff5a0..1658979fc 100644 --- a/pkg/sentry/fs/host/control.go +++ b/pkg/sentry/fs/host/control.go @@ -17,7 +17,7 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go index f6c626f2c..e08f56d04 100644 --- a/pkg/sentry/fs/host/file.go +++ b/pkg/sentry/fs/host/file.go @@ -18,17 +18,17 @@ import ( "fmt" "syscall" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/secio" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/host/fs.go b/pkg/sentry/fs/host/fs.go index 68d2697c0..d3e8e3a36 100644 --- a/pkg/sentry/fs/host/fs.go +++ b/pkg/sentry/fs/host/fs.go @@ -23,8 +23,8 @@ import ( "strconv" "strings" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/host/fs_test.go b/pkg/sentry/fs/host/fs_test.go index c6852ee30..3111d2df9 100644 --- a/pkg/sentry/fs/host/fs_test.go +++ b/pkg/sentry/fs/host/fs_test.go @@ -23,8 +23,8 @@ import ( "sort" "testing" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 873a1c52d..6fa39caab 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -18,14 +18,14 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/secio" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" diff --git a/pkg/sentry/fs/host/inode_state.go b/pkg/sentry/fs/host/inode_state.go index b267ec305..299e0e0b0 100644 --- a/pkg/sentry/fs/host/inode_state.go +++ b/pkg/sentry/fs/host/inode_state.go @@ -18,7 +18,7 @@ import ( "fmt" "syscall" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go index 2d959f10d..7221bc825 100644 --- a/pkg/sentry/fs/host/inode_test.go +++ b/pkg/sentry/fs/host/inode_test.go @@ -21,7 +21,7 @@ import ( "syscall" "testing" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index c076d5bdd..06fc2d80a 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -19,11 +19,11 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/socket/control" unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index 68b38fd1c..eb4afe520 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -21,13 +21,13 @@ import ( "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 753ef8cd6..3f218b4a7 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -16,14 +16,14 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // TTYFileOperations implements fs.FileOperations for a host file descriptor diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go index 88d24d693..d49c3a635 100644 --- a/pkg/sentry/fs/host/wait_test.go +++ b/pkg/sentry/fs/host/wait_test.go @@ -19,7 +19,7 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index e4cf5a570..b66c091ab 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -16,10 +16,10 @@ package fs import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go index 13261cb81..70f2eae96 100644 --- a/pkg/sentry/fs/inode_operations.go +++ b/pkg/sentry/fs/inode_operations.go @@ -17,7 +17,7 @@ package fs import ( "errors" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index c477de837..4729b4aac 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -19,8 +19,8 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go index 493d98c36..389c219d6 100644 --- a/pkg/sentry/fs/inode_overlay_test.go +++ b/pkg/sentry/fs/inode_overlay_test.go @@ -17,7 +17,7 @@ package fs_test import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index cc7dd1c92..928c90aa0 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -19,13 +19,13 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/inotify_event.go b/pkg/sentry/fs/inotify_event.go index 9f70a3e82..686e1b1cd 100644 --- a/pkg/sentry/fs/inotify_event.go +++ b/pkg/sentry/fs/inotify_event.go @@ -18,8 +18,8 @@ import ( "bytes" "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/usermem" ) // inotifyEventBaseSize is the base size of linux's struct inotify_event. This diff --git a/pkg/sentry/fs/mock.go b/pkg/sentry/fs/mock.go index 7a24c6f1b..1d6ea5736 100644 --- a/pkg/sentry/fs/mock.go +++ b/pkg/sentry/fs/mock.go @@ -15,7 +15,7 @@ package fs import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/mount.go b/pkg/sentry/fs/mount.go index 7a9692800..37bae6810 100644 --- a/pkg/sentry/fs/mount.go +++ b/pkg/sentry/fs/mount.go @@ -19,8 +19,8 @@ import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" ) // DirentOperations provide file systems greater control over how long a Dirent diff --git a/pkg/sentry/fs/mount_overlay.go b/pkg/sentry/fs/mount_overlay.go index 299712cd7..78e35b1e6 100644 --- a/pkg/sentry/fs/mount_overlay.go +++ b/pkg/sentry/fs/mount_overlay.go @@ -15,7 +15,7 @@ package fs import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // overlayMountSourceOperations implements MountSourceOperations for an overlay diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go index 0b84732aa..e672a438c 100644 --- a/pkg/sentry/fs/mount_test.go +++ b/pkg/sentry/fs/mount_test.go @@ -18,7 +18,7 @@ import ( "fmt" "testing" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" ) // cacheReallyContains iterates through the dirent cache to determine whether diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index a9627a9d1..574a2cc91 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -22,9 +22,9 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go index c4c771f2c..a69b41468 100644 --- a/pkg/sentry/fs/mounts_test.go +++ b/pkg/sentry/fs/mounts_test.go @@ -17,7 +17,7 @@ package fs_test import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" diff --git a/pkg/sentry/fs/offset.go b/pkg/sentry/fs/offset.go index f7d844ce7..53b5df175 100644 --- a/pkg/sentry/fs/offset.go +++ b/pkg/sentry/fs/offset.go @@ -17,7 +17,7 @@ package fs import ( "math" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // OffsetPageEnd returns the file offset rounded up to the nearest diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index f7702f8f4..a8ae7d81d 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -18,12 +18,12 @@ import ( "fmt" "strings" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // The virtual filesystem implements an overlay configuration. For a high-level diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index b06bead41..280093c5e 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -29,8 +29,8 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/proc/device", @@ -46,10 +46,10 @@ go_library( "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -64,8 +64,8 @@ go_test( library = ":proc", deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/inet", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/proc/cgroup.go b/pkg/sentry/fs/proc/cgroup.go index c4abe319d..7c1d9e7e9 100644 --- a/pkg/sentry/fs/proc/cgroup.go +++ b/pkg/sentry/fs/proc/cgroup.go @@ -17,7 +17,7 @@ package proc import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/proc/cpuinfo.go b/pkg/sentry/fs/proc/cpuinfo.go index df0c4e3a7..c96533401 100644 --- a/pkg/sentry/fs/proc/cpuinfo.go +++ b/pkg/sentry/fs/proc/cpuinfo.go @@ -17,7 +17,7 @@ package proc import ( "bytes" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" ) diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go index 9aaeb780b..8fe626e1c 100644 --- a/pkg/sentry/fs/proc/exec_args.go +++ b/pkg/sentry/fs/proc/exec_args.go @@ -20,12 +20,12 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go index 2fa3cfa7d..35972e23c 100644 --- a/pkg/sentry/fs/proc/fds.go +++ b/pkg/sentry/fs/proc/fds.go @@ -19,7 +19,7 @@ import ( "sort" "strconv" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" diff --git a/pkg/sentry/fs/proc/filesystems.go b/pkg/sentry/fs/proc/filesystems.go index 7b3b974ab..0a58ac34c 100644 --- a/pkg/sentry/fs/proc/filesystems.go +++ b/pkg/sentry/fs/proc/filesystems.go @@ -18,7 +18,7 @@ import ( "bytes" "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" ) diff --git a/pkg/sentry/fs/proc/fs.go b/pkg/sentry/fs/proc/fs.go index 761d24462..daf1ba781 100644 --- a/pkg/sentry/fs/proc/fs.go +++ b/pkg/sentry/fs/proc/fs.go @@ -17,7 +17,7 @@ package proc import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go index 723f6b661..d2859a4c2 100644 --- a/pkg/sentry/fs/proc/inode.go +++ b/pkg/sentry/fs/proc/inode.go @@ -16,14 +16,14 @@ package proc import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange diff --git a/pkg/sentry/fs/proc/loadavg.go b/pkg/sentry/fs/proc/loadavg.go index d7d2afcb7..139d49c34 100644 --- a/pkg/sentry/fs/proc/loadavg.go +++ b/pkg/sentry/fs/proc/loadavg.go @@ -18,7 +18,7 @@ import ( "bytes" "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" ) diff --git a/pkg/sentry/fs/proc/meminfo.go b/pkg/sentry/fs/proc/meminfo.go index 313c6a32b..465b47da9 100644 --- a/pkg/sentry/fs/proc/meminfo.go +++ b/pkg/sentry/fs/proc/meminfo.go @@ -18,11 +18,11 @@ import ( "bytes" "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go index d4efc86e0..c10888100 100644 --- a/pkg/sentry/fs/proc/mounts.go +++ b/pkg/sentry/fs/proc/mounts.go @@ -20,7 +20,7 @@ import ( "sort" "strings" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/kernel" diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index bad445f3f..6f2775344 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -22,8 +22,8 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" @@ -33,9 +33,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 29867dc3a..c8abb5052 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -20,7 +20,7 @@ import ( "sort" "strconv" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 310d8dd52..21338d912 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -8,14 +8,14 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/proc/device", "//pkg/sentry/kernel/time", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -26,10 +26,10 @@ go_test( srcs = ["seqfile_test.go"], library = ":seqfile", deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/sentry/fs/ramfs", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go index f9af191d5..6121f0e95 100644 --- a/pkg/sentry/fs/proc/seqfile/seqfile.go +++ b/pkg/sentry/fs/proc/seqfile/seqfile.go @@ -19,14 +19,14 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_test.go b/pkg/sentry/fs/proc/seqfile/seqfile_test.go index ebfeee835..98e394569 100644 --- a/pkg/sentry/fs/proc/seqfile/seqfile_test.go +++ b/pkg/sentry/fs/proc/seqfile/seqfile_test.go @@ -20,11 +20,11 @@ import ( "io" "testing" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) type seqTest struct { diff --git a/pkg/sentry/fs/proc/stat.go b/pkg/sentry/fs/proc/stat.go index bc5b2bc7b..d4fbd76ac 100644 --- a/pkg/sentry/fs/proc/stat.go +++ b/pkg/sentry/fs/proc/stat.go @@ -19,7 +19,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/kernel" ) diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index 2bdcf5f70..f8aad2dbd 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -20,13 +20,13 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index b9e8ef35f..0772d4ae4 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -19,14 +19,14 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go index 6abae7a60..355e83d47 100644 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ b/pkg/sentry/fs/proc/sys_net_test.go @@ -17,9 +17,9 @@ package proc import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) func TestQuerySendBufferSize(t *testing.T) { diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 7358d6ef9..ca020e11e 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -22,7 +22,7 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" @@ -32,8 +32,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go index 3eacc9265..8d9517b95 100644 --- a/pkg/sentry/fs/proc/uid_gid_map.go +++ b/pkg/sentry/fs/proc/uid_gid_map.go @@ -20,13 +20,13 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/uptime.go b/pkg/sentry/fs/proc/uptime.go index adfe58adb..c0f6fb802 100644 --- a/pkg/sentry/fs/proc/uptime.go +++ b/pkg/sentry/fs/proc/uptime.go @@ -19,12 +19,12 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/version.go b/pkg/sentry/fs/proc/version.go index 27fd5b1cb..35e258ff6 100644 --- a/pkg/sentry/fs/proc/version.go +++ b/pkg/sentry/fs/proc/version.go @@ -17,7 +17,7 @@ package proc import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/kernel" ) diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 39c4b84f8..8ca823fb3 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -13,14 +13,14 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -31,7 +31,7 @@ go_test( srcs = ["tree_test.go"], library = ":ramfs", deps = [ - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", ], ) diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index dcbb8eb2e..bfa304552 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -20,7 +20,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go index a24fe2ea2..29ff004f2 100644 --- a/pkg/sentry/fs/ramfs/socket.go +++ b/pkg/sentry/fs/ramfs/socket.go @@ -16,7 +16,7 @@ package ramfs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" diff --git a/pkg/sentry/fs/ramfs/symlink.go b/pkg/sentry/fs/ramfs/symlink.go index fcfaa29aa..d988349aa 100644 --- a/pkg/sentry/fs/ramfs/symlink.go +++ b/pkg/sentry/fs/ramfs/symlink.go @@ -16,7 +16,7 @@ package ramfs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/sentry/fs/ramfs/tree.go b/pkg/sentry/fs/ramfs/tree.go index 702cc4a1e..dfc9d3453 100644 --- a/pkg/sentry/fs/ramfs/tree.go +++ b/pkg/sentry/fs/ramfs/tree.go @@ -19,10 +19,10 @@ import ( "path" "strings" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // MakeDirectoryTree constructs a ramfs tree of all directories containing diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go index 61a7e2900..a6ed8b2c5 100644 --- a/pkg/sentry/fs/ramfs/tree_test.go +++ b/pkg/sentry/fs/ramfs/tree_test.go @@ -17,7 +17,7 @@ package ramfs import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index 389c330a0..791d1526c 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -18,7 +18,7 @@ import ( "io" "sync/atomic" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD index cc6b3bfbf..f2e8b9932 100644 --- a/pkg/sentry/fs/sys/BUILD +++ b/pkg/sentry/fs/sys/BUILD @@ -13,12 +13,12 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/kernel", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/sys/devices.go b/pkg/sentry/fs/sys/devices.go index 4f78ca8d2..b67065956 100644 --- a/pkg/sentry/fs/sys/devices.go +++ b/pkg/sentry/fs/sys/devices.go @@ -18,7 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" diff --git a/pkg/sentry/fs/sys/fs.go b/pkg/sentry/fs/sys/fs.go index e60b63e75..fd03a4e38 100644 --- a/pkg/sentry/fs/sys/fs.go +++ b/pkg/sentry/fs/sys/fs.go @@ -15,7 +15,7 @@ package sys import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" ) diff --git a/pkg/sentry/fs/sys/sys.go b/pkg/sentry/fs/sys/sys.go index b14bf3f55..0891645e4 100644 --- a/pkg/sentry/fs/sys/sys.go +++ b/pkg/sentry/fs/sys/sys.go @@ -16,10 +16,10 @@ package sys import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) func newFile(ctx context.Context, node fs.InodeOperations, msrc *fs.MountSource) *fs.Inode { diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index 092668e8d..d16cdb4df 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -7,13 +7,13 @@ go_library( srcs = ["timerfd.go"], visibility = ["//pkg/sentry:internal"], deps = [ - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel/time", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index f8bf663bb..88c344089 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -19,13 +19,13 @@ package timerfd import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 04776555f..aa7199014 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -14,8 +14,9 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/metric", - "//pkg/sentry/context", + "//pkg/safemem", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", @@ -25,12 +26,11 @@ go_library( "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", - "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -41,10 +41,10 @@ go_test( srcs = ["file_test.go"], library = ":tmpfs", deps = [ - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/kernel/contexttest", "//pkg/sentry/usage", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/tmpfs/file_regular.go b/pkg/sentry/fs/tmpfs/file_regular.go index 9a6943fe4..614f8f8a1 100644 --- a/pkg/sentry/fs/tmpfs/file_regular.go +++ b/pkg/sentry/fs/tmpfs/file_regular.go @@ -15,11 +15,11 @@ package tmpfs import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go index 0075ef023..aaba35502 100644 --- a/pkg/sentry/fs/tmpfs/file_test.go +++ b/pkg/sentry/fs/tmpfs/file_test.go @@ -18,11 +18,11 @@ import ( "bytes" "testing" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) func newFileInode(ctx context.Context) *fs.Inode { diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go index be98ad751..d5be56c3f 100644 --- a/pkg/sentry/fs/tmpfs/fs.go +++ b/pkg/sentry/fs/tmpfs/fs.go @@ -19,7 +19,7 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index f1c87fe41..dabc10662 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -20,18 +20,18 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/metric" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) var ( diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 0f718e236..c00cef0a5 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -17,7 +17,7 @@ package tmpfs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" @@ -25,8 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) var fsInfo = fs.Info{ diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 29f804c6c..5cb0e0417 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -16,20 +16,20 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", - "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -41,7 +41,7 @@ go_test( library = ":tty", deps = [ "//pkg/abi/linux", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/usermem", + "//pkg/sentry/contexttest", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 88aa66b24..108654827 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -21,14 +21,14 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go index edee56c12..8fe05ebe5 100644 --- a/pkg/sentry/fs/tty/fs.go +++ b/pkg/sentry/fs/tty/fs.go @@ -15,7 +15,7 @@ package tty import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 9fe02657e..12b1c6097 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -19,11 +19,11 @@ import ( "unicode/utf8" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 6b07f6bf2..f62da49bd 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -16,13 +16,13 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index 21ccc6f32..1ca79c0b2 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -16,12 +16,12 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index 2a51e6bab..db55cdc48 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -16,12 +16,12 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index 917f90cc0..5883f26db 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -16,11 +16,11 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Terminal is a pseudoterminal. diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go index 59f07ff8e..2cbc05678 100644 --- a/pkg/sentry/fs/tty/tty_test.go +++ b/pkg/sentry/fs/tty/tty_test.go @@ -18,8 +18,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/contexttest" + "gvisor.dev/gvisor/pkg/usermem" ) func TestSimpleMasterToSlave(t *testing.T) { diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index a718920d5..6f78f478f 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -35,21 +35,21 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/fd", "//pkg/fspath", "//pkg/log", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/safemem", "//pkg/sentry/syscalls/linux", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -73,14 +73,14 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/fspath", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", + "//pkg/usermem", "//runsc/testutil", "@com_github_google_go-cmp//cmp:go_default_library", "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 12f3990c1..6c5a559fd 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -7,9 +7,9 @@ go_test( size = "small", srcs = ["benchmark_test.go"], deps = [ + "//pkg/context", "//pkg/fspath", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/ext", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index a56b03711..d1436b943 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -24,9 +24,9 @@ import ( "strings" "testing" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 8944171c8..ebb72b75e 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -17,8 +17,8 @@ package ext import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/memmap" diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index 4b7d17dc6..373d23b74 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -21,9 +21,9 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 6c14a1e2d..05f992826 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -25,14 +25,14 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/runsc/testutil" ) diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go index 841274daf..92f7da40d 100644 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -16,7 +16,7 @@ package ext import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 9afb1a84c..07bf58953 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -19,8 +19,8 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index d11153c90..30135ddb0 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -18,13 +18,13 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // regularFile represents a regular file's inode. This too follows the diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index bdf8705c1..1447a4dc1 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -15,11 +15,11 @@ package ext import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // symlink represents a symlink inode. diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 7bf83ccba..e73f1f857 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -29,16 +29,16 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fspath", "//pkg/log", "//pkg/refs", - "//pkg/sentry/context", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) @@ -49,13 +49,13 @@ go_test( deps = [ ":kernfs", "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", + "//pkg/usermem", "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 75624e0b1..373f801ff 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -18,11 +18,11 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // DynamicBytesFile implements kernfs.Inode and represents a read-only diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 5fa1fa67b..6104751c8 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -16,11 +16,11 @@ package kernfs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index a4600ad47..9d65d0179 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -20,8 +20,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 1700fffd9..adca2313f 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -19,8 +19,8 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 85bcdcc57..79ebea8a5 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -56,8 +56,8 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index fade59491..ee65cf491 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -21,14 +21,14 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const defaultMode linux.FileMode = 01777 diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index f19f12854..0ee7eb9b7 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -16,7 +16,7 @@ package kernfs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 3768f55b2..12aac2e6a 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -16,8 +16,9 @@ go_library( ], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", - "//pkg/sentry/context", + "//pkg/safemem", "//pkg/sentry/fs", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/inet", @@ -26,15 +27,14 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/mm", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/usermem", ], ) @@ -48,15 +48,15 @@ go_test( library = ":proc", deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fspath", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index f49819187..11477b6a9 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -19,7 +19,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index 91eded415..353e37195 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -19,7 +19,7 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index a0580f20d..eb5bc62c0 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -19,7 +19,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 7bc352ae9..efd3b3453 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -20,17 +20,17 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // mm gets the kernel task's MemoryManager. No additional reference is taken on diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 51f634716..e0cb9c47b 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -20,7 +20,7 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index ad3760e39..434998910 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -20,14 +20,14 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) type selfSymlink struct { diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go index 4aaf23e97..608fec017 100644 --- a/pkg/sentry/fsimpl/proc/tasks_net.go +++ b/pkg/sentry/fsimpl/proc/tasks_net.go @@ -22,8 +22,8 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -32,9 +32,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/usermem" ) func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index aabf2bf0c..ad963870b 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -19,7 +19,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go index 0a1d3f34b..be54897bb 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go @@ -20,7 +20,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/inet" ) diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 2c1635f33..6fc3524db 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -22,14 +22,14 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) var ( diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index beda141f1..66c0d8bc8 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -9,7 +9,7 @@ go_library( ], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1305ad01d..e35d52d17 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -20,7 +20,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 12053a5b6..efd5974c4 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -12,10 +12,10 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/cpuid", "//pkg/fspath", "//pkg/memutil", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -27,9 +27,9 @@ go_library( "//pkg/sentry/platform/kvm", "//pkg/sentry/platform/ptrace", "//pkg/sentry/time", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/usermem", "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 295da2d52..89f8c4915 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -21,9 +21,9 @@ import ( "runtime" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go index 2a723a89f..1c98335c1 100644 --- a/pkg/sentry/fsimpl/testutil/testutil.go +++ b/pkg/sentry/fsimpl/testutil/testutil.go @@ -24,12 +24,12 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // System represents the context for a single test. diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 857e98bc5..fb436860c 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -30,10 +30,11 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/amutex", + "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", @@ -43,12 +44,11 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", - "//pkg/sentry/safemem", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) @@ -59,10 +59,10 @@ go_test( deps = [ ":tmpfs", "//pkg/abi/linux", + "//pkg/context", "//pkg/fspath", "//pkg/refs", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/sentry/fs/tmpfs", "//pkg/sentry/kernel/auth", @@ -82,13 +82,13 @@ go_test( library = ":tmpfs", deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fspath", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/contexttest", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go index d88c83499..54241c8e8 100644 --- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go @@ -21,10 +21,10 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 887ca2619..dc0d27cf9 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -16,7 +16,7 @@ package tmpfs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index d726f03c5..5ee9cf1e9 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -19,8 +19,8 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index 482aabd52..0c57fdca3 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -16,11 +16,11 @@ package tmpfs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" ) type namedPipe struct { diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index 70b42a6ec..5ee7f2a72 100644 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -19,13 +19,13 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const fileName = "mypipe" diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 7c633c1b0..e9e6faf67 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -20,17 +20,17 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) type regularFile struct { diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go index 034a29fdb..32552e261 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go @@ -22,12 +22,12 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" ) // nextFileID is used to generate unique file names. diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 515f033f2..88dbd6e35 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -29,7 +29,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/pgalloc" diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index a145a5ca3..61c78569d 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -12,6 +12,6 @@ go_library( deps = [ "//pkg/fd", "//pkg/log", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go index 19335ca73..506c7864a 100644 --- a/pkg/sentry/hostmm/hostmm.go +++ b/pkg/sentry/hostmm/hostmm.go @@ -24,7 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // NotifyCurrentMemcgPressureCallback requests that f is called whenever the diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index aa621b724..334432abf 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -13,7 +13,7 @@ go_library( "test_stack.go", ], deps = [ - "//pkg/sentry/context", + "//pkg/context", "//pkg/tcpip/stack", ], ) diff --git a/pkg/sentry/inet/context.go b/pkg/sentry/inet/context.go index 4eda7dd1f..e8cc1bffd 100644 --- a/pkg/sentry/inet/context.go +++ b/pkg/sentry/inet/context.go @@ -15,7 +15,7 @@ package inet import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // contextID is the inet package's type for context.Context.Value keys. diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index cebaccd92..0738946d9 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -153,14 +153,15 @@ go_library( "//pkg/binary", "//pkg/bits", "//pkg/bpf", + "//pkg/context", "//pkg/cpuid", "//pkg/eventchannel", "//pkg/log", "//pkg/metric", "//pkg/refs", + "//pkg/safemem", "//pkg/secio", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/lock", @@ -180,7 +181,6 @@ go_library( "//pkg/sentry/mm", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", - "//pkg/sentry/safemem", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/time", @@ -188,7 +188,6 @@ go_library( "//pkg/sentry/unimpl:unimplemented_syscall_go_proto", "//pkg/sentry/uniqueid", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/state", "//pkg/state/statefile", "//pkg/sync", @@ -196,6 +195,7 @@ go_library( "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/stack", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -212,9 +212,9 @@ go_test( library = ":kernel", deps = [ "//pkg/abi", + "//pkg/context", "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/sentry/fs/filetest", "//pkg/sentry/kernel/sched", @@ -222,8 +222,8 @@ go_test( "//pkg/sentry/pgalloc", "//pkg/sentry/time", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 64537c9be..2bc49483a 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -61,8 +61,8 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/bits", + "//pkg/context", "//pkg/log", - "//pkg/sentry/context", "//pkg/sync", "//pkg/syserror", ], diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go index 5c0e7d6b6..ef5723127 100644 --- a/pkg/sentry/kernel/auth/context.go +++ b/pkg/sentry/kernel/auth/context.go @@ -15,7 +15,7 @@ package auth import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // contextID is the auth package's type for context.Context.Value keys. diff --git a/pkg/sentry/kernel/auth/id_map.go b/pkg/sentry/kernel/auth/id_map.go index 3d74bc610..28cbe159d 100644 --- a/pkg/sentry/kernel/auth/id_map.go +++ b/pkg/sentry/kernel/auth/id_map.go @@ -16,7 +16,7 @@ package auth import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index 3c9dceaba..0c40bf315 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -17,8 +17,8 @@ package kernel import ( "time" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" ) // contextID is the kernel package's type for context.Context.Value keys. diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD index daff608d7..9d26392c0 100644 --- a/pkg/sentry/kernel/contexttest/BUILD +++ b/pkg/sentry/kernel/contexttest/BUILD @@ -8,8 +8,8 @@ go_library( srcs = ["contexttest.go"], visibility = ["//pkg/sentry:internal"], deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/sentry/contexttest", "//pkg/sentry/kernel", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", diff --git a/pkg/sentry/kernel/contexttest/contexttest.go b/pkg/sentry/kernel/contexttest/contexttest.go index 82f9d8922..22c340e56 100644 --- a/pkg/sentry/kernel/contexttest/contexttest.go +++ b/pkg/sentry/kernel/contexttest/contexttest.go @@ -19,8 +19,8 @@ package contexttest import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index 19e16ab3a..dedf0fa15 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -24,13 +24,13 @@ go_library( ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/context", "//pkg/refs", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/usermem", "//pkg/sync", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -43,7 +43,7 @@ go_test( ], library = ":epoll", deps = [ - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fs/filetest", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index e84742993..8bffb78fc 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -20,13 +20,13 @@ import ( "fmt" "syscall" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go index 4a20d4c82..22630e9c5 100644 --- a/pkg/sentry/kernel/epoll/epoll_test.go +++ b/pkg/sentry/kernel/epoll/epoll_test.go @@ -17,7 +17,7 @@ package epoll import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs/filetest" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index ee2d74864..9983a32e5 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -8,14 +8,14 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fdnotifier", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -26,8 +26,8 @@ go_test( srcs = ["eventfd_test.go"], library = ":eventfd", deps = [ - "//pkg/sentry/context/contexttest", - "//pkg/sentry/usermem", + "//pkg/sentry/contexttest", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 687690679..87951adeb 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -21,14 +21,14 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/eventfd/eventfd_test.go b/pkg/sentry/kernel/eventfd/eventfd_test.go index 018c7f3ef..9b4892f74 100644 --- a/pkg/sentry/kernel/eventfd/eventfd_test.go +++ b/pkg/sentry/kernel/eventfd/eventfd_test.go @@ -17,8 +17,8 @@ package eventfd import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/contexttest" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 0ad4135b3..9460bb235 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -22,8 +22,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 86164df49..261b815f2 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -18,8 +18,8 @@ import ( "runtime" "testing" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/filetest" "gvisor.dev/gvisor/pkg/sentry/limits" diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index f413d8ae2..c5021f2db 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -36,12 +36,12 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", - "//pkg/sentry/context", "//pkg/sentry/memmap", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) @@ -51,7 +51,7 @@ go_test( srcs = ["futex_test.go"], library = ":futex", deps = [ - "//pkg/sentry/usermem", "//pkg/sync", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index d1931c8f4..732e66da4 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -20,9 +20,9 @@ package futex import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // KeyKind indicates the type of a Key. diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go index c23126ca5..7c5c7665b 100644 --- a/pkg/sentry/kernel/futex/futex_test.go +++ b/pkg/sentry/kernel/futex/futex_test.go @@ -22,8 +22,8 @@ import ( "testing" "unsafe" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // testData implements the Target interface, and allows us to diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index c85e97fef..7b90fac5a 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -40,12 +40,12 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/timerfd" "gvisor.dev/gvisor/pkg/sentry/hostcpu" diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 2c7b6206f..4c049d5b4 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -33,16 +33,16 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/amutex", + "//pkg/context", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/safemem", - "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -57,11 +57,11 @@ go_test( ], library = ":pipe", deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 1c0f34269..fe3be5dbd 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -17,7 +17,7 @@ package pipe import ( "io" - "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sync" ) diff --git a/pkg/sentry/kernel/pipe/buffer_test.go b/pkg/sentry/kernel/pipe/buffer_test.go index ee1b90115..4d54b8b8f 100644 --- a/pkg/sentry/kernel/pipe/buffer_test.go +++ b/pkg/sentry/kernel/pipe/buffer_test.go @@ -18,7 +18,7 @@ import ( "testing" "unsafe" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) func TestBufferSize(t *testing.T) { diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 716f589af..4b688c627 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -16,7 +16,7 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sync" diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index 16fa80abe..ab75a87ff 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -18,11 +18,11 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) type sleeper struct { diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index e4fd7d420..08410283f 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -20,7 +20,7 @@ import ( "sync/atomic" "syscall" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go index e3a14b665..bda739dbe 100644 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ b/pkg/sentry/kernel/pipe/pipe_test.go @@ -18,9 +18,9 @@ import ( "bytes" "testing" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 8394eb78b..80158239e 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -21,10 +21,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index b4d29fc77..b2b5691ee 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -17,11 +17,11 @@ package pipe import ( "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // ReaderWriter satisfies the FileOperations interface and services both diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 6f83e3cee..a5675bd70 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -16,12 +16,12 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 3be171cdc..35ad97d5d 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // ptraceOptions are the subset of options controlling a task's ptrace behavior diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go index 5514cf432..cef1276ec 100644 --- a/pkg/sentry/kernel/ptrace_amd64.go +++ b/pkg/sentry/kernel/ptrace_amd64.go @@ -18,8 +18,8 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // ptraceArch implements arch-specific ptrace commands. diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go index 61e412911..d971b96b3 100644 --- a/pkg/sentry/kernel/ptrace_arm64.go +++ b/pkg/sentry/kernel/ptrace_arm64.go @@ -17,8 +17,8 @@ package kernel import ( - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // ptraceArch implements arch-specific ptrace commands. diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index b14429854..efebfd872 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/hostcpu" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Restartable sequences. diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index 2347dcf36..c38c5a40c 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const maxSyscallFilterInstructions = 1 << 15 diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 76e19b551..65e5427c1 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -24,8 +24,8 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", @@ -41,8 +41,8 @@ go_test( library = ":semaphore", deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/sentry/contexttest", "//pkg/sentry/kernel/auth", "//pkg/syserror", ], diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 18299814e..1000f3287 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -19,8 +19,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" diff --git a/pkg/sentry/kernel/semaphore/semaphore_test.go b/pkg/sentry/kernel/semaphore/semaphore_test.go index c235f6ca4..e47acefdf 100644 --- a/pkg/sentry/kernel/semaphore/semaphore_test.go +++ b/pkg/sentry/kernel/semaphore/semaphore_test.go @@ -18,8 +18,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 5547c5abf..bfd779837 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -11,9 +11,9 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/refs", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", @@ -22,8 +22,8 @@ go_library( "//pkg/sentry/pgalloc", "//pkg/sentry/platform", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 8ddef7eb8..208569057 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -37,9 +37,9 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -47,9 +47,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Key represents a shm segment key. Analogous to a file name. diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 5d44773d4..3eb78e91b 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -9,14 +9,14 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 28be4a939..8243bb93e 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -18,14 +18,14 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index d2d01add4..93c4fe969 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // maxSyscallNum is the highest supported syscall number. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 978d66da8..95adf2778 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -35,8 +35,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 247bd4aba..53d4d211b 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -17,8 +17,8 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // SharingOptions controls what resources are shared by a new task created by diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index bb5560acf..2d6e7733c 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -18,13 +18,13 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/usermem" ) var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC) diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index c211b5b74..a53e77c9f 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -16,7 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Futex returns t's futex manager. diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index 0fb3661de..41259210c 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -20,7 +20,7 @@ import ( "sort" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 6357273d3..5568c91bc 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -26,7 +26,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // A taskRunState is a reified state in the task state machine. See README.md diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 39cd1340d..8802db142 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -26,8 +26,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 58af16ee2..de838beef 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // TaskConfig defines the configuration of a new Task (see below). diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 3180f5560..d555d69a8 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -25,8 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index 518bfe1bd..2bf3ce8a8 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -18,8 +18,8 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // MAX_RW_COUNT is the maximum size in bytes of a single read or write. diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index d49594d9f..7ba7dc50c 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -11,7 +11,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sync", "//pkg/syserror", "//pkg/waiter", diff --git a/pkg/sentry/kernel/time/context.go b/pkg/sentry/kernel/time/context.go index 8ef483dd3..00b729d88 100644 --- a/pkg/sentry/kernel/time/context.go +++ b/pkg/sentry/kernel/time/context.go @@ -15,7 +15,7 @@ package time import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // contextID is the time package's type for context.Context.Value keys. diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go index 849c5b646..cf2f7ca72 100644 --- a/pkg/sentry/kernel/timekeeper_test.go +++ b/pkg/sentry/kernel/timekeeper_test.go @@ -17,12 +17,12 @@ package kernel import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/pgalloc" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // mockClocks is a sentrytime.Clocks that simply returns the times in the diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index fdd10c56c..f1b3c212c 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -18,10 +18,10 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // vdsoParams are the parameters exposed to the VDSO. diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 67869757f..cf591c4c1 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -12,7 +12,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sync", ], ) diff --git a/pkg/sentry/limits/context.go b/pkg/sentry/limits/context.go index 6972749ed..77e1fe217 100644 --- a/pkg/sentry/limits/context.go +++ b/pkg/sentry/limits/context.go @@ -15,7 +15,7 @@ package limits import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // contextID is the limit package's type for context.Context.Value keys. diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index d4ad2bd6c..23790378a 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -24,11 +24,12 @@ go_library( "//pkg/abi", "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/cpuid", "//pkg/log", "//pkg/rand", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", @@ -37,12 +38,11 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/pgalloc", - "//pkg/sentry/safemem", "//pkg/sentry/uniqueid", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 6299a3e2f..122ed05c2 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -23,16 +23,16 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go index ccf909cac..098a45d36 100644 --- a/pkg/sentry/loader/interpreter.go +++ b/pkg/sentry/loader/interpreter.go @@ -18,10 +18,10 @@ import ( "bytes" "io" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index b03eeb005..9a613d6b7 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -24,16 +24,16 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // LoadArgs holds specifications for an executable file to be loaded. diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index df8a81907..52f446ed7 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -20,20 +20,20 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index f9a65f086..a98b66de1 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -38,11 +38,11 @@ go_library( ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/context", "//pkg/log", - "//pkg/sentry/context", "//pkg/sentry/platform", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", ], ) @@ -51,5 +51,5 @@ go_test( size = "small", srcs = ["mapping_set_test.go"], library = ":memmap", - deps = ["//pkg/sentry/usermem"], + deps = ["//pkg/usermem"], ) diff --git a/pkg/sentry/memmap/mapping_set.go b/pkg/sentry/memmap/mapping_set.go index 0a5b7ce45..d609c1ae0 100644 --- a/pkg/sentry/memmap/mapping_set.go +++ b/pkg/sentry/memmap/mapping_set.go @@ -18,7 +18,7 @@ import ( "fmt" "math" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // MappingSet maps offsets into a Mappable to mappings of those offsets. It is diff --git a/pkg/sentry/memmap/mapping_set_test.go b/pkg/sentry/memmap/mapping_set_test.go index f9b11a59c..d39efe38f 100644 --- a/pkg/sentry/memmap/mapping_set_test.go +++ b/pkg/sentry/memmap/mapping_set_test.go @@ -18,7 +18,7 @@ import ( "reflect" "testing" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) type testMappingSpace struct { diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 16a722a13..c6db9fc8f 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -18,9 +18,9 @@ package memmap import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Mappable represents a memory-mappable object, a mutable mapping from uint64 diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index bd6399fa2..e5729ced5 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -27,7 +27,7 @@ go_template_instance( "minDegree": "8", }, imports = { - "usermem": "gvisor.dev/gvisor/pkg/sentry/usermem", + "usermem": "gvisor.dev/gvisor/pkg/usermem", }, package = "mm", prefix = "vma", @@ -47,7 +47,7 @@ go_template_instance( "minDegree": "8", }, imports = { - "usermem": "gvisor.dev/gvisor/pkg/sentry/usermem", + "usermem": "gvisor.dev/gvisor/pkg/usermem", }, package = "mm", prefix = "pma", @@ -99,10 +99,12 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/atomicbitops", + "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/safecopy", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/proc/seqfile", "//pkg/sentry/kernel/auth", @@ -112,13 +114,11 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/safemem", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/buffer", + "//pkg/usermem", ], ) @@ -128,14 +128,14 @@ go_test( srcs = ["mm_test.go"], library = ":mm", deps = [ + "//pkg/context", "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/limits", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index cfebcfd42..e58a63deb 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -20,7 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // AddressSpace returns the platform.AddressSpace bound to mm. diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 4b48866ad..cb29d94b0 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -16,15 +16,15 @@ package mm import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // aioManager creates and manages asynchronous I/O contexts. diff --git a/pkg/sentry/mm/debug.go b/pkg/sentry/mm/debug.go index df9adf708..c273c982e 100644 --- a/pkg/sentry/mm/debug.go +++ b/pkg/sentry/mm/debug.go @@ -18,7 +18,7 @@ import ( "bytes" "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) const ( diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go index b03e7d020..fa776f9c6 100644 --- a/pkg/sentry/mm/io.go +++ b/pkg/sentry/mm/io.go @@ -15,11 +15,11 @@ package mm import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // There are two supported ways to copy data to/from application virtual diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 4e9ca1de6..47b8fbf43 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -19,13 +19,13 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // NewMemoryManager returns a new MemoryManager with no mappings and 1 user. diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go index d2a01d48a..f550acae0 100644 --- a/pkg/sentry/mm/metadata.go +++ b/pkg/sentry/mm/metadata.go @@ -17,7 +17,7 @@ package mm import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Dumpability describes if and how core dumps should be created. diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 78cc9e6e4..09e582dd3 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -35,14 +35,14 @@ package mm import ( + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // MemoryManager implements a virtual address space. diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index 4d2bfaaed..edacca741 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -17,15 +17,15 @@ package mm import ( "testing" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) func testMemoryManager(ctx context.Context) *MemoryManager { diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index c976c6f45..62e4c20af 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -17,14 +17,14 @@ package mm import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // existingPMAsLocked checks that pmas exist for all addresses in ar, and diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go index 79610acb7..1ab92f046 100644 --- a/pkg/sentry/mm/procfs.go +++ b/pkg/sentry/mm/procfs.go @@ -19,10 +19,10 @@ import ( "fmt" "strings" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/mm/save_restore.go b/pkg/sentry/mm/save_restore.go index 93259c5a3..f56215d9a 100644 --- a/pkg/sentry/mm/save_restore.go +++ b/pkg/sentry/mm/save_restore.go @@ -17,7 +17,7 @@ package mm import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // InvalidateUnsavable invokes memmap.Mappable.InvalidateUnsavable on all diff --git a/pkg/sentry/mm/shm.go b/pkg/sentry/mm/shm.go index b9f2d23e5..6432731d4 100644 --- a/pkg/sentry/mm/shm.go +++ b/pkg/sentry/mm/shm.go @@ -15,10 +15,10 @@ package mm import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/shm" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // DetachShm unmaps a sysv shared memory segment. diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index ea2d7af74..9ad52082d 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -15,14 +15,14 @@ package mm import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // SpecialMappable implements memmap.MappingIdentity and memmap.Mappable with diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index c2466c988..c5dfa5972 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -19,14 +19,14 @@ import ( mrand "math/rand" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // HandleUserFault handles an application page fault. sp is the faulting diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index f2fd70799..9a14e69e6 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -18,13 +18,13 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Preconditions: mm.mappingMu must be locked for writing. opts must be valid diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 02385a3ce..1eeb9f317 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -61,18 +61,18 @@ go_library( ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/context", "//pkg/log", "//pkg/memutil", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/hostmm", "//pkg/sentry/platform", - "//pkg/sentry/safemem", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/state", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) @@ -81,5 +81,5 @@ go_test( size = "small", srcs = ["pgalloc_test.go"], library = ":pgalloc", - deps = ["//pkg/sentry/usermem"], + deps = ["//pkg/usermem"], ) diff --git a/pkg/sentry/pgalloc/context.go b/pkg/sentry/pgalloc/context.go index 11ccf897b..d25215418 100644 --- a/pkg/sentry/pgalloc/context.go +++ b/pkg/sentry/pgalloc/context.go @@ -15,7 +15,7 @@ package pgalloc import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // contextID is this package's type for context.Context.Value keys. diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index c99e023d9..577e9306a 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -29,15 +29,15 @@ import ( "syscall" "time" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostmm" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // MemoryFile is a platform.File whose pages may be allocated to arbitrary diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go index 428e6a859..293f22c6b 100644 --- a/pkg/sentry/pgalloc/pgalloc_test.go +++ b/pkg/sentry/pgalloc/pgalloc_test.go @@ -17,7 +17,7 @@ package pgalloc import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index aafce1d00..f8385c146 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -25,8 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/usermem" ) // SaveTo writes f's state to the given stream. diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 006450b2d..453241eca 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -26,14 +26,14 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/atomicbitops", + "//pkg/context", "//pkg/log", + "//pkg/safecopy", + "//pkg/safemem", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/safemem", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/context.go b/pkg/sentry/platform/context.go index e29bc4485..6759cda65 100644 --- a/pkg/sentry/platform/context.go +++ b/pkg/sentry/platform/context.go @@ -15,7 +15,7 @@ package platform import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // contextID is the auth package's type for context.Context.Value keys. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index a4532a766..159f7eafd 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -44,16 +44,16 @@ go_library( "//pkg/cpuid", "//pkg/log", "//pkg/procid", + "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/ring0", "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/platform/safecopy", "//pkg/sentry/time", - "//pkg/sentry/usermem", "//pkg/sync", + "//pkg/usermem", ], ) @@ -75,6 +75,6 @@ go_test( "//pkg/sentry/platform/kvm/testutil", "//pkg/sentry/platform/ring0", "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index a25f3c449..be213bfe8 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // dirtySet tracks vCPUs for invalidation. diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index 30dbb74d6..35cd55fef 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -19,9 +19,9 @@ import ( "reflect" "syscall" + "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" ) // bluepill enters guest mode. diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index f6459cda9..e34f46aeb 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -18,7 +18,7 @@ import ( "sync/atomic" "syscall" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go index 99450d22d..c769ac7b4 100644 --- a/pkg/sentry/platform/kvm/context.go +++ b/pkg/sentry/platform/kvm/context.go @@ -19,7 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // context is an implementation of the platform context. diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index d337c5c7c..972ba85c3 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // KVM represents a lightweight VM context. diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 30df725d4..c42752d50 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -27,7 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) var dummyFPState = (*byte)(arch.NewFloatingPointData()) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index e6d912168..8076c7529 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -25,8 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // machine contains state associated with the VM as a whole. diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 873e39dc7..923ce3909 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -26,7 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // initArchState initializes architecture-specific state. diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 3b1f20219..09552837a 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -20,7 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) type vCPUArchState struct { diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 3f2f97a6b..1c8384e6b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -26,7 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // setMemoryRegion initializes a region. diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index 91de5dab1..f7fa2f98d 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) type region struct { diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go index 2d68855ef..c8897d34f 100644 --- a/pkg/sentry/platform/kvm/virtual_map.go +++ b/pkg/sentry/platform/kvm/virtual_map.go @@ -22,7 +22,7 @@ import ( "regexp" "strconv" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) type virtualRegion struct { diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go index 6a2f145be..327e2be4f 100644 --- a/pkg/sentry/platform/kvm/virtual_map_test.go +++ b/pkg/sentry/platform/kvm/virtual_map_test.go @@ -18,7 +18,7 @@ import ( "syscall" "testing" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) type checker struct { diff --git a/pkg/sentry/platform/mmap_min_addr.go b/pkg/sentry/platform/mmap_min_addr.go index 999787462..091c2e365 100644 --- a/pkg/sentry/platform/mmap_min_addr.go +++ b/pkg/sentry/platform/mmap_min_addr.go @@ -20,7 +20,7 @@ import ( "strconv" "strings" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // systemMMapMinAddrSource is the source file. diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index ec22dbf87..2ca696382 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -22,10 +22,10 @@ import ( "os" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // Platform provides abstractions for execution contexts (Context, diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 3bcc5e040..95abd321e 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -25,14 +25,14 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/procid", + "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/hostcpu", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", "//pkg/sync", + "//pkg/usermem", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index bb0e03880..03adb624b 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -51,8 +51,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) var ( diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 72c7ec564..6c0ed7b3e 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -20,7 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // getRegs gets the general purpose register set. diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go index aa1b87237..341dde143 100644 --- a/pkg/sentry/platform/ptrace/stub_unsafe.go +++ b/pkg/sentry/platform/ptrace/stub_unsafe.go @@ -19,8 +19,8 @@ import ( "syscall" "unsafe" - "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/usermem" ) // stub is defined in arch-specific assembly. diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 15dc46a5b..31b7cec53 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -25,8 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" ) // Linux kernel errnos which "should never be seen by user programs", but will diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 6dee8fcc5..934b6fbcd 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -78,6 +78,6 @@ go_library( deps = [ "//pkg/cpuid", "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 9dae0dccb..9c6c2cf5c 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -18,7 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index a850ce6cf..1583dda12 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -18,7 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 147311ed3..4cae10459 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -28,6 +28,6 @@ go_binary( deps = [ "//pkg/cpuid", "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 8b5cdd6c1..971eed7fa 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -93,8 +93,8 @@ go_library( "//pkg/sentry/platform/ring0:__subpackages__", ], deps = [ - "//pkg/sentry/usermem", "//pkg/sync", + "//pkg/usermem", ], ) @@ -108,5 +108,5 @@ go_test( "walker_check.go", ], library = ":pagetables", - deps = ["//pkg/sentry/usermem"], + deps = ["//pkg/usermem"], ) diff --git a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go b/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go index a90394a33..d08bfdeb3 100644 --- a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go +++ b/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go @@ -17,7 +17,7 @@ package pagetables import ( "unsafe" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // newAlignedPTEs returns a set of aligned PTEs. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 30c64a372..87e88e97d 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -21,7 +21,7 @@ package pagetables import ( - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // PageTables is a set of page tables. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index e78424766..78510ebed 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -19,7 +19,7 @@ package pagetables import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // archPageTables is architecture-specific data. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go index 35e917526..54e8e554f 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go @@ -19,7 +19,7 @@ package pagetables import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) func Test2MAnd4K(t *testing.T) { diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go index 254116233..2f73d424f 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go @@ -19,7 +19,7 @@ package pagetables import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) func Test2MAnd4K(t *testing.T) { diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go index 6e95ad2b9..5c88d087d 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go @@ -17,7 +17,7 @@ package pagetables import ( "testing" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) type mapping struct { diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go index 3e2383c5e..dcf061df9 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go @@ -19,7 +19,7 @@ package pagetables import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // archPageTables is architecture-specific data. diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD deleted file mode 100644 index b8747585b..000000000 --- a/pkg/sentry/platform/safecopy/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "safecopy", - srcs = [ - "atomic_amd64.s", - "atomic_arm64.s", - "memclr_amd64.s", - "memclr_arm64.s", - "memcpy_amd64.s", - "memcpy_arm64.s", - "safecopy.go", - "safecopy_unsafe.go", - "sighandler_amd64.s", - "sighandler_arm64.s", - ], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/syserror"], -) - -go_test( - name = "safecopy_test", - srcs = [ - "safecopy_test.go", - ], - library = ":safecopy", -) diff --git a/pkg/sentry/platform/safecopy/LICENSE b/pkg/sentry/platform/safecopy/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/sentry/platform/safecopy/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sentry/platform/safecopy/atomic_amd64.s b/pkg/sentry/platform/safecopy/atomic_amd64.s deleted file mode 100644 index a0cd78f33..000000000 --- a/pkg/sentry/platform/safecopy/atomic_amd64.s +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -// handleSwapUint32Fault returns the value stored in DI. Control is transferred -// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal -// number stored in DI. -// -// It must have the same frame configuration as swapUint32 so that it can undo -// any potential call frame set up by the assembler. -TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 - MOVL DI, sig+20(FP) - RET - -// swapUint32 atomically stores new into *addr and returns (the previous *addr -// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the -// value of old is unspecified, and sig is the number of the signal that was -// received. -// -// Preconditions: addr must be aligned to a 4-byte boundary. -// -//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) -TEXT ·swapUint32(SB), NOSPLIT, $0-24 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleSwapUint32Fault will store a different value in this address. - MOVL $0, sig+20(FP) - - MOVQ addr+0(FP), DI - MOVL new+8(FP), AX - XCHGL AX, 0(DI) - MOVL AX, old+16(FP) - RET - -// handleSwapUint64Fault returns the value stored in DI. Control is transferred -// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal -// number stored in DI. -// -// It must have the same frame configuration as swapUint64 so that it can undo -// any potential call frame set up by the assembler. -TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 - MOVL DI, sig+24(FP) - RET - -// swapUint64 atomically stores new into *addr and returns (the previous *addr -// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the -// value of old is unspecified, and sig is the number of the signal that was -// received. -// -// Preconditions: addr must be aligned to a 8-byte boundary. -// -//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) -TEXT ·swapUint64(SB), NOSPLIT, $0-28 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleSwapUint64Fault will store a different value in this address. - MOVL $0, sig+24(FP) - - MOVQ addr+0(FP), DI - MOVQ new+8(FP), AX - XCHGQ AX, 0(DI) - MOVQ AX, old+16(FP) - RET - -// handleCompareAndSwapUint32Fault returns the value stored in DI. Control is -// transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the -// signal number stored in DI. -// -// It must have the same frame configuration as compareAndSwapUint32 so that it -// can undo any potential call frame set up by the assembler. -TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 - MOVL DI, sig+20(FP) - RET - -// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns -// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is -// received during the operation, the value of prev is unspecified, and sig is -// the number of the signal that was received. -// -// Preconditions: addr must be aligned to a 4-byte boundary. -// -//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) -TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 - // Store 0 as the returned signal number. If we run to completion, this is - // the value the caller will see; if a signal is received, - // handleCompareAndSwapUint32Fault will store a different value in this - // address. - MOVL $0, sig+20(FP) - - MOVQ addr+0(FP), DI - MOVL old+8(FP), AX - MOVL new+12(FP), DX - LOCK - CMPXCHGL DX, 0(DI) - MOVL AX, prev+16(FP) - RET - -// handleLoadUint32Fault returns the value stored in DI. Control is transferred -// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal -// number stored in DI. -// -// It must have the same frame configuration as loadUint32 so that it can undo -// any potential call frame set up by the assembler. -TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 - MOVL DI, sig+12(FP) - RET - -// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS -// signal is received, the value returned is unspecified, and sig is the number -// of the signal that was received. -// -// Preconditions: addr must be aligned to a 4-byte boundary. -// -//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) -TEXT ·loadUint32(SB), NOSPLIT, $0-16 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleLoadUint32Fault will store a different value in this address. - MOVL $0, sig+12(FP) - - MOVQ addr+0(FP), AX - MOVL (AX), BX - MOVL BX, val+8(FP) - RET diff --git a/pkg/sentry/platform/safecopy/atomic_arm64.s b/pkg/sentry/platform/safecopy/atomic_arm64.s deleted file mode 100644 index d58ed71f7..000000000 --- a/pkg/sentry/platform/safecopy/atomic_arm64.s +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -#include "textflag.h" - -// handleSwapUint32Fault returns the value stored in R1. Control is transferred -// to it when swapUint32 below receives SIGSEGV or SIGBUS, with the signal -// number stored in R1. -// -// It must have the same frame configuration as swapUint32 so that it can undo -// any potential call frame set up by the assembler. -TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 - MOVW R1, sig+20(FP) - RET - -// See the corresponding doc in safecopy_unsafe.go -// -// The code is derived from Go source runtime/internal/atomic.Xchg. -// -//func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) -TEXT ·swapUint32(SB), NOSPLIT, $0-24 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleSwapUint32Fault will store a different value in this address. - MOVW $0, sig+20(FP) -again: - MOVD addr+0(FP), R0 - MOVW new+8(FP), R1 - LDAXRW (R0), R2 - STLXRW R1, (R0), R3 - CBNZ R3, again - MOVW R2, old+16(FP) - RET - -// handleSwapUint64Fault returns the value stored in R1. Control is transferred -// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal -// number stored in R1. -// -// It must have the same frame configuration as swapUint64 so that it can undo -// any potential call frame set up by the assembler. -TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 - MOVW R1, sig+24(FP) - RET - -// See the corresponding doc in safecopy_unsafe.go -// -// The code is derived from Go source runtime/internal/atomic.Xchg64. -// -//func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) -TEXT ·swapUint64(SB), NOSPLIT, $0-28 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleSwapUint64Fault will store a different value in this address. - MOVW $0, sig+24(FP) -again: - MOVD addr+0(FP), R0 - MOVD new+8(FP), R1 - LDAXR (R0), R2 - STLXR R1, (R0), R3 - CBNZ R3, again - MOVD R2, old+16(FP) - RET - -// handleCompareAndSwapUint32Fault returns the value stored in R1. Control is -// transferred to it when compareAndSwapUint32 below receives SIGSEGV or SIGBUS, -// with the signal number stored in R1. -// -// It must have the same frame configuration as compareAndSwapUint32 so that it -// can undo any potential call frame set up by the assembler. -TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 - MOVW R1, sig+20(FP) - RET - -// See the corresponding doc in safecopy_unsafe.go -// -// The code is derived from Go source runtime/internal/atomic.Cas. -// -//func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) -TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 - // Store 0 as the returned signal number. If we run to completion, this is - // the value the caller will see; if a signal is received, - // handleCompareAndSwapUint32Fault will store a different value in this - // address. - MOVW $0, sig+20(FP) - - MOVD addr+0(FP), R0 - MOVW old+8(FP), R1 - MOVW new+12(FP), R2 -again: - LDAXRW (R0), R3 - CMPW R1, R3 - BNE done - STLXRW R2, (R0), R4 - CBNZ R4, again -done: - MOVW R3, prev+16(FP) - RET - -// handleLoadUint32Fault returns the value stored in DI. Control is transferred -// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal -// number stored in DI. -// -// It must have the same frame configuration as loadUint32 so that it can undo -// any potential call frame set up by the assembler. -TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 - MOVW R1, sig+12(FP) - RET - -// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS -// signal is received, the value returned is unspecified, and sig is the number -// of the signal that was received. -// -// Preconditions: addr must be aligned to a 4-byte boundary. -// -//func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) -TEXT ·loadUint32(SB), NOSPLIT, $0-16 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleLoadUint32Fault will store a different value in this address. - MOVW $0, sig+12(FP) - - MOVD addr+0(FP), R0 - LDARW (R0), R1 - MOVW R1, val+8(FP) - RET diff --git a/pkg/sentry/platform/safecopy/memclr_amd64.s b/pkg/sentry/platform/safecopy/memclr_amd64.s deleted file mode 100644 index 64cf32f05..000000000 --- a/pkg/sentry/platform/safecopy/memclr_amd64.s +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -#include "textflag.h" - -// handleMemclrFault returns (the value stored in AX, the value stored in DI). -// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS, -// with the faulting address stored in AX and the signal number stored in DI. -// -// It must have the same frame configuration as memclr so that it can undo any -// potential call frame set up by the assembler. -TEXT handleMemclrFault(SB), NOSPLIT, $0-28 - MOVQ AX, addr+16(FP) - MOVL DI, sig+24(FP) - RET - -// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS -// signal is received during the write, it returns the address that caused the -// fault and the number of the signal that was received. Otherwise, it returns -// an unspecified address and a signal number of 0. -// -// Data is written in order, such that if a fault happens at address p, it is -// safe to assume that all data before p-maxRegisterSize has already been -// successfully written. -// -// The code is derived from runtime.memclrNoHeapPointers. -// -// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) -TEXT ·memclr(SB), NOSPLIT, $0-28 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleMemclrFault will store a different value in this address. - MOVL $0, sig+24(FP) - - MOVQ ptr+0(FP), DI - MOVQ n+8(FP), BX - XORQ AX, AX - - // MOVOU seems always faster than REP STOSQ. -tail: - TESTQ BX, BX - JEQ _0 - CMPQ BX, $2 - JBE _1or2 - CMPQ BX, $4 - JBE _3or4 - CMPQ BX, $8 - JB _5through7 - JE _8 - CMPQ BX, $16 - JBE _9through16 - PXOR X0, X0 - CMPQ BX, $32 - JBE _17through32 - CMPQ BX, $64 - JBE _33through64 - CMPQ BX, $128 - JBE _65through128 - CMPQ BX, $256 - JBE _129through256 - // TODO: use branch table and BSR to make this just a single dispatch - // TODO: for really big clears, use MOVNTDQ, even without AVX2. - -loop: - MOVOU X0, 0(DI) - MOVOU X0, 16(DI) - MOVOU X0, 32(DI) - MOVOU X0, 48(DI) - MOVOU X0, 64(DI) - MOVOU X0, 80(DI) - MOVOU X0, 96(DI) - MOVOU X0, 112(DI) - MOVOU X0, 128(DI) - MOVOU X0, 144(DI) - MOVOU X0, 160(DI) - MOVOU X0, 176(DI) - MOVOU X0, 192(DI) - MOVOU X0, 208(DI) - MOVOU X0, 224(DI) - MOVOU X0, 240(DI) - SUBQ $256, BX - ADDQ $256, DI - CMPQ BX, $256 - JAE loop - JMP tail - -_1or2: - MOVB AX, (DI) - MOVB AX, -1(DI)(BX*1) - RET -_0: - RET -_3or4: - MOVW AX, (DI) - MOVW AX, -2(DI)(BX*1) - RET -_5through7: - MOVL AX, (DI) - MOVL AX, -4(DI)(BX*1) - RET -_8: - // We need a separate case for 8 to make sure we clear pointers atomically. - MOVQ AX, (DI) - RET -_9through16: - MOVQ AX, (DI) - MOVQ AX, -8(DI)(BX*1) - RET -_17through32: - MOVOU X0, (DI) - MOVOU X0, -16(DI)(BX*1) - RET -_33through64: - MOVOU X0, (DI) - MOVOU X0, 16(DI) - MOVOU X0, -32(DI)(BX*1) - MOVOU X0, -16(DI)(BX*1) - RET -_65through128: - MOVOU X0, (DI) - MOVOU X0, 16(DI) - MOVOU X0, 32(DI) - MOVOU X0, 48(DI) - MOVOU X0, -64(DI)(BX*1) - MOVOU X0, -48(DI)(BX*1) - MOVOU X0, -32(DI)(BX*1) - MOVOU X0, -16(DI)(BX*1) - RET -_129through256: - MOVOU X0, (DI) - MOVOU X0, 16(DI) - MOVOU X0, 32(DI) - MOVOU X0, 48(DI) - MOVOU X0, 64(DI) - MOVOU X0, 80(DI) - MOVOU X0, 96(DI) - MOVOU X0, 112(DI) - MOVOU X0, -128(DI)(BX*1) - MOVOU X0, -112(DI)(BX*1) - MOVOU X0, -96(DI)(BX*1) - MOVOU X0, -80(DI)(BX*1) - MOVOU X0, -64(DI)(BX*1) - MOVOU X0, -48(DI)(BX*1) - MOVOU X0, -32(DI)(BX*1) - MOVOU X0, -16(DI)(BX*1) - RET diff --git a/pkg/sentry/platform/safecopy/memclr_arm64.s b/pkg/sentry/platform/safecopy/memclr_arm64.s deleted file mode 100644 index 7361b9067..000000000 --- a/pkg/sentry/platform/safecopy/memclr_arm64.s +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -#include "textflag.h" - -// handleMemclrFault returns (the value stored in R0, the value stored in R1). -// Control is transferred to it when memclr below receives SIGSEGV or SIGBUS, -// with the faulting address stored in R0 and the signal number stored in R1. -// -// It must have the same frame configuration as memclr so that it can undo any -// potential call frame set up by the assembler. -TEXT handleMemclrFault(SB), NOSPLIT, $0-28 - MOVD R0, addr+16(FP) - MOVW R1, sig+24(FP) - RET - -// See the corresponding doc in safecopy_unsafe.go -// -// The code is derived from runtime.memclrNoHeapPointers. -// -// func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) -TEXT ·memclr(SB), NOSPLIT, $0-28 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleMemclrFault will store a different value in this address. - MOVW $0, sig+24(FP) - MOVD ptr+0(FP), R0 - MOVD n+8(FP), R1 - - // If size is less than 16 bytes, use tail_zero to zero what remains - CMP $16, R1 - BLT tail_zero - // Get buffer offset into 16 byte aligned address for better performance - ANDS $15, R0, ZR - BNE unaligned_to_16 -aligned_to_16: - LSR $4, R1, R2 -zero_by_16: - STP.P (ZR, ZR), 16(R0) // Store pair with post index. - SUBS $1, R2, R2 - BNE zero_by_16 - ANDS $15, R1, R1 - BEQ end - - // Zero buffer with size=R1 < 16 -tail_zero: - TBZ $3, R1, tail_zero_4 - MOVD.P ZR, 8(R0) -tail_zero_4: - TBZ $2, R1, tail_zero_2 - MOVW.P ZR, 4(R0) -tail_zero_2: - TBZ $1, R1, tail_zero_1 - MOVH.P ZR, 2(R0) -tail_zero_1: - TBZ $0, R1, end - MOVB ZR, (R0) -end: - RET - -unaligned_to_16: - MOVD R0, R2 -head_loop: - MOVBU.P ZR, 1(R0) - ANDS $15, R0, ZR - BNE head_loop - // Adjust length for what remains - SUB R2, R0, R3 - SUB R3, R1 - // If size is less than 16 bytes, use tail_zero to zero what remains - CMP $16, R1 - BLT tail_zero - B aligned_to_16 diff --git a/pkg/sentry/platform/safecopy/memcpy_amd64.s b/pkg/sentry/platform/safecopy/memcpy_amd64.s deleted file mode 100644 index 129691d68..000000000 --- a/pkg/sentry/platform/safecopy/memcpy_amd64.s +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright © 1994-1999 Lucent Technologies Inc. All rights reserved. -// Revisions Copyright © 2000-2007 Vita Nuova Holdings Limited (www.vitanuova.com). All rights reserved. -// Portions Copyright 2009 The Go Authors. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#include "textflag.h" - -// handleMemcpyFault returns (the value stored in AX, the value stored in DI). -// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS, -// with the faulting address stored in AX and the signal number stored in DI. -// -// It must have the same frame configuration as memcpy so that it can undo any -// potential call frame set up by the assembler. -TEXT handleMemcpyFault(SB), NOSPLIT, $0-36 - MOVQ AX, addr+24(FP) - MOVL DI, sig+32(FP) - RET - -// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received -// during the copy, it returns the address that caused the fault and the number -// of the signal that was received. Otherwise, it returns an unspecified address -// and a signal number of 0. -// -// Data is copied in order, such that if a fault happens at address p, it is -// safe to assume that all data before p-maxRegisterSize has already been -// successfully copied. -// -// The code is derived from the forward copying part of runtime.memmove. -// -// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) -TEXT ·memcpy(SB), NOSPLIT, $0-36 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleMemcpyFault will store a different value in this address. - MOVL $0, sig+32(FP) - - MOVQ to+0(FP), DI - MOVQ from+8(FP), SI - MOVQ n+16(FP), BX - - // REP instructions have a high startup cost, so we handle small sizes - // with some straightline code. The REP MOVSQ instruction is really fast - // for large sizes. The cutover is approximately 2K. -tail: - // move_129through256 or smaller work whether or not the source and the - // destination memory regions overlap because they load all data into - // registers before writing it back. move_256through2048 on the other - // hand can be used only when the memory regions don't overlap or the copy - // direction is forward. - TESTQ BX, BX - JEQ move_0 - CMPQ BX, $2 - JBE move_1or2 - CMPQ BX, $4 - JBE move_3or4 - CMPQ BX, $8 - JB move_5through7 - JE move_8 - CMPQ BX, $16 - JBE move_9through16 - CMPQ BX, $32 - JBE move_17through32 - CMPQ BX, $64 - JBE move_33through64 - CMPQ BX, $128 - JBE move_65through128 - CMPQ BX, $256 - JBE move_129through256 - // TODO: use branch table and BSR to make this just a single dispatch - -/* - * forward copy loop - */ - CMPQ BX, $2048 - JLS move_256through2048 - - // Check alignment - MOVL SI, AX - ORL DI, AX - TESTL $7, AX - JEQ fwdBy8 - - // Do 1 byte at a time - MOVQ BX, CX - REP; MOVSB - RET - -fwdBy8: - // Do 8 bytes at a time - MOVQ BX, CX - SHRQ $3, CX - ANDQ $7, BX - REP; MOVSQ - JMP tail - -move_1or2: - MOVB (SI), AX - MOVB AX, (DI) - MOVB -1(SI)(BX*1), CX - MOVB CX, -1(DI)(BX*1) - RET -move_0: - RET -move_3or4: - MOVW (SI), AX - MOVW AX, (DI) - MOVW -2(SI)(BX*1), CX - MOVW CX, -2(DI)(BX*1) - RET -move_5through7: - MOVL (SI), AX - MOVL AX, (DI) - MOVL -4(SI)(BX*1), CX - MOVL CX, -4(DI)(BX*1) - RET -move_8: - // We need a separate case for 8 to make sure we write pointers atomically. - MOVQ (SI), AX - MOVQ AX, (DI) - RET -move_9through16: - MOVQ (SI), AX - MOVQ AX, (DI) - MOVQ -8(SI)(BX*1), CX - MOVQ CX, -8(DI)(BX*1) - RET -move_17through32: - MOVOU (SI), X0 - MOVOU X0, (DI) - MOVOU -16(SI)(BX*1), X1 - MOVOU X1, -16(DI)(BX*1) - RET -move_33through64: - MOVOU (SI), X0 - MOVOU X0, (DI) - MOVOU 16(SI), X1 - MOVOU X1, 16(DI) - MOVOU -32(SI)(BX*1), X2 - MOVOU X2, -32(DI)(BX*1) - MOVOU -16(SI)(BX*1), X3 - MOVOU X3, -16(DI)(BX*1) - RET -move_65through128: - MOVOU (SI), X0 - MOVOU X0, (DI) - MOVOU 16(SI), X1 - MOVOU X1, 16(DI) - MOVOU 32(SI), X2 - MOVOU X2, 32(DI) - MOVOU 48(SI), X3 - MOVOU X3, 48(DI) - MOVOU -64(SI)(BX*1), X4 - MOVOU X4, -64(DI)(BX*1) - MOVOU -48(SI)(BX*1), X5 - MOVOU X5, -48(DI)(BX*1) - MOVOU -32(SI)(BX*1), X6 - MOVOU X6, -32(DI)(BX*1) - MOVOU -16(SI)(BX*1), X7 - MOVOU X7, -16(DI)(BX*1) - RET -move_129through256: - MOVOU (SI), X0 - MOVOU X0, (DI) - MOVOU 16(SI), X1 - MOVOU X1, 16(DI) - MOVOU 32(SI), X2 - MOVOU X2, 32(DI) - MOVOU 48(SI), X3 - MOVOU X3, 48(DI) - MOVOU 64(SI), X4 - MOVOU X4, 64(DI) - MOVOU 80(SI), X5 - MOVOU X5, 80(DI) - MOVOU 96(SI), X6 - MOVOU X6, 96(DI) - MOVOU 112(SI), X7 - MOVOU X7, 112(DI) - MOVOU -128(SI)(BX*1), X8 - MOVOU X8, -128(DI)(BX*1) - MOVOU -112(SI)(BX*1), X9 - MOVOU X9, -112(DI)(BX*1) - MOVOU -96(SI)(BX*1), X10 - MOVOU X10, -96(DI)(BX*1) - MOVOU -80(SI)(BX*1), X11 - MOVOU X11, -80(DI)(BX*1) - MOVOU -64(SI)(BX*1), X12 - MOVOU X12, -64(DI)(BX*1) - MOVOU -48(SI)(BX*1), X13 - MOVOU X13, -48(DI)(BX*1) - MOVOU -32(SI)(BX*1), X14 - MOVOU X14, -32(DI)(BX*1) - MOVOU -16(SI)(BX*1), X15 - MOVOU X15, -16(DI)(BX*1) - RET -move_256through2048: - SUBQ $256, BX - MOVOU (SI), X0 - MOVOU X0, (DI) - MOVOU 16(SI), X1 - MOVOU X1, 16(DI) - MOVOU 32(SI), X2 - MOVOU X2, 32(DI) - MOVOU 48(SI), X3 - MOVOU X3, 48(DI) - MOVOU 64(SI), X4 - MOVOU X4, 64(DI) - MOVOU 80(SI), X5 - MOVOU X5, 80(DI) - MOVOU 96(SI), X6 - MOVOU X6, 96(DI) - MOVOU 112(SI), X7 - MOVOU X7, 112(DI) - MOVOU 128(SI), X8 - MOVOU X8, 128(DI) - MOVOU 144(SI), X9 - MOVOU X9, 144(DI) - MOVOU 160(SI), X10 - MOVOU X10, 160(DI) - MOVOU 176(SI), X11 - MOVOU X11, 176(DI) - MOVOU 192(SI), X12 - MOVOU X12, 192(DI) - MOVOU 208(SI), X13 - MOVOU X13, 208(DI) - MOVOU 224(SI), X14 - MOVOU X14, 224(DI) - MOVOU 240(SI), X15 - MOVOU X15, 240(DI) - CMPQ BX, $256 - LEAQ 256(SI), SI - LEAQ 256(DI), DI - JGE move_256through2048 - JMP tail diff --git a/pkg/sentry/platform/safecopy/memcpy_arm64.s b/pkg/sentry/platform/safecopy/memcpy_arm64.s deleted file mode 100644 index e7e541565..000000000 --- a/pkg/sentry/platform/safecopy/memcpy_arm64.s +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -#include "textflag.h" - -// handleMemcpyFault returns (the value stored in R0, the value stored in R1). -// Control is transferred to it when memcpy below receives SIGSEGV or SIGBUS, -// with the faulting address stored in R0 and the signal number stored in R1. -// -// It must have the same frame configuration as memcpy so that it can undo any -// potential call frame set up by the assembler. -TEXT handleMemcpyFault(SB), NOSPLIT, $0-36 - MOVD R0, addr+24(FP) - MOVW R1, sig+32(FP) - RET - -// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received -// during the copy, it returns the address that caused the fault and the number -// of the signal that was received. Otherwise, it returns an unspecified address -// and a signal number of 0. -// -// Data is copied in order, such that if a fault happens at address p, it is -// safe to assume that all data before p-maxRegisterSize has already been -// successfully copied. -// -// The code is derived from the Go source runtime.memmove. -// -// func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) -TEXT ·memcpy(SB), NOSPLIT, $-8-36 - // Store 0 as the returned signal number. If we run to completion, - // this is the value the caller will see; if a signal is received, - // handleMemcpyFault will store a different value in this address. - MOVW $0, sig+32(FP) - - MOVD to+0(FP), R3 - MOVD from+8(FP), R4 - MOVD n+16(FP), R5 - CMP $0, R5 - BNE check - RET - -check: - AND $~7, R5, R7 // R7 is N&~7. - SUB R7, R5, R6 // R6 is N&7. - - // Copying forward proceeds by copying R7/8 words then copying R6 bytes. - // R3 and R4 are advanced as we copy. - - // (There may be implementations of armv8 where copying by bytes until - // at least one of source or dest is word aligned is a worthwhile - // optimization, but the on the one tested so far (xgene) it did not - // make a significance difference.) - - CMP $0, R7 // Do we need to do any word-by-word copying? - BEQ noforwardlarge - ADD R3, R7, R9 // R9 points just past where we copy by word. - -forwardlargeloop: - MOVD.P 8(R4), R8 // R8 is just a scratch register. - MOVD.P R8, 8(R3) - CMP R3, R9 - BNE forwardlargeloop - -noforwardlarge: - CMP $0, R6 // Do we need to do any byte-by-byte copying? - BNE forwardtail - RET - -forwardtail: - ADD R3, R6, R9 // R9 points just past the destination memory. - -forwardtailloop: - MOVBU.P 1(R4), R8 - MOVBU.P R8, 1(R3) - CMP R3, R9 - BNE forwardtailloop - RET diff --git a/pkg/sentry/platform/safecopy/safecopy.go b/pkg/sentry/platform/safecopy/safecopy.go deleted file mode 100644 index 2fb7e5809..000000000 --- a/pkg/sentry/platform/safecopy/safecopy.go +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package safecopy provides an efficient implementation of functions to access -// memory that may result in SIGSEGV or SIGBUS being sent to the accessor. -package safecopy - -import ( - "fmt" - "reflect" - "runtime" - "syscall" - - "gvisor.dev/gvisor/pkg/syserror" -) - -// SegvError is returned when a safecopy function receives SIGSEGV. -type SegvError struct { - // Addr is the address at which the SIGSEGV occurred. - Addr uintptr -} - -// Error implements error.Error. -func (e SegvError) Error() string { - return fmt.Sprintf("SIGSEGV at %#x", e.Addr) -} - -// BusError is returned when a safecopy function receives SIGBUS. -type BusError struct { - // Addr is the address at which the SIGBUS occurred. - Addr uintptr -} - -// Error implements error.Error. -func (e BusError) Error() string { - return fmt.Sprintf("SIGBUS at %#x", e.Addr) -} - -// AlignmentError is returned when a safecopy function is passed an address -// that does not meet alignment requirements. -type AlignmentError struct { - // Addr is the invalid address. - Addr uintptr - - // Alignment is the required alignment. - Alignment uintptr -} - -// Error implements error.Error. -func (e AlignmentError) Error() string { - return fmt.Sprintf("address %#x is not aligned to a %d-byte boundary", e.Addr, e.Alignment) -} - -var ( - // The begin and end addresses below are for the functions that are - // checked by the signal handler. - memcpyBegin uintptr - memcpyEnd uintptr - memclrBegin uintptr - memclrEnd uintptr - swapUint32Begin uintptr - swapUint32End uintptr - swapUint64Begin uintptr - swapUint64End uintptr - compareAndSwapUint32Begin uintptr - compareAndSwapUint32End uintptr - loadUint32Begin uintptr - loadUint32End uintptr - - // savedSigSegVHandler is a pointer to the SIGSEGV handler that was - // configured before we replaced it with our own. We still call into it - // when we get a SIGSEGV that is not interesting to us. - savedSigSegVHandler uintptr - - // same a above, but for SIGBUS signals. - savedSigBusHandler uintptr -) - -// signalHandler is our replacement signal handler for SIGSEGV and SIGBUS -// signals. -func signalHandler() - -// FindEndAddress returns the end address (one byte beyond the last) of the -// function that contains the specified address (begin). -func FindEndAddress(begin uintptr) uintptr { - f := runtime.FuncForPC(begin) - if f != nil { - for p := begin; ; p++ { - g := runtime.FuncForPC(p) - if f != g { - return p - } - } - } - return begin -} - -// initializeAddresses initializes the addresses used by the signal handler. -func initializeAddresses() { - // The following functions are written in assembly language, so they won't - // be inlined by the existing compiler/linker. Tests will fail if this - // assumption is violated. - memcpyBegin = reflect.ValueOf(memcpy).Pointer() - memcpyEnd = FindEndAddress(memcpyBegin) - memclrBegin = reflect.ValueOf(memclr).Pointer() - memclrEnd = FindEndAddress(memclrBegin) - swapUint32Begin = reflect.ValueOf(swapUint32).Pointer() - swapUint32End = FindEndAddress(swapUint32Begin) - swapUint64Begin = reflect.ValueOf(swapUint64).Pointer() - swapUint64End = FindEndAddress(swapUint64Begin) - compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer() - compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin) - loadUint32Begin = reflect.ValueOf(loadUint32).Pointer() - loadUint32End = FindEndAddress(loadUint32Begin) -} - -func init() { - initializeAddresses() - if err := ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil { - panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) - } - if err := ReplaceSignalHandler(syscall.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil { - panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) - } - syserror.AddErrorUnwrapper(func(e error) (syscall.Errno, bool) { - switch e.(type) { - case SegvError, BusError, AlignmentError: - return syscall.EFAULT, true - default: - return 0, false - } - }) -} diff --git a/pkg/sentry/platform/safecopy/safecopy_test.go b/pkg/sentry/platform/safecopy/safecopy_test.go deleted file mode 100644 index 5818f7f9b..000000000 --- a/pkg/sentry/platform/safecopy/safecopy_test.go +++ /dev/null @@ -1,617 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safecopy - -import ( - "bytes" - "fmt" - "io/ioutil" - "math/rand" - "os" - "runtime/debug" - "syscall" - "testing" - "unsafe" -) - -// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular -// dependency. -const pageSize = 4096 - -func initRandom(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func randBuf(size int) []byte { - b := make([]byte, size) - initRandom(b) - return b -} - -func TestCopyInSuccess(t *testing.T) { - // Test that CopyIn does not return an error when all pages are accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := CopyIn(b, unsafe.Pointer(&a[0])) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestCopyOutSuccess(t *testing.T) { - // Test that CopyOut does not return an error when all pages are - // accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := CopyOut(unsafe.Pointer(&b[0]), a) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestCopySuccess(t *testing.T) { - // Test that Copy does not return an error when all pages are accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := Copy(unsafe.Pointer(&b[0]), unsafe.Pointer(&a[0]), bufLen) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestZeroOutSuccess(t *testing.T) { - // Test that ZeroOut does not return an error when all pages are - // accessible. - const bufLen = 8192 - a := make([]byte, bufLen) - b := randBuf(bufLen) - - n, err := ZeroOut(unsafe.Pointer(&b[0]), bufLen) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestSwapUint32Success(t *testing.T) { - // Test that SwapUint32 does not return an error when the page is - // accessible. - before := uint32(rand.Int31()) - after := uint32(rand.Int31()) - val := before - - old, err := SwapUint32(unsafe.Pointer(&val), after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if val != after { - t.Errorf("Unexpected new value: got %v, want %v", val, after) - } -} - -func TestSwapUint32AlignmentError(t *testing.T) { - // Test that SwapUint32 returns an AlignmentError when passed an unaligned - // address. - data := new(struct{ val uint64 }) - addr := uintptr(unsafe.Pointer(&data.val)) + 1 - want := AlignmentError{Addr: addr, Alignment: 4} - if _, err := SwapUint32(unsafe.Pointer(addr), 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -func TestSwapUint64Success(t *testing.T) { - // Test that SwapUint64 does not return an error when the page is - // accessible. - before := uint64(rand.Int63()) - after := uint64(rand.Int63()) - // "The first word in ... an allocated struct or slice can be relied upon - // to be 64-bit aligned." - sync/atomic docs - data := new(struct{ val uint64 }) - data.val = before - - old, err := SwapUint64(unsafe.Pointer(&data.val), after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if data.val != after { - t.Errorf("Unexpected new value: got %v, want %v", data.val, after) - } -} - -func TestSwapUint64AlignmentError(t *testing.T) { - // Test that SwapUint64 returns an AlignmentError when passed an unaligned - // address. - data := new(struct{ val1, val2 uint64 }) - addr := uintptr(unsafe.Pointer(&data.val1)) + 1 - want := AlignmentError{Addr: addr, Alignment: 8} - if _, err := SwapUint64(unsafe.Pointer(addr), 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -func TestCompareAndSwapUint32Success(t *testing.T) { - // Test that CompareAndSwapUint32 does not return an error when the page is - // accessible. - before := uint32(rand.Int31()) - after := uint32(rand.Int31()) - val := before - - old, err := CompareAndSwapUint32(unsafe.Pointer(&val), before, after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if val != after { - t.Errorf("Unexpected new value: got %v, want %v", val, after) - } -} - -func TestCompareAndSwapUint32AlignmentError(t *testing.T) { - // Test that CompareAndSwapUint32 returns an AlignmentError when passed an - // unaligned address. - data := new(struct{ val uint64 }) - addr := uintptr(unsafe.Pointer(&data.val)) + 1 - want := AlignmentError{Addr: addr, Alignment: 4} - if _, err := CompareAndSwapUint32(unsafe.Pointer(addr), 0, 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -// withSegvErrorTestMapping calls fn with a two-page mapping. The first page -// contains random data, and the second page generates SIGSEGV when accessed. -func withSegvErrorTestMapping(t *testing.T, fn func(m []byte)) { - mapping, err := syscall.Mmap(-1, 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - } - defer syscall.Munmap(mapping) - if err := syscall.Mprotect(mapping[pageSize:], syscall.PROT_NONE); err != nil { - t.Fatalf("Mprotect failed: %v", err) - } - initRandom(mapping[:pageSize]) - - fn(mapping) -} - -// withBusErrorTestMapping calls fn with a two-page mapping. The first page -// contains random data, and the second page generates SIGBUS when accessed. -func withBusErrorTestMapping(t *testing.T, fn func(m []byte)) { - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - if err := f.Truncate(pageSize); err != nil { - t.Fatalf("Truncate failed: %v", err) - } - mapping, err := syscall.Mmap(int(f.Fd()), 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - } - defer syscall.Munmap(mapping) - initRandom(mapping[:pageSize]) - - fn(mapping) -} - -func TestCopyInSegvError(t *testing.T) { - // Test that CopyIn returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - dst := randBuf(pageSize) - n, err := CopyIn(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyInBusError(t *testing.T) { - // Test that CopyIn returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - dst := randBuf(pageSize) - n, err := CopyIn(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyOutSegvError(t *testing.T) { - // Test that CopyOut returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - src := randBuf(pageSize) - n, err := CopyOut(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyOutBusError(t *testing.T) { - // Test that CopyOut returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - src := randBuf(pageSize) - n, err := CopyOut(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopySourceSegvError(t *testing.T) { - // Test that Copy returns a SegvError when copying from a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - dst := randBuf(pageSize) - n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopySourceBusError(t *testing.T) { - // Test that Copy returns a BusError when copying from a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - dst := randBuf(pageSize) - n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyDestinationSegvError(t *testing.T) { - // Test that Copy returns a SegvError when copying to a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - src := randBuf(pageSize) - n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyDestinationBusError(t *testing.T) { - // Test that Copy returns a BusError when copying to a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - src := randBuf(pageSize) - n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestZeroOutSegvError(t *testing.T) { - // Test that ZeroOut returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - n, err := ZeroOut(dst, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { - t.Errorf("Non-zero bytes in written part of mapping: %v", got) - } - }) - }) - } -} - -func TestZeroOutBusError(t *testing.T) { - // Test that ZeroOut returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - n, err := ZeroOut(dst, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { - t.Errorf("Non-zero bytes in written part of mapping: %v", got) - } - }) - }) - } -} - -func TestSwapUint32SegvError(t *testing.T) { - // Test that SwapUint32 returns a SegvError when reaching a page that - // signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := SwapUint32(unsafe.Pointer(secondPage), 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint32BusError(t *testing.T) { - // Test that SwapUint32 returns a BusError when reaching a page that - // signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := SwapUint32(unsafe.Pointer(secondPage), 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint64SegvError(t *testing.T) { - // Test that SwapUint64 returns a SegvError when reaching a page that - // signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := SwapUint64(unsafe.Pointer(secondPage), 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint64BusError(t *testing.T) { - // Test that SwapUint64 returns a BusError when reaching a page that - // signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := SwapUint64(unsafe.Pointer(secondPage), 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestCompareAndSwapUint32SegvError(t *testing.T) { - // Test that CompareAndSwapUint32 returns a SegvError when reaching a page - // that signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestCompareAndSwapUint32BusError(t *testing.T) { - // Test that CompareAndSwapUint32 returns a BusError when reaching a page - // that signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func testCopy(dst, src []byte) (panicked bool) { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - debug.SetPanicOnFault(true) - copy(dst, src) - return -} - -func TestSegVOnMemmove(t *testing.T) { - // Test that SIGSEGVs received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - a, err := syscall.Mmap(-1, 0, bufLen, syscall.PROT_NONE, syscall.MAP_ANON|syscall.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer syscall.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} - -func TestSigbusOnMemmove(t *testing.T) { - // Test that SIGBUS received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - os.Remove(f.Name()) - defer f.Close() - - a, err := syscall.Mmap(int(f.Fd()), 0, bufLen, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer syscall.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} diff --git a/pkg/sentry/platform/safecopy/safecopy_unsafe.go b/pkg/sentry/platform/safecopy/safecopy_unsafe.go deleted file mode 100644 index eef028e68..000000000 --- a/pkg/sentry/platform/safecopy/safecopy_unsafe.go +++ /dev/null @@ -1,335 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safecopy - -import ( - "fmt" - "syscall" - "unsafe" -) - -// maxRegisterSize is the maximum register size used in memcpy and memclr. It -// is used to decide by how much to rewind the copy (for memcpy) or zeroing -// (for memclr) before proceeding. -const maxRegisterSize = 16 - -// memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received -// during the copy, it returns the address that caused the fault and the number -// of the signal that was received. Otherwise, it returns an unspecified address -// and a signal number of 0. -// -// Data is copied in order, such that if a fault happens at address p, it is -// safe to assume that all data before p-maxRegisterSize has already been -// successfully copied. -// -//go:noescape -func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) - -// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS -// signal is received during the write, it returns the address that caused the -// fault and the number of the signal that was received. Otherwise, it returns -// an unspecified address and a signal number of 0. -// -// Data is written in order, such that if a fault happens at address p, it is -// safe to assume that all data before p-maxRegisterSize has already been -// successfully written. -// -//go:noescape -func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32) - -// swapUint32 atomically stores new into *ptr and returns (the previous *ptr -// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the -// value of old is unspecified, and sig is the number of the signal that was -// received. -// -// Preconditions: ptr must be aligned to a 4-byte boundary. -// -//go:noescape -func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) - -// swapUint64 atomically stores new into *ptr and returns (the previous *ptr -// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the -// value of old is unspecified, and sig is the number of the signal that was -// received. -// -// Preconditions: ptr must be aligned to a 8-byte boundary. -// -//go:noescape -func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) - -// compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns -// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is -// received during the operation, the value of prev is unspecified, and sig is -// the number of the signal that was received. -// -// Preconditions: ptr must be aligned to a 4-byte boundary. -// -//go:noescape -func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) - -// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It -// may fail with SIGSEGV or SIGBUS if it is received while reading from ptr. -// -// Preconditions: ptr must be aligned to a 4-byte boundary. -// -//go:noescape -func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) - -// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes -// copied and an error if SIGSEGV or SIGBUS is received while reading from src. -func CopyIn(dst []byte, src unsafe.Pointer) (int, error) { - toCopy := uintptr(len(dst)) - if len(dst) == 0 { - return 0, nil - } - - fault, sig := memcpy(unsafe.Pointer(&dst[0]), src, toCopy) - if sig == 0 { - return len(dst), nil - } - - faultN, srcN := uintptr(fault), uintptr(src) - if faultN < srcN || faultN >= srcN+toCopy { - panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, faultN, srcN, srcN+toCopy)) - } - - // memcpy might have ended the copy up to maxRegisterSize bytes before - // fault, if an instruction caused a memory access that straddled two - // pages, and the second one faulted. Try to copy up to the fault. - var done int - if faultN-srcN > maxRegisterSize { - done = int(faultN - srcN - maxRegisterSize) - } - n, err := CopyIn(dst[done:int(faultN-srcN)], unsafe.Pointer(srcN+uintptr(done))) - done += n - if err != nil { - return done, err - } - return done, errorFromFaultSignal(fault, sig) -} - -// CopyOut copies len(src) bytes from src to dst. If returns the number of -// bytes done and an error if SIGSEGV or SIGBUS is received while writing to -// dst. -func CopyOut(dst unsafe.Pointer, src []byte) (int, error) { - toCopy := uintptr(len(src)) - if toCopy == 0 { - return 0, nil - } - - fault, sig := memcpy(dst, unsafe.Pointer(&src[0]), toCopy) - if sig == 0 { - return len(src), nil - } - - faultN, dstN := uintptr(fault), uintptr(dst) - if faultN < dstN || faultN >= dstN+toCopy { - panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toCopy)) - } - - // memcpy might have ended the copy up to maxRegisterSize bytes before - // fault, if an instruction caused a memory access that straddled two - // pages, and the second one faulted. Try to copy up to the fault. - var done int - if faultN-dstN > maxRegisterSize { - done = int(faultN - dstN - maxRegisterSize) - } - n, err := CopyOut(unsafe.Pointer(dstN+uintptr(done)), src[done:int(faultN-dstN)]) - done += n - if err != nil { - return done, err - } - return done, errorFromFaultSignal(fault, sig) -} - -// Copy copies toCopy bytes from src to dst. It returns the number of bytes -// copied and an error if SIGSEGV or SIGBUS is received while reading from src -// or writing to dst. -// -// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap, -// the resulting contents of dst are unspecified. -func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) { - if toCopy == 0 { - return 0, nil - } - - fault, sig := memcpy(dst, src, toCopy) - if sig == 0 { - return toCopy, nil - } - - // Did the fault occur while reading from src or writing to dst? - faultN, srcN, dstN := uintptr(fault), uintptr(src), uintptr(dst) - faultAfterSrc := ^uintptr(0) - if faultN >= srcN { - faultAfterSrc = faultN - srcN - } - faultAfterDst := ^uintptr(0) - if faultN >= dstN { - faultAfterDst = faultN - dstN - } - if faultAfterSrc >= toCopy && faultAfterDst >= toCopy { - panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, faultN, srcN, srcN+toCopy, dstN, dstN+toCopy)) - } - faultedAfter := faultAfterSrc - if faultedAfter > faultAfterDst { - faultedAfter = faultAfterDst - } - - // memcpy might have ended the copy up to maxRegisterSize bytes before - // fault, if an instruction caused a memory access that straddled two - // pages, and the second one faulted. Try to copy up to the fault. - var done uintptr - if faultedAfter > maxRegisterSize { - done = faultedAfter - maxRegisterSize - } - n, err := Copy(unsafe.Pointer(dstN+done), unsafe.Pointer(srcN+done), faultedAfter-done) - done += n - if err != nil { - return done, err - } - return done, errorFromFaultSignal(fault, sig) -} - -// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes -// written and an error if SIGSEGV or SIGBUS is received while writing to dst. -func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) { - if toZero == 0 { - return 0, nil - } - - fault, sig := memclr(dst, toZero) - if sig == 0 { - return toZero, nil - } - - faultN, dstN := uintptr(fault), uintptr(dst) - if faultN < dstN || faultN >= dstN+toZero { - panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toZero)) - } - - // memclr might have ended the write up to maxRegisterSize bytes before - // fault, if an instruction caused a memory access that straddled two - // pages, and the second one faulted. Try to write up to the fault. - var done uintptr - if faultN-dstN > maxRegisterSize { - done = faultN - dstN - maxRegisterSize - } - n, err := ZeroOut(unsafe.Pointer(dstN+done), faultN-dstN-done) - done += n - if err != nil { - return done, err - } - return done, errorFromFaultSignal(fault, sig) -} - -// SwapUint32 is equivalent to sync/atomic.SwapUint32, except that it returns -// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is -// not aligned to a 4-byte boundary. -func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) { - if addr := uintptr(ptr); addr&3 != 0 { - return 0, AlignmentError{addr, 4} - } - old, sig := swapUint32(ptr, new) - return old, errorFromFaultSignal(ptr, sig) -} - -// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns -// an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is -// not aligned to an 8-byte boundary. -func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) { - if addr := uintptr(ptr); addr&7 != 0 { - return 0, AlignmentError{addr, 8} - } - old, sig := swapUint64(ptr, new) - return old, errorFromFaultSignal(ptr, sig) -} - -// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32, -// except that it returns an error if SIGSEGV or SIGBUS is received while -// accessing ptr, or if ptr is not aligned to a 4-byte boundary. -func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) { - if addr := uintptr(ptr); addr&3 != 0 { - return 0, AlignmentError{addr, 4} - } - prev, sig := compareAndSwapUint32(ptr, old, new) - return prev, errorFromFaultSignal(ptr, sig) -} - -// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It -// may fail with SIGSEGV or SIGBUS if it is received while reading from ptr. -// -// Preconditions: ptr must be aligned to a 4-byte boundary. -func LoadUint32(ptr unsafe.Pointer) (uint32, error) { - if addr := uintptr(ptr); addr&3 != 0 { - return 0, AlignmentError{addr, 4} - } - val, sig := loadUint32(ptr) - return val, errorFromFaultSignal(ptr, sig) -} - -func errorFromFaultSignal(addr unsafe.Pointer, sig int32) error { - switch sig { - case 0: - return nil - case int32(syscall.SIGSEGV): - return SegvError{uintptr(addr)} - case int32(syscall.SIGBUS): - return BusError{uintptr(addr)} - default: - panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr)) - } -} - -// ReplaceSignalHandler replaces the existing signal handler for the provided -// signal with the one that handles faults in safecopy-protected functions. -// -// It stores the value of the previously set handler in previous. -// -// This function will be called on initialization in order to install safecopy -// handlers for appropriate signals. These handlers will call the previous -// handler however, and if this is function is being used externally then the -// same courtesy is expected. -func ReplaceSignalHandler(sig syscall.Signal, handler uintptr, previous *uintptr) error { - var sa struct { - handler uintptr - flags uint64 - restorer uintptr - mask uint64 - } - const maskLen = 8 - - // Get the existing signal handler information, and save the current - // handler. Once we replace it, we will use this pointer to fall back to - // it when we receive other signals. - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 { - return e - } - - // Fail if there isn't a previous handler. - if sa.handler == 0 { - return fmt.Errorf("previous handler for signal %x isn't set", sig) - } - - *previous = sa.handler - - // Install our own handler. - sa.handler = handler - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { - return e - } - - return nil -} diff --git a/pkg/sentry/platform/safecopy/sighandler_amd64.s b/pkg/sentry/platform/safecopy/sighandler_amd64.s deleted file mode 100644 index 475ae48e9..000000000 --- a/pkg/sentry/platform/safecopy/sighandler_amd64.s +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -// The signals handled by sigHandler. -#define SIGBUS 7 -#define SIGSEGV 11 - -// Offsets to the registers in context->uc_mcontext.gregs[]. -#define REG_RDI 0x68 -#define REG_RAX 0x90 -#define REG_IP 0xa8 - -// Offset to the si_addr field of siginfo. -#define SI_CODE 0x08 -#define SI_ADDR 0x10 - -// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must -// not be set up as a handler to any other signals. -// -// If the instruction causing the signal is within a safecopy-protected -// function, the signal is handled such that execution resumes in the -// appropriate fault handling stub with AX containing the faulting address and -// DI containing the signal number. Otherwise control is transferred to the -// previously configured signal handler (savedSigSegvHandler or -// savedSigBusHandler). -// -// This function cannot be written in go because it runs whenever a signal is -// received by the thread (preempting whatever was running), which includes when -// garbage collector has stopped or isn't expecting any interactions (like -// barriers). -// -// The arguments are the following: -// DI - The signal number. -// SI - Pointer to siginfo_t structure. -// DX - Pointer to ucontext structure. -TEXT ·signalHandler(SB),NOSPLIT,$0 - // Check if the signal is from the kernel. - MOVQ $0x0, CX - CMPL CX, SI_CODE(SI) - JGE original_handler - - // Check if RIP is within the area we care about. - MOVQ REG_IP(DX), CX - CMPQ CX, ·memcpyBegin(SB) - JB not_memcpy - CMPQ CX, ·memcpyEnd(SB) - JAE not_memcpy - - // Modify the context such that execution will resume in the fault - // handler. - LEAQ handleMemcpyFault(SB), CX - JMP handle_fault - -not_memcpy: - CMPQ CX, ·memclrBegin(SB) - JB not_memclr - CMPQ CX, ·memclrEnd(SB) - JAE not_memclr - - LEAQ handleMemclrFault(SB), CX - JMP handle_fault - -not_memclr: - CMPQ CX, ·swapUint32Begin(SB) - JB not_swapuint32 - CMPQ CX, ·swapUint32End(SB) - JAE not_swapuint32 - - LEAQ handleSwapUint32Fault(SB), CX - JMP handle_fault - -not_swapuint32: - CMPQ CX, ·swapUint64Begin(SB) - JB not_swapuint64 - CMPQ CX, ·swapUint64End(SB) - JAE not_swapuint64 - - LEAQ handleSwapUint64Fault(SB), CX - JMP handle_fault - -not_swapuint64: - CMPQ CX, ·compareAndSwapUint32Begin(SB) - JB not_casuint32 - CMPQ CX, ·compareAndSwapUint32End(SB) - JAE not_casuint32 - - LEAQ handleCompareAndSwapUint32Fault(SB), CX - JMP handle_fault - -not_casuint32: - CMPQ CX, ·loadUint32Begin(SB) - JB not_loaduint32 - CMPQ CX, ·loadUint32End(SB) - JAE not_loaduint32 - - LEAQ handleLoadUint32Fault(SB), CX - JMP handle_fault - -not_loaduint32: -original_handler: - // Jump to the previous signal handler, which is likely the golang one. - XORQ CX, CX - MOVQ ·savedSigBusHandler(SB), AX - CMPL DI, $SIGSEGV - CMOVQEQ ·savedSigSegVHandler(SB), AX - JMP AX - -handle_fault: - // Entered with the address of the fault handler in RCX; store it in - // RIP. - MOVQ CX, REG_IP(DX) - - // Store the faulting address in RAX. - MOVQ SI_ADDR(SI), CX - MOVQ CX, REG_RAX(DX) - - // Store the signal number in EDI. - MOVL DI, REG_RDI(DX) - - RET diff --git a/pkg/sentry/platform/safecopy/sighandler_arm64.s b/pkg/sentry/platform/safecopy/sighandler_arm64.s deleted file mode 100644 index 53e4ac2c1..000000000 --- a/pkg/sentry/platform/safecopy/sighandler_arm64.s +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -// The signals handled by sigHandler. -#define SIGBUS 7 -#define SIGSEGV 11 - -// Offsets to the registers in context->uc_mcontext.gregs[]. -#define REG_R0 0xB8 -#define REG_R1 0xC0 -#define REG_PC 0x1B8 - -// Offset to the si_addr field of siginfo. -#define SI_CODE 0x08 -#define SI_ADDR 0x10 - -// signalHandler is the signal handler for SIGSEGV and SIGBUS signals. It must -// not be set up as a handler to any other signals. -// -// If the instruction causing the signal is within a safecopy-protected -// function, the signal is handled such that execution resumes in the -// appropriate fault handling stub with R0 containing the faulting address and -// R1 containing the signal number. Otherwise control is transferred to the -// previously configured signal handler (savedSigSegvHandler or -// savedSigBusHandler). -// -// This function cannot be written in go because it runs whenever a signal is -// received by the thread (preempting whatever was running), which includes when -// garbage collector has stopped or isn't expecting any interactions (like -// barriers). -// -// The arguments are the following: -// R0 - The signal number. -// R1 - Pointer to siginfo_t structure. -// R2 - Pointer to ucontext structure. -TEXT ·signalHandler(SB),NOSPLIT,$0 - // Check if the signal is from the kernel, si_code > 0 means a kernel signal. - MOVD SI_CODE(R1), R7 - CMPW $0x0, R7 - BLE original_handler - - // Check if PC is within the area we care about. - MOVD REG_PC(R2), R7 - MOVD ·memcpyBegin(SB), R8 - CMP R8, R7 - BLO not_memcpy - MOVD ·memcpyEnd(SB), R8 - CMP R8, R7 - BHS not_memcpy - - // Modify the context such that execution will resume in the fault handler. - MOVD $handleMemcpyFault(SB), R7 - B handle_fault - -not_memcpy: - MOVD ·memclrBegin(SB), R8 - CMP R8, R7 - BLO not_memclr - MOVD ·memclrEnd(SB), R8 - CMP R8, R7 - BHS not_memclr - - MOVD $handleMemclrFault(SB), R7 - B handle_fault - -not_memclr: - MOVD ·swapUint32Begin(SB), R8 - CMP R8, R7 - BLO not_swapuint32 - MOVD ·swapUint32End(SB), R8 - CMP R8, R7 - BHS not_swapuint32 - - MOVD $handleSwapUint32Fault(SB), R7 - B handle_fault - -not_swapuint32: - MOVD ·swapUint64Begin(SB), R8 - CMP R8, R7 - BLO not_swapuint64 - MOVD ·swapUint64End(SB), R8 - CMP R8, R7 - BHS not_swapuint64 - - MOVD $handleSwapUint64Fault(SB), R7 - B handle_fault - -not_swapuint64: - MOVD ·compareAndSwapUint32Begin(SB), R8 - CMP R8, R7 - BLO not_casuint32 - MOVD ·compareAndSwapUint32End(SB), R8 - CMP R8, R7 - BHS not_casuint32 - - MOVD $handleCompareAndSwapUint32Fault(SB), R7 - B handle_fault - -not_casuint32: - MOVD ·loadUint32Begin(SB), R8 - CMP R8, R7 - BLO not_loaduint32 - MOVD ·loadUint32End(SB), R8 - CMP R8, R7 - BHS not_loaduint32 - - MOVD $handleLoadUint32Fault(SB), R7 - B handle_fault - -not_loaduint32: -original_handler: - // Jump to the previous signal handler, which is likely the golang one. - MOVD ·savedSigBusHandler(SB), R7 - MOVD ·savedSigSegVHandler(SB), R8 - CMPW $SIGSEGV, R0 - CSEL EQ, R8, R7, R7 - B (R7) - -handle_fault: - // Entered with the address of the fault handler in R7; store it in PC. - MOVD R7, REG_PC(R2) - - // Store the faulting address in R0. - MOVD SI_ADDR(R1), R7 - MOVD R7, REG_R0(R2) - - // Store the signal number in R1. - MOVW R0, REG_R1(R2) - - RET diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD deleted file mode 100644 index 3ab76da97..000000000 --- a/pkg/sentry/safemem/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "safemem", - srcs = [ - "block_unsafe.go", - "io.go", - "safemem.go", - "seq_unsafe.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/platform/safecopy", - ], -) - -go_test( - name = "safemem_test", - size = "small", - srcs = [ - "io_test.go", - "seq_test.go", - ], - library = ":safemem", -) diff --git a/pkg/sentry/safemem/block_unsafe.go b/pkg/sentry/safemem/block_unsafe.go deleted file mode 100644 index 6f03c94bf..000000000 --- a/pkg/sentry/safemem/block_unsafe.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "fmt" - "reflect" - "unsafe" - - "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" -) - -// A Block is a range of contiguous bytes, similar to []byte but with the -// following differences: -// -// - The memory represented by a Block may require the use of safecopy to -// access. -// -// - Block does not carry a capacity and cannot be expanded. -// -// Blocks are immutable and may be copied by value. The zero value of Block -// represents an empty range, analogous to a nil []byte. -type Block struct { - // [start, start+length) is the represented memory. - // - // start is an unsafe.Pointer to ensure that Block prevents the represented - // memory from being garbage-collected. - start unsafe.Pointer - length int - - // needSafecopy is true if accessing the represented memory requires the - // use of safecopy. - needSafecopy bool -} - -// BlockFromSafeSlice returns a Block equivalent to slice, which is safe to -// access without safecopy. -func BlockFromSafeSlice(slice []byte) Block { - return blockFromSlice(slice, false) -} - -// BlockFromUnsafeSlice returns a Block equivalent to bs, which is not safe to -// access without safecopy. -func BlockFromUnsafeSlice(slice []byte) Block { - return blockFromSlice(slice, true) -} - -func blockFromSlice(slice []byte, needSafecopy bool) Block { - if len(slice) == 0 { - return Block{} - } - return Block{ - start: unsafe.Pointer(&slice[0]), - length: len(slice), - needSafecopy: needSafecopy, - } -} - -// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is -// safe to access without safecopy. -// -// Preconditions: ptr+len does not overflow. -func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block { - return blockFromPointer(ptr, len, false) -} - -// BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which -// is not safe to access without safecopy. -// -// Preconditions: ptr+len does not overflow. -func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block { - return blockFromPointer(ptr, len, true) -} - -func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block { - if uptr := uintptr(ptr); uptr+uintptr(len) < uptr { - panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len)) - } - return Block{ - start: ptr, - length: len, - needSafecopy: needSafecopy, - } -} - -// DropFirst returns a Block equivalent to b, but with the first n bytes -// omitted. It is analogous to the [n:] operation on a slice, except that if n -// > b.Len(), DropFirst returns an empty Block instead of panicking. -// -// Preconditions: n >= 0. -func (b Block) DropFirst(n int) Block { - if n < 0 { - panic(fmt.Sprintf("invalid n: %d", n)) - } - return b.DropFirst64(uint64(n)) -} - -// DropFirst64 is equivalent to DropFirst but takes a uint64. -func (b Block) DropFirst64(n uint64) Block { - if n >= uint64(b.length) { - return Block{} - } - return Block{ - start: unsafe.Pointer(uintptr(b.start) + uintptr(n)), - length: b.length - int(n), - needSafecopy: b.needSafecopy, - } -} - -// TakeFirst returns a Block equivalent to the first n bytes of b. It is -// analogous to the [:n] operation on a slice, except that if n > b.Len(), -// TakeFirst returns a copy of b instead of panicking. -// -// Preconditions: n >= 0. -func (b Block) TakeFirst(n int) Block { - if n < 0 { - panic(fmt.Sprintf("invalid n: %d", n)) - } - return b.TakeFirst64(uint64(n)) -} - -// TakeFirst64 is equivalent to TakeFirst but takes a uint64. -func (b Block) TakeFirst64(n uint64) Block { - if n == 0 { - return Block{} - } - if n >= uint64(b.length) { - return b - } - return Block{ - start: b.start, - length: int(n), - needSafecopy: b.needSafecopy, - } -} - -// ToSlice returns a []byte equivalent to b. -func (b Block) ToSlice() []byte { - var bs []byte - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) - hdr.Data = uintptr(b.start) - hdr.Len = b.length - hdr.Cap = b.length - return bs -} - -// Addr returns b's start address as a uintptr. It returns uintptr instead of -// unsafe.Pointer so that code using safemem cannot obtain unsafe.Pointers -// without importing the unsafe package explicitly. -// -// Note that a uintptr is not recognized as a pointer by the garbage collector, -// such that if there are no uses of b after a call to b.Addr() and the address -// is to Go-managed memory, the returned uintptr does not prevent garbage -// collection of the pointee. -func (b Block) Addr() uintptr { - return uintptr(b.start) -} - -// Len returns b's length in bytes. -func (b Block) Len() int { - return b.length -} - -// NeedSafecopy returns true if accessing b.ToSlice() requires the use of safecopy. -func (b Block) NeedSafecopy() bool { - return b.needSafecopy -} - -// String implements fmt.Stringer.String. -func (b Block) String() string { - if uintptr(b.start) == 0 && b.length == 0 { - return "" - } - var suffix string - if b.needSafecopy { - suffix = "*" - } - return fmt.Sprintf("[%#x-%#x)%s", uintptr(b.start), uintptr(b.start)+uintptr(b.length), suffix) -} - -// Copy copies src.Len() or dst.Len() bytes, whichever is less, from src -// to dst and returns the number of bytes copied. -// -// If src and dst overlap, the data stored in dst is unspecified. -func Copy(dst, src Block) (int, error) { - if !dst.needSafecopy && !src.needSafecopy { - return copy(dst.ToSlice(), src.ToSlice()), nil - } - - n := dst.length - if n > src.length { - n = src.length - } - if n == 0 { - return 0, nil - } - - switch { - case dst.needSafecopy && !src.needSafecopy: - return safecopy.CopyOut(dst.start, src.TakeFirst(n).ToSlice()) - case !dst.needSafecopy && src.needSafecopy: - return safecopy.CopyIn(dst.TakeFirst(n).ToSlice(), src.start) - case dst.needSafecopy && src.needSafecopy: - n64, err := safecopy.Copy(dst.start, src.start, uintptr(n)) - return int(n64), err - default: - panic("unreachable") - } -} - -// Zero sets all bytes in dst to 0 and returns the number of bytes zeroed. -func Zero(dst Block) (int, error) { - if !dst.needSafecopy { - bs := dst.ToSlice() - for i := range bs { - bs[i] = 0 - } - return len(bs), nil - } - - n64, err := safecopy.ZeroOut(dst.start, uintptr(dst.length)) - return int(n64), err -} - -// Safecopy atomics are no slower than non-safecopy atomics, so use the former -// even when !b.needSafecopy to get consistent alignment checking. - -// SwapUint32 invokes safecopy.SwapUint32 on the first 4 bytes of b. -// -// Preconditions: b.Len() >= 4. -func SwapUint32(b Block, new uint32) (uint32, error) { - if b.length < 4 { - panic(fmt.Sprintf("insufficient length: %d", b.length)) - } - return safecopy.SwapUint32(b.start, new) -} - -// SwapUint64 invokes safecopy.SwapUint64 on the first 8 bytes of b. -// -// Preconditions: b.Len() >= 8. -func SwapUint64(b Block, new uint64) (uint64, error) { - if b.length < 8 { - panic(fmt.Sprintf("insufficient length: %d", b.length)) - } - return safecopy.SwapUint64(b.start, new) -} - -// CompareAndSwapUint32 invokes safecopy.CompareAndSwapUint32 on the first 4 -// bytes of b. -// -// Preconditions: b.Len() >= 4. -func CompareAndSwapUint32(b Block, old, new uint32) (uint32, error) { - if b.length < 4 { - panic(fmt.Sprintf("insufficient length: %d", b.length)) - } - return safecopy.CompareAndSwapUint32(b.start, old, new) -} - -// LoadUint32 invokes safecopy.LoadUint32 on the first 4 bytes of b. -// -// Preconditions: b.Len() >= 4. -func LoadUint32(b Block) (uint32, error) { - if b.length < 4 { - panic(fmt.Sprintf("insufficient length: %d", b.length)) - } - return safecopy.LoadUint32(b.start) -} diff --git a/pkg/sentry/safemem/io.go b/pkg/sentry/safemem/io.go deleted file mode 100644 index f039a5c34..000000000 --- a/pkg/sentry/safemem/io.go +++ /dev/null @@ -1,392 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "errors" - "io" - "math" -) - -// ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write -// beyond the end of the BlockSeq. -var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq") - -// Reader represents a streaming byte source like io.Reader. -type Reader interface { - // ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the - // number of bytes read. It may return a partial read without an error - // (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a - // full read with an error (i.e. (dsts.NumBytes(), err) where err != nil); - // note that this differs from io.Reader.Read (in particular, io.EOF should - // not be returned if ReadToBlocks successfully reads dsts.NumBytes() - // bytes.) - ReadToBlocks(dsts BlockSeq) (uint64, error) -} - -// Writer represents a streaming byte sink like io.Writer. -type Writer interface { - // WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns - // the number of bytes written. It may return a partial write without an - // error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not - // return a full write with an error (i.e. srcs.NumBytes(), err) where err - // != nil). - WriteFromBlocks(srcs BlockSeq) (uint64, error) -} - -// ReadFullToBlocks repeatedly invokes r.ReadToBlocks until dsts.NumBytes() -// bytes have been read or ReadToBlocks returns an error. -func ReadFullToBlocks(r Reader, dsts BlockSeq) (uint64, error) { - var done uint64 - for !dsts.IsEmpty() { - n, err := r.ReadToBlocks(dsts) - done += n - if err != nil { - return done, err - } - dsts = dsts.DropFirst64(n) - } - return done, nil -} - -// WriteFullFromBlocks repeatedly invokes w.WriteFromBlocks until -// srcs.NumBytes() bytes have been written or WriteFromBlocks returns an error. -func WriteFullFromBlocks(w Writer, srcs BlockSeq) (uint64, error) { - var done uint64 - for !srcs.IsEmpty() { - n, err := w.WriteFromBlocks(srcs) - done += n - if err != nil { - return done, err - } - srcs = srcs.DropFirst64(n) - } - return done, nil -} - -// BlockSeqReader implements Reader by reading from a BlockSeq. -type BlockSeqReader struct { - Blocks BlockSeq -} - -// ReadToBlocks implements Reader.ReadToBlocks. -func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { - n, err := CopySeq(dsts, r.Blocks) - r.Blocks = r.Blocks.DropFirst64(n) - if err != nil { - return n, err - } - if n < dsts.NumBytes() { - return n, io.EOF - } - return n, nil -} - -// BlockSeqWriter implements Writer by writing to a BlockSeq. -type BlockSeqWriter struct { - Blocks BlockSeq -} - -// WriteFromBlocks implements Writer.WriteFromBlocks. -func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { - n, err := CopySeq(w.Blocks, srcs) - w.Blocks = w.Blocks.DropFirst64(n) - if err != nil { - return n, err - } - if n < srcs.NumBytes() { - return n, ErrEndOfBlockSeq - } - return n, nil -} - -// ReaderFunc implements Reader for a function with the semantics of -// Reader.ReadToBlocks. -type ReaderFunc func(dsts BlockSeq) (uint64, error) - -// ReadToBlocks implements Reader.ReadToBlocks. -func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { - return f(dsts) -} - -// WriterFunc implements Writer for a function with the semantics of -// Writer.WriteFromBlocks. -type WriterFunc func(srcs BlockSeq) (uint64, error) - -// WriteFromBlocks implements Writer.WriteFromBlocks. -func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { - return f(srcs) -} - -// ToIOReader implements io.Reader for a (safemem.)Reader. -// -// ToIOReader will return a successful partial read iff Reader.ReadToBlocks does -// so. -type ToIOReader struct { - Reader Reader -} - -// Read implements io.Reader.Read. -func (r ToIOReader) Read(dst []byte) (int, error) { - n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst))) - return int(n), err -} - -// ToIOWriter implements io.Writer for a (safemem.)Writer. -type ToIOWriter struct { - Writer Writer -} - -// Write implements io.Writer.Write. -func (w ToIOWriter) Write(src []byte) (int, error) { - // io.Writer does not permit partial writes. - n, err := WriteFullFromBlocks(w.Writer, BlockSeqOf(BlockFromSafeSlice(src))) - return int(n), err -} - -// FromIOReader implements Reader for an io.Reader by repeatedly invoking -// io.Reader.Read until it returns an error or partial read. This is not -// thread-safe. -// -// FromIOReader will return a successful partial read iff Reader.Read does so. -type FromIOReader struct { - Reader io.Reader -} - -// ReadToBlocks implements Reader.ReadToBlocks. -func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { - var buf []byte - var done uint64 - for !dsts.IsEmpty() { - dst := dsts.Head() - var n int - var err error - n, buf, err = r.readToBlock(dst, buf) - done += uint64(n) - if n != dst.Len() { - return done, err - } - dsts = dsts.Tail() - if err != nil { - if dsts.IsEmpty() && err == io.EOF { - return done, nil - } - return done, err - } - } - return done, nil -} - -func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) { - // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require - // safecopy. - if !dst.NeedSafecopy() { - n, err := r.Reader.Read(dst.ToSlice()) - return n, buf, err - } - if len(buf) < dst.Len() { - buf = make([]byte, dst.Len()) - } - rn, rerr := r.Reader.Read(buf[:dst.Len()]) - wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) - if wberr != nil { - return wbn, buf, wberr - } - return wbn, buf, rerr -} - -// FromIOReaderAt implements Reader for an io.ReaderAt. Does not repeatedly -// invoke io.ReaderAt.ReadAt because ReadAt is more strict than Read. A partial -// read indicates an error. This is not thread-safe. -type FromIOReaderAt struct { - ReaderAt io.ReaderAt - Offset int64 -} - -// ReadToBlocks implements Reader.ReadToBlocks. -func (r FromIOReaderAt) ReadToBlocks(dsts BlockSeq) (uint64, error) { - var buf []byte - var done uint64 - for !dsts.IsEmpty() { - dst := dsts.Head() - var n int - var err error - n, buf, err = r.readToBlock(dst, buf) - done += uint64(n) - if n != dst.Len() { - return done, err - } - dsts = dsts.Tail() - if err != nil { - if dsts.IsEmpty() && err == io.EOF { - return done, nil - } - return done, err - } - } - return done, nil -} - -func (r FromIOReaderAt) readToBlock(dst Block, buf []byte) (int, []byte, error) { - // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require - // safecopy. - if !dst.NeedSafecopy() { - n, err := r.ReaderAt.ReadAt(dst.ToSlice(), r.Offset) - r.Offset += int64(n) - return n, buf, err - } - if len(buf) < dst.Len() { - buf = make([]byte, dst.Len()) - } - rn, rerr := r.ReaderAt.ReadAt(buf[:dst.Len()], r.Offset) - r.Offset += int64(rn) - wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) - if wberr != nil { - return wbn, buf, wberr - } - return wbn, buf, rerr -} - -// FromIOWriter implements Writer for an io.Writer by repeatedly invoking -// io.Writer.Write until it returns an error or partial write. -// -// FromIOWriter will tolerate implementations of io.Writer.Write that return -// partial writes with a nil error in contravention of io.Writer's -// requirements, since Writer is permitted to do so. FromIOWriter will return a -// successful partial write iff Writer.Write does so. -type FromIOWriter struct { - Writer io.Writer -} - -// WriteFromBlocks implements Writer.WriteFromBlocks. -func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { - var buf []byte - var done uint64 - for !srcs.IsEmpty() { - src := srcs.Head() - var n int - var err error - n, buf, err = w.writeFromBlock(src, buf) - done += uint64(n) - if n != src.Len() || err != nil { - return done, err - } - srcs = srcs.Tail() - } - return done, nil -} - -func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) { - // io.Writer isn't safecopy-aware, so we have to buffer Blocks that require - // safecopy. - if !src.NeedSafecopy() { - n, err := w.Writer.Write(src.ToSlice()) - return n, buf, err - } - if len(buf) < src.Len() { - buf = make([]byte, src.Len()) - } - bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src) - wn, werr := w.Writer.Write(buf[:bufn]) - if werr != nil { - return wn, buf, werr - } - return wn, buf, buferr -} - -// FromVecReaderFunc implements Reader for a function that reads data into a -// [][]byte and returns the number of bytes read as an int64. -type FromVecReaderFunc struct { - ReadVec func(dsts [][]byte) (int64, error) -} - -// ReadToBlocks implements Reader.ReadToBlocks. -// -// ReadToBlocks calls r.ReadVec at most once. -func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { - if dsts.IsEmpty() { - return 0, nil - } - // Ensure that we don't pass a [][]byte with a total length > MaxInt64. - dsts = dsts.TakeFirst64(uint64(math.MaxInt64)) - dstSlices := make([][]byte, 0, dsts.NumBlocks()) - // Buffer Blocks that require safecopy. - for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() { - dst := tmp.Head() - if dst.NeedSafecopy() { - dstSlices = append(dstSlices, make([]byte, dst.Len())) - } else { - dstSlices = append(dstSlices, dst.ToSlice()) - } - } - rn, rerr := r.ReadVec(dstSlices) - dsts = dsts.TakeFirst64(uint64(rn)) - var done uint64 - var i int - for !dsts.IsEmpty() { - dst := dsts.Head() - if dst.NeedSafecopy() { - n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i])) - done += uint64(n) - if err != nil { - return done, err - } - } else { - done += uint64(dst.Len()) - } - dsts = dsts.Tail() - i++ - } - return done, rerr -} - -// FromVecWriterFunc implements Writer for a function that writes data from a -// [][]byte and returns the number of bytes written. -type FromVecWriterFunc struct { - WriteVec func(srcs [][]byte) (int64, error) -} - -// WriteFromBlocks implements Writer.WriteFromBlocks. -// -// WriteFromBlocks calls w.WriteVec at most once. -func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { - if srcs.IsEmpty() { - return 0, nil - } - // Ensure that we don't pass a [][]byte with a total length > MaxInt64. - srcs = srcs.TakeFirst64(uint64(math.MaxInt64)) - srcSlices := make([][]byte, 0, srcs.NumBlocks()) - // Buffer Blocks that require safecopy. - var buferr error - for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() { - src := tmp.Head() - if src.NeedSafecopy() { - slice := make([]byte, src.Len()) - n, err := Copy(BlockFromSafeSlice(slice), src) - srcSlices = append(srcSlices, slice[:n]) - if err != nil { - buferr = err - break - } - } else { - srcSlices = append(srcSlices, src.ToSlice()) - } - } - n, err := w.WriteVec(srcSlices) - if err != nil { - return uint64(n), err - } - return uint64(n), buferr -} diff --git a/pkg/sentry/safemem/io_test.go b/pkg/sentry/safemem/io_test.go deleted file mode 100644 index 629741bee..000000000 --- a/pkg/sentry/safemem/io_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "io" - "testing" -) - -func makeBlocks(slices ...[]byte) []Block { - blocks := make([]Block, 0, len(slices)) - for _, s := range slices { - blocks = append(blocks, BlockFromSafeSlice(s)) - } - return blocks -} - -func TestFromIOReaderFullRead(t *testing.T) { - r := FromIOReader{bytes.NewBufferString("foobar")} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -type eofHidingReader struct { - Reader io.Reader -} - -func (r eofHidingReader) Read(dst []byte) (int, error) { - n, err := r.Reader.Read(dst) - if err == io.EOF { - return n, nil - } - return n, err -} - -func TestFromIOReaderPartialRead(t *testing.T) { - r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - // FromIOReader should stop after the eofHidingReader returns (1, nil) - // for a 3-byte read. - if wantN := uint64(4); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -type singleByteReader struct { - Reader io.Reader -} - -func (r singleByteReader) Read(dst []byte) (int, error) { - if len(dst) == 0 { - return r.Reader.Read(dst) - } - return r.Reader.Read(dst[:1]) -} - -func TestSingleByteReader(t *testing.T) { - r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - // FromIOReader should stop after the singleByteReader returns (1, nil) - // for a 3-byte read. - if wantN := uint64(1); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -func TestReadFullToBlocks(t *testing.T) { - r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts)) - // ReadFullToBlocks should call into FromIOReader => singleByteReader - // repeatedly until dsts is exhausted. - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -func TestFromIOWriterFullWrite(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{&dst} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -type limitedWriter struct { - Writer io.Writer - Done int - Limit int -} - -func (w *limitedWriter) Write(src []byte) (int, error) { - count := len(src) - if count > (w.Limit - w.Done) { - count = w.Limit - w.Done - } - n, err := w.Writer.Write(src[:count]) - w.Done += n - return n, err -} - -func TestFromIOWriterPartialWrite(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{&limitedWriter{&dst, 0, 4}} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - // FromIOWriter should stop after the limitedWriter returns (1, nil) for a - // 3-byte write. - if wantN := uint64(4); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -type singleByteWriter struct { - Writer io.Writer -} - -func (w singleByteWriter) Write(src []byte) (int, error) { - if len(src) == 0 { - return w.Writer.Write(src) - } - return w.Writer.Write(src[:1]) -} - -func TestSingleByteWriter(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{singleByteWriter{&dst}} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - // FromIOWriter should stop after the singleByteWriter returns (1, nil) - // for a 3-byte write. - if wantN := uint64(1); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestWriteFullToBlocks(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{singleByteWriter{&dst}} - n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs)) - // WriteFullToBlocks should call into FromIOWriter => singleByteWriter - // repeatedly until srcs is exhausted. - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} diff --git a/pkg/sentry/safemem/safemem.go b/pkg/sentry/safemem/safemem.go deleted file mode 100644 index 3e70d33a2..000000000 --- a/pkg/sentry/safemem/safemem.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package safemem provides the Block and BlockSeq types. -package safemem diff --git a/pkg/sentry/safemem/seq_test.go b/pkg/sentry/safemem/seq_test.go deleted file mode 100644 index eba4bb535..000000000 --- a/pkg/sentry/safemem/seq_test.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "reflect" - "testing" -) - -type blockSeqTest struct { - desc string - - pieces []string - haveOffset bool - offset uint64 - haveLimit bool - limit uint64 - - want string -} - -func (t blockSeqTest) NonEmptyByteSlices() [][]byte { - // t is a value, so we can mutate it freely. - slices := make([][]byte, 0, len(t.pieces)) - for _, str := range t.pieces { - if t.haveOffset { - strOff := t.offset - if strOff > uint64(len(str)) { - strOff = uint64(len(str)) - } - str = str[strOff:] - t.offset -= strOff - } - if t.haveLimit { - strLim := t.limit - if strLim > uint64(len(str)) { - strLim = uint64(len(str)) - } - str = str[:strLim] - t.limit -= strLim - } - if len(str) != 0 { - slices = append(slices, []byte(str)) - } - } - return slices -} - -func (t blockSeqTest) BlockSeq() BlockSeq { - blocks := make([]Block, 0, len(t.pieces)) - for _, str := range t.pieces { - blocks = append(blocks, BlockFromSafeSlice([]byte(str))) - } - bs := BlockSeqFromSlice(blocks) - if t.haveOffset { - bs = bs.DropFirst64(t.offset) - } - if t.haveLimit { - bs = bs.TakeFirst64(t.limit) - } - return bs -} - -var blockSeqTests = []blockSeqTest{ - { - desc: "Empty sequence", - }, - { - desc: "Sequence of length 1", - pieces: []string{"foobar"}, - want: "foobar", - }, - { - desc: "Sequence of length 2", - pieces: []string{"foo", "bar"}, - want: "foobar", - }, - { - desc: "Empty Blocks", - pieces: []string{"", "foo", "", "", "bar", ""}, - want: "foobar", - }, - { - desc: "Sequence with non-zero offset", - pieces: []string{"foo", "bar"}, - haveOffset: true, - offset: 2, - want: "obar", - }, - { - desc: "Sequence with non-maximal limit", - pieces: []string{"foo", "bar"}, - haveLimit: true, - limit: 5, - want: "fooba", - }, - { - desc: "Sequence with offset and limit", - pieces: []string{"foo", "bar"}, - haveOffset: true, - offset: 2, - haveLimit: true, - limit: 3, - want: "oba", - }, -} - -func TestBlockSeqNumBytes(t *testing.T) { - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - if got, want := test.BlockSeq().NumBytes(), uint64(len(test.want)); got != want { - t.Errorf("NumBytes: got %d, wanted %d", got, want) - } - }) - } -} - -func TestBlockSeqIterBlocks(t *testing.T) { - // Tests BlockSeq iteration using Head/Tail. - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - srcs := test.BlockSeq() - // "Note that a non-nil empty slice and a nil slice ... are not - // deeply equal." - reflect - slices := make([][]byte, 0, 0) - for !srcs.IsEmpty() { - src := srcs.Head() - slices = append(slices, src.ToSlice()) - nextSrcs := srcs.Tail() - if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-uint64(src.Len()); got != want { - t.Fatalf("%v.Tail(): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) - } - srcs = nextSrcs - } - if wantSlices := test.NonEmptyByteSlices(); !reflect.DeepEqual(slices, wantSlices) { - t.Errorf("Accumulated slices: got %v, wanted %v", slices, wantSlices) - } - }) - } -} - -func TestBlockSeqIterBytes(t *testing.T) { - // Tests BlockSeq iteration using Head/DropFirst. - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - srcs := test.BlockSeq() - var dst bytes.Buffer - for !srcs.IsEmpty() { - src := srcs.Head() - var b [1]byte - n, err := Copy(BlockFromSafeSlice(b[:]), src) - if n != 1 || err != nil { - t.Fatalf("Copy: got (%v, %v), wanted (1, nil)", n, err) - } - dst.WriteByte(b[0]) - nextSrcs := srcs.DropFirst(1) - if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-1; got != want { - t.Fatalf("%v.DropFirst(1): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) - } - srcs = nextSrcs - } - if got := string(dst.Bytes()); got != test.want { - t.Errorf("Copied string: got %q, wanted %q", got, test.want) - } - }) - } -} - -func TestBlockSeqDropBeyondLimit(t *testing.T) { - blocks := []Block{BlockFromSafeSlice([]byte("123")), BlockFromSafeSlice([]byte("4"))} - bs := BlockSeqFromSlice(blocks) - if got, want := bs.NumBytes(), uint64(4); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } - bs = bs.TakeFirst(1) - if got, want := bs.NumBytes(), uint64(1); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } - bs = bs.DropFirst(2) - if got, want := bs.NumBytes(), uint64(0); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } -} diff --git a/pkg/sentry/safemem/seq_unsafe.go b/pkg/sentry/safemem/seq_unsafe.go deleted file mode 100644 index 354a95dde..000000000 --- a/pkg/sentry/safemem/seq_unsafe.go +++ /dev/null @@ -1,299 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "fmt" - "reflect" - "unsafe" -) - -// A BlockSeq represents a sequence of Blocks, each of which has non-zero -// length. -// -// BlockSeqs are immutable and may be copied by value. The zero value of -// BlockSeq represents an empty sequence. -type BlockSeq struct { - // If length is 0, then the BlockSeq is empty. Invariants: data == 0; - // offset == 0; limit == 0. - // - // If length is -1, then the BlockSeq represents the single Block{data, - // limit, false}. Invariants: offset == 0; limit > 0; limit does not - // overflow the range of an int. - // - // If length is -2, then the BlockSeq represents the single Block{data, - // limit, true}. Invariants: offset == 0; limit > 0; limit does not - // overflow the range of an int. - // - // Otherwise, length >= 2, and the BlockSeq represents the `length` Blocks - // in the array of Blocks starting at address `data`, starting at `offset` - // bytes into the first Block and limited to the following `limit` bytes. - // Invariants: data != 0; offset < len(data[0]); limit > 0; offset+limit <= - // the combined length of all Blocks in the array; the first Block in the - // array has non-zero length. - // - // length is never 1; sequences consisting of a single Block are always - // stored inline (with length < 0). - data unsafe.Pointer - length int - offset int - limit uint64 -} - -// BlockSeqOf returns a BlockSeq representing the single Block b. -func BlockSeqOf(b Block) BlockSeq { - bs := BlockSeq{ - data: b.start, - length: -1, - limit: uint64(b.length), - } - if b.needSafecopy { - bs.length = -2 - } - return bs -} - -// BlockSeqFromSlice returns a BlockSeq representing all Blocks in slice. -// If slice contains Blocks with zero length, BlockSeq will skip them during -// iteration. -// -// Whether the returned BlockSeq shares memory with slice is unspecified; -// clients should avoid mutating slices passed to BlockSeqFromSlice. -// -// Preconditions: The combined length of all Blocks in slice <= math.MaxUint64. -func BlockSeqFromSlice(slice []Block) BlockSeq { - slice = skipEmpty(slice) - var limit uint64 - for _, b := range slice { - sum := limit + uint64(b.Len()) - if sum < limit { - panic("BlockSeq length overflows uint64") - } - limit = sum - } - return blockSeqFromSliceLimited(slice, limit) -} - -// Preconditions: The combined length of all Blocks in slice <= limit. If -// len(slice) != 0, the first Block in slice has non-zero length, and limit > -// 0. -func blockSeqFromSliceLimited(slice []Block, limit uint64) BlockSeq { - switch len(slice) { - case 0: - return BlockSeq{} - case 1: - return BlockSeqOf(slice[0].TakeFirst64(limit)) - default: - return BlockSeq{ - data: unsafe.Pointer(&slice[0]), - length: len(slice), - limit: limit, - } - } -} - -func skipEmpty(slice []Block) []Block { - for i, b := range slice { - if b.Len() != 0 { - return slice[i:] - } - } - return nil -} - -// IsEmpty returns true if bs contains no Blocks. -// -// Invariants: bs.IsEmpty() == (bs.NumBlocks() == 0) == (bs.NumBytes() == 0). -// (Of these, prefer to use bs.IsEmpty().) -func (bs BlockSeq) IsEmpty() bool { - return bs.length == 0 -} - -// NumBlocks returns the number of Blocks in bs. -func (bs BlockSeq) NumBlocks() int { - // In general, we have to count: if bs represents a windowed slice then the - // slice may contain Blocks with zero length, and bs.length may be larger - // than the actual number of Blocks due to bs.limit. - var n int - for !bs.IsEmpty() { - n++ - bs = bs.Tail() - } - return n -} - -// NumBytes returns the sum of Block.Len() for all Blocks in bs. -func (bs BlockSeq) NumBytes() uint64 { - return bs.limit -} - -// Head returns the first Block in bs. -// -// Preconditions: !bs.IsEmpty(). -func (bs BlockSeq) Head() Block { - if bs.length == 0 { - panic("empty BlockSeq") - } - if bs.length < 0 { - return bs.internalBlock() - } - return (*Block)(bs.data).DropFirst(bs.offset).TakeFirst64(bs.limit) -} - -// Preconditions: bs.length < 0. -func (bs BlockSeq) internalBlock() Block { - return Block{ - start: bs.data, - length: int(bs.limit), - needSafecopy: bs.length == -2, - } -} - -// Tail returns a BlockSeq consisting of all Blocks in bs after the first. -// -// Preconditions: !bs.IsEmpty(). -func (bs BlockSeq) Tail() BlockSeq { - if bs.length == 0 { - panic("empty BlockSeq") - } - if bs.length < 0 { - return BlockSeq{} - } - head := (*Block)(bs.data).DropFirst(bs.offset) - headLen := uint64(head.Len()) - if headLen >= bs.limit { - // The head Block exhausts the limit, so the tail is empty. - return BlockSeq{} - } - var extSlice []Block - extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice)) - extSliceHdr.Data = uintptr(bs.data) - extSliceHdr.Len = bs.length - extSliceHdr.Cap = bs.length - tailSlice := skipEmpty(extSlice[1:]) - tailLimit := bs.limit - headLen - return blockSeqFromSliceLimited(tailSlice, tailLimit) -} - -// DropFirst returns a BlockSeq equivalent to bs, but with the first n bytes -// omitted. If n > bs.NumBytes(), DropFirst returns an empty BlockSeq. -// -// Preconditions: n >= 0. -func (bs BlockSeq) DropFirst(n int) BlockSeq { - if n < 0 { - panic(fmt.Sprintf("invalid n: %d", n)) - } - return bs.DropFirst64(uint64(n)) -} - -// DropFirst64 is equivalent to DropFirst but takes an uint64. -func (bs BlockSeq) DropFirst64(n uint64) BlockSeq { - if n >= bs.limit { - return BlockSeq{} - } - for { - // Calling bs.Head() here is surprisingly expensive, so inline getting - // the head's length. - var headLen uint64 - if bs.length < 0 { - headLen = bs.limit - } else { - headLen = uint64((*Block)(bs.data).Len() - bs.offset) - } - if n < headLen { - // Dropping ends partway through the head Block. - if bs.length < 0 { - return BlockSeqOf(bs.internalBlock().DropFirst64(n)) - } - bs.offset += int(n) - bs.limit -= n - return bs - } - n -= headLen - bs = bs.Tail() - } -} - -// TakeFirst returns a BlockSeq equivalent to the first n bytes of bs. If n > -// bs.NumBytes(), TakeFirst returns a BlockSeq equivalent to bs. -// -// Preconditions: n >= 0. -func (bs BlockSeq) TakeFirst(n int) BlockSeq { - if n < 0 { - panic(fmt.Sprintf("invalid n: %d", n)) - } - return bs.TakeFirst64(uint64(n)) -} - -// TakeFirst64 is equivalent to TakeFirst but takes a uint64. -func (bs BlockSeq) TakeFirst64(n uint64) BlockSeq { - if n == 0 { - return BlockSeq{} - } - if bs.limit > n { - bs.limit = n - } - return bs -} - -// String implements fmt.Stringer.String. -func (bs BlockSeq) String() string { - var buf bytes.Buffer - buf.WriteByte('[') - var sep string - for !bs.IsEmpty() { - buf.WriteString(sep) - sep = " " - buf.WriteString(bs.Head().String()) - bs = bs.Tail() - } - buf.WriteByte(']') - return buf.String() -} - -// CopySeq copies srcs.NumBytes() or dsts.NumBytes() bytes, whichever is less, -// from srcs to dsts and returns the number of bytes copied. -// -// If srcs and dsts overlap, the data stored in dsts is unspecified. -func CopySeq(dsts, srcs BlockSeq) (uint64, error) { - var done uint64 - for !dsts.IsEmpty() && !srcs.IsEmpty() { - dst := dsts.Head() - src := srcs.Head() - n, err := Copy(dst, src) - done += uint64(n) - if err != nil { - return done, err - } - dsts = dsts.DropFirst(n) - srcs = srcs.DropFirst(n) - } - return done, nil -} - -// ZeroSeq sets all bytes in dsts to 0 and returns the number of bytes zeroed. -func ZeroSeq(dsts BlockSeq) (uint64, error) { - var done uint64 - for !dsts.IsEmpty() { - n, err := Zero(dsts.Head()) - done += uint64(n) - if err != nil { - return done, err - } - dsts = dsts.DropFirst(n) - } - return done, nil -} diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 8e2b97afb..611fa22c3 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -9,15 +9,15 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/tcpip", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 3850f6345..79e16d6e8 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -12,13 +12,13 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 1684dfc24..00265f15b 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -19,14 +19,14 @@ package control import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const maxInt = int(^uint(0) >> 1) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 42bf7be6a..5a07d5d0e 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -16,23 +16,23 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/fdnotifier", "//pkg/log", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", - "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip/stack", + "//pkg/usermem", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index c957b0f1d..bde4c7a1e 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -21,19 +21,19 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index e69ec38c2..cd67234d2 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -19,14 +19,14 @@ import ( "unsafe" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) func firstBytePtr(bs []byte) unsafe.Pointer { diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index e67b46c9e..034eca676 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -25,13 +25,13 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" ) var defaultRecvBufSize = inet.TCPBufferSize{ diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index ed34a8308..fa2a2cb66 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -15,10 +15,10 @@ go_library( "//pkg/binary", "//pkg/log", "//pkg/sentry/kernel", - "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/iptables", "//pkg/tcpip/stack", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index c65c36081..6ef740463 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -23,11 +23,11 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" ) // errorTargetName is used to mark targets as error targets. Error targets diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index baaac13c6..f8b8e467d 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -13,8 +13,8 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", @@ -25,11 +25,11 @@ go_library( "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index ce0a1afd0..b21e0ca4b 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -20,7 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // alignUp rounds a length up to an alignment. diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index be005df24..07f860a49 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -18,7 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 2137c7aeb..0234aadde 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -8,7 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 6b4a0ecf4..80a15d6cb 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -19,7 +19,7 @@ import ( "bytes" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index cea56f4ed..c4b95debb 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -32,11 +32,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD index 73fbdf1eb..b6434923c 100644 --- a/pkg/sentry/socket/netlink/uevent/BUILD +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -8,7 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/kernel", "//pkg/sentry/socket/netlink", "//pkg/syserr", diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go index b5d7808d7..1ee4296bc 100644 --- a/pkg/sentry/socket/netlink/uevent/protocol.go +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -20,7 +20,7 @@ package uevent import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/syserr" diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e3d1f90cb..ab01cb4fa 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -17,10 +17,11 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/log", "//pkg/metric", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", @@ -28,11 +29,9 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", @@ -45,6 +44,7 @@ go_library( "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 318acbeff..8619cc506 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -34,20 +34,19 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -57,6 +56,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 2d2c1ba2a..5afff2564 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -18,7 +18,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 2389a9cdb..50d9744e6 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -24,16 +24,16 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" ) // ControlMessages represents the union of unix control messages and tcpip diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index bade18686..08743deba 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -12,23 +12,23 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 2447f24ef..129949990 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -15,8 +15,8 @@ package unix import ( - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 4bdfc9208..74bcd6300 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -28,9 +28,9 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/ilist", "//pkg/refs", - "//pkg/sentry/context", "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9e6fbc111..ce5b94ee7 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -16,7 +16,7 @@ package transport import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 0322dec0b..4b06d63ac 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -16,7 +16,7 @@ package transport import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index fcc0da332..dcbafe0e5 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -19,7 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 7f49ba864..4d30aa714 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -22,9 +22,9 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -33,10 +33,10 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index ff6fafa63..762a946fe 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -34,7 +34,7 @@ go_library( "//pkg/sentry/socket/netlink", "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", - "//pkg/sentry/usermem", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/strace/poll.go b/pkg/sentry/strace/poll.go index 5187594a7..074e80f9b 100644 --- a/pkg/sentry/strace/poll.go +++ b/pkg/sentry/strace/poll.go @@ -22,7 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // PollEventSet is the set of poll(2) event flags. diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go index c77d418e6..3a4c32aa0 100644 --- a/pkg/sentry/strace/select.go +++ b/pkg/sentry/strace/select.go @@ -19,7 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) func fdsFromSet(t *kernel.Task, set []byte) []int { diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go index 5656d53eb..c41f36e3f 100644 --- a/pkg/sentry/strace/signal.go +++ b/pkg/sentry/strace/signal.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // signalNames contains the names of all named signals. diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index b6d7177f4..d2079c85f 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -26,7 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // SocketFamily are the possible socket(2) families. diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 629c1f308..3fc4a47fc 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -33,7 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" pb "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // DefaultLogMaximumSize is the default LogMaximumSize. diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 7d74e0f70..8d6c52850 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -63,11 +63,12 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/bpf", + "//pkg/context", "//pkg/log", "//pkg/metric", "//pkg/rand", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/lock", @@ -87,16 +88,15 @@ go_library( "//pkg/sentry/loader", "//pkg/sentry/memmap", "//pkg/sentry/mm", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/syscalls", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index c76771a54..7435b50bf 100644 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/syscalls" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // AMD64 is a table of Linux amd64 syscall API with the corresponding syscall diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go index d3587fda6..03a39fe65 100644 --- a/pkg/sentry/syscalls/linux/linux64_arm64.go +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/syscalls" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // ARM64 is a table of Linux arm64 syscall API with the corresponding syscall diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go index 333013d8c..2ddb2b146 100644 --- a/pkg/sentry/syscalls/linux/sigset.go +++ b/pkg/sentry/syscalls/linux/sigset.go @@ -17,8 +17,8 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // copyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index f56411bfe..b401978db 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // I/O commands. diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index 65b4a227b..5f11b496c 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" "gvisor.dev/gvisor/pkg/sentry/syscalls" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 9bc2445a5..c54735148 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -18,8 +18,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" @@ -28,8 +28,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/fasync" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // fileOpAt performs an operation on the second last component in the path. diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index bde17a767..b68261f72 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // futexWaitRestartBlock encapsulates the state required to restart futex(2) diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index 912cbe4ff..f66f4ffde 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Getdents implements linux syscall getdents(2) for 64bit systems. diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go index f5a519d8a..ac934dc6f 100644 --- a/pkg/sentry/syscalls/linux/sys_mempolicy.go +++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // We unconditionally report a single NUMA node. This also means that our diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 58a05b5bb..9959f6e61 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Brk implements linux syscall brk(2). diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go index 8c13e2d82..eb5ff48f5 100644 --- a/pkg/sentry/syscalls/linux/sys_mount.go +++ b/pkg/sentry/syscalls/linux/sys_mount.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Mount implements Linux syscall mount(2). diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 418d7fa5f..798344042 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // pipe2 implements the actual system call with flags. diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index 2b2df989a..4f8762d7d 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/syscalls/linux/sys_random.go b/pkg/sentry/syscalls/linux/sys_random.go index bc4c588bf..c0aa0fd60 100644 --- a/pkg/sentry/syscalls/linux/sys_random.go +++ b/pkg/sentry/syscalls/linux/sys_random.go @@ -19,11 +19,11 @@ import ( "math" "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index cd31e0649..f9f594190 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index 51e3f836b..e08c333d6 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // rlimit describes an implementation of 'struct rlimit', which may vary from diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go index 18510ead8..5b7a66f4d 100644 --- a/pkg/sentry/syscalls/linux/sys_seccomp.go +++ b/pkg/sentry/syscalls/linux/sys_seccomp.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // userSockFprog is equivalent to Linux's struct sock_fprog on amd64. diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index cde3b54e7..5f54f2456 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const opsMax = 500 // SEMOPM diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index fb6efd5d8..209be2990 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // "For a process to have permission to send a signal it must diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index cda517a81..2919228d0 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -26,9 +26,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 69b17b799..c841abccb 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Stat implements linux syscall stat(2). diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go index 58afb4a9a..75a567bd4 100644 --- a/pkg/sentry/syscalls/linux/sys_stat_amd64.go +++ b/pkg/sentry/syscalls/linux/sys_stat_amd64.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // copyOutStat copies the attributes (sattr, uattr) to the struct stat at diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go index 3e1251e0b..80c98d05c 100644 --- a/pkg/sentry/syscalls/linux/sys_stat_arm64.go +++ b/pkg/sentry/syscalls/linux/sys_stat_arm64.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) // copyOutStat copies the attributes (sattr, uattr) to the struct stat at diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index b47c3b5c4..0c9e2255d 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -24,8 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" "gvisor.dev/gvisor/pkg/sentry/loader" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const ( diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index b887fa9d7..2d2aa0819 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // The most significant 29 bits hold either a pid or a file descriptor. diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go index d4134207b..432351917 100644 --- a/pkg/sentry/syscalls/linux/sys_timer.go +++ b/pkg/sentry/syscalls/linux/sys_timer.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) const nsecPerSec = int64(time.Second) diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index ad4b67806..aba892939 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 77deb8980..efb95555c 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // GetXattr implements linux syscall getxattr(2). diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go index 4ff8f9234..ddc3ee26e 100644 --- a/pkg/sentry/syscalls/linux/timespec.go +++ b/pkg/sentry/syscalls/linux/timespec.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // copyTimespecIn copies a Timespec from the untrusted app range to the kernel. diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD index 370fa6ec5..5d4aa3a63 100644 --- a/pkg/sentry/unimpl/BUILD +++ b/pkg/sentry/unimpl/BUILD @@ -14,7 +14,7 @@ go_library( srcs = ["events.go"], visibility = ["//:sandbox"], deps = [ + "//pkg/context", "//pkg/log", - "//pkg/sentry/context", ], ) diff --git a/pkg/sentry/unimpl/events.go b/pkg/sentry/unimpl/events.go index 79b5de9e4..73ed9372f 100644 --- a/pkg/sentry/unimpl/events.go +++ b/pkg/sentry/unimpl/events.go @@ -17,8 +17,8 @@ package unimpl import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" ) // contextID is the events package's type for context.Context.Value keys. diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD index e9c18f170..7467e6398 100644 --- a/pkg/sentry/uniqueid/BUILD +++ b/pkg/sentry/uniqueid/BUILD @@ -7,7 +7,7 @@ go_library( srcs = ["context.go"], visibility = ["//pkg/sentry:internal"], deps = [ - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/socket/unix/transport", ], ) diff --git a/pkg/sentry/uniqueid/context.go b/pkg/sentry/uniqueid/context.go index 4e466d66d..1fb884a90 100644 --- a/pkg/sentry/uniqueid/context.go +++ b/pkg/sentry/uniqueid/context.go @@ -17,7 +17,7 @@ package uniqueid import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" ) diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD deleted file mode 100644 index c8322e29e..000000000 --- a/pkg/sentry/usermem/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "addr_range", - out = "addr_range.go", - package = "usermem", - prefix = "Addr", - template = "//pkg/segment:generic_range", - types = { - "T": "Addr", - }, -) - -go_library( - name = "usermem", - srcs = [ - "access_type.go", - "addr.go", - "addr_range.go", - "addr_range_seq_unsafe.go", - "bytes_io.go", - "bytes_io_unsafe.go", - "usermem.go", - "usermem_arm64.go", - "usermem_unsafe.go", - "usermem_x86.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/atomicbitops", - "//pkg/binary", - "//pkg/log", - "//pkg/sentry/context", - "//pkg/sentry/safemem", - "//pkg/syserror", - ], -) - -go_test( - name = "usermem_test", - size = "small", - srcs = [ - "addr_range_seq_test.go", - "usermem_test.go", - ], - library = ":usermem", - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/safemem", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/usermem/README.md b/pkg/sentry/usermem/README.md deleted file mode 100644 index f6d2137eb..000000000 --- a/pkg/sentry/usermem/README.md +++ /dev/null @@ -1,31 +0,0 @@ -This package defines primitives for sentry access to application memory. - -Major types: - -- The `IO` interface represents a virtual address space and provides I/O - methods on that address space. `IO` is the lowest-level primitive. The - primary implementation of the `IO` interface is `mm.MemoryManager`. - -- `IOSequence` represents a collection of individually-contiguous address - ranges in a `IO` that is operated on sequentially, analogous to Linux's - `struct iov_iter`. - -Major usage patterns: - -- Access to a task's virtual memory, subject to the application's memory - protections and while running on that task's goroutine, from a context that - is at or above the level of the `kernel` package (e.g. most syscall - implementations in `syscalls/linux`); use the `kernel.Task.Copy*` wrappers - defined in `kernel/task_usermem.go`. - -- Access to a task's virtual memory, from a context that is at or above the - level of the `kernel` package, but where any of the above constraints does - not hold (e.g. `PTRACE_POKEDATA`, which ignores application memory - protections); obtain the task's `mm.MemoryManager` by calling - `kernel.Task.MemoryManager`, and call its `IO` methods directly. - -- Access to a task's virtual memory, from a context that is below the level of - the `kernel` package (e.g. filesystem I/O); clients must pass I/O arguments - from higher layers, usually in the form of an `IOSequence`. The - `kernel.Task.SingleIOSequence` and `kernel.Task.IovecsIOSequence` functions - in `kernel/task_usermem.go` are convenience functions for doing so. diff --git a/pkg/sentry/usermem/access_type.go b/pkg/sentry/usermem/access_type.go deleted file mode 100644 index 9c1742a59..000000000 --- a/pkg/sentry/usermem/access_type.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "syscall" -) - -// AccessType specifies memory access types. This is used for -// setting mapping permissions, as well as communicating faults. -// -// +stateify savable -type AccessType struct { - // Read is read access. - Read bool - - // Write is write access. - Write bool - - // Execute is executable access. - Execute bool -} - -// String returns a pretty representation of access. This looks like the -// familiar r-x, rw-, etc. and can be relied on as such. -func (a AccessType) String() string { - bits := [3]byte{'-', '-', '-'} - if a.Read { - bits[0] = 'r' - } - if a.Write { - bits[1] = 'w' - } - if a.Execute { - bits[2] = 'x' - } - return string(bits[:]) -} - -// Any returns true iff at least one of Read, Write or Execute is true. -func (a AccessType) Any() bool { - return a.Read || a.Write || a.Execute -} - -// Prot returns the system prot (syscall.PROT_READ, etc.) for this access. -func (a AccessType) Prot() int { - var prot int - if a.Read { - prot |= syscall.PROT_READ - } - if a.Write { - prot |= syscall.PROT_WRITE - } - if a.Execute { - prot |= syscall.PROT_EXEC - } - return prot -} - -// SupersetOf returns true iff the access types in a are a superset of the -// access types in other. -func (a AccessType) SupersetOf(other AccessType) bool { - if !a.Read && other.Read { - return false - } - if !a.Write && other.Write { - return false - } - if !a.Execute && other.Execute { - return false - } - return true -} - -// Intersect returns the access types set in both a and other. -func (a AccessType) Intersect(other AccessType) AccessType { - return AccessType{ - Read: a.Read && other.Read, - Write: a.Write && other.Write, - Execute: a.Execute && other.Execute, - } -} - -// Union returns the access types set in either a or other. -func (a AccessType) Union(other AccessType) AccessType { - return AccessType{ - Read: a.Read || other.Read, - Write: a.Write || other.Write, - Execute: a.Execute || other.Execute, - } -} - -// Effective returns the set of effective access types allowed by a, even if -// some types are not explicitly allowed. -func (a AccessType) Effective() AccessType { - // In Linux, Write and Execute access generally imply Read access. See - // mm/mmap.c:protection_map. - // - // The notable exception is get_user_pages, which only checks against - // the original vma flags. That said, most user memory accesses do not - // use GUP. - if a.Write || a.Execute { - a.Read = true - } - return a -} - -// Convenient access types. -var ( - NoAccess = AccessType{} - Read = AccessType{Read: true} - Write = AccessType{Write: true} - Execute = AccessType{Execute: true} - ReadWrite = AccessType{Read: true, Write: true} - AnyAccess = AccessType{Read: true, Write: true, Execute: true} -) diff --git a/pkg/sentry/usermem/addr.go b/pkg/sentry/usermem/addr.go deleted file mode 100644 index e79210804..000000000 --- a/pkg/sentry/usermem/addr.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "fmt" -) - -// Addr represents a generic virtual address. -// -// +stateify savable -type Addr uintptr - -// AddLength adds the given length to start and returns the result. ok is true -// iff adding the length did not overflow the range of Addr. -// -// Note: This function is usually used to get the end of an address range -// defined by its start address and length. Since the resulting end is -// exclusive, end == 0 is technically valid, and corresponds to a range that -// extends to the end of the address space, but ok will be false. This isn't -// expected to ever come up in practice. -func (v Addr) AddLength(length uint64) (end Addr, ok bool) { - end = v + Addr(length) - // The second half of the following check is needed in case uintptr is - // smaller than 64 bits. - ok = end >= v && length <= uint64(^Addr(0)) - return -} - -// RoundDown returns the address rounded down to the nearest page boundary. -func (v Addr) RoundDown() Addr { - return v & ^Addr(PageSize-1) -} - -// RoundUp returns the address rounded up to the nearest page boundary. ok is -// true iff rounding up did not wrap around. -func (v Addr) RoundUp() (addr Addr, ok bool) { - addr = Addr(v + PageSize - 1).RoundDown() - ok = addr >= v - return -} - -// MustRoundUp is equivalent to RoundUp, but panics if rounding up wraps -// around. -func (v Addr) MustRoundUp() Addr { - addr, ok := v.RoundUp() - if !ok { - panic(fmt.Sprintf("usermem.Addr(%d).RoundUp() wraps", v)) - } - return addr -} - -// HugeRoundDown returns the address rounded down to the nearest huge page -// boundary. -func (v Addr) HugeRoundDown() Addr { - return v & ^Addr(HugePageSize-1) -} - -// HugeRoundUp returns the address rounded up to the nearest huge page boundary. -// ok is true iff rounding up did not wrap around. -func (v Addr) HugeRoundUp() (addr Addr, ok bool) { - addr = Addr(v + HugePageSize - 1).HugeRoundDown() - ok = addr >= v - return -} - -// PageOffset returns the offset of v into the current page. -func (v Addr) PageOffset() uint64 { - return uint64(v & Addr(PageSize-1)) -} - -// IsPageAligned returns true if v.PageOffset() == 0. -func (v Addr) IsPageAligned() bool { - return v.PageOffset() == 0 -} - -// AddrRange is a range of Addrs. -// -// type AddrRange - -// ToRange returns [v, v+length). -func (v Addr) ToRange(length uint64) (AddrRange, bool) { - end, ok := v.AddLength(length) - return AddrRange{v, end}, ok -} - -// IsPageAligned returns true if ar.Start.IsPageAligned() and -// ar.End.IsPageAligned(). -func (ar AddrRange) IsPageAligned() bool { - return ar.Start.IsPageAligned() && ar.End.IsPageAligned() -} - -// String implements fmt.Stringer.String. -func (ar AddrRange) String() string { - return fmt.Sprintf("[%#x, %#x)", ar.Start, ar.End) -} diff --git a/pkg/sentry/usermem/addr_range_seq_test.go b/pkg/sentry/usermem/addr_range_seq_test.go deleted file mode 100644 index 82f735026..000000000 --- a/pkg/sentry/usermem/addr_range_seq_test.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "testing" -) - -var addrRangeSeqTests = []struct { - desc string - ranges []AddrRange -}{ - { - desc: "Empty sequence", - }, - { - desc: "Single empty AddrRange", - ranges: []AddrRange{ - {0x10, 0x10}, - }, - }, - { - desc: "Single non-empty AddrRange of length 1", - ranges: []AddrRange{ - {0x10, 0x11}, - }, - }, - { - desc: "Single non-empty AddrRange of length 2", - ranges: []AddrRange{ - {0x10, 0x12}, - }, - }, - { - desc: "Multiple non-empty AddrRanges", - ranges: []AddrRange{ - {0x10, 0x11}, - {0x20, 0x22}, - }, - }, - { - desc: "Multiple AddrRanges including empty AddrRanges", - ranges: []AddrRange{ - {0x10, 0x10}, - {0x20, 0x20}, - {0x30, 0x33}, - {0x40, 0x44}, - {0x50, 0x50}, - {0x60, 0x60}, - {0x70, 0x77}, - {0x80, 0x88}, - {0x90, 0x90}, - {0xa0, 0xa0}, - }, - }, -} - -func testAddrRangeSeqEqualityWithTailIteration(t *testing.T, ars AddrRangeSeq, wantRanges []AddrRange) { - var wantLen int64 - for _, ar := range wantRanges { - wantLen += int64(ar.Length()) - } - - var i int - for !ars.IsEmpty() { - if gotLen := ars.NumBytes(); gotLen != wantLen { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) - } - if gotN, wantN := ars.NumRanges(), len(wantRanges)-i; gotN != wantN { - t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted %d", i, ars, gotN, wantN) - } - got := ars.Head() - if i >= len(wantRanges) { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted ", i, ars, got) - } else if want := wantRanges[i]; got != want { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) - } - ars = ars.Tail() - wantLen -= int64(got.Length()) - i++ - } - if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) - } - if gotN := ars.NumRanges(); gotN != 0 { - t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted 0", i, ars, gotN) - } -} - -func TestAddrRangeSeqTailIteration(t *testing.T) { - for _, test := range addrRangeSeqTests { - t.Run(test.desc, func(t *testing.T) { - testAddrRangeSeqEqualityWithTailIteration(t, AddrRangeSeqFromSlice(test.ranges), test.ranges) - }) - } -} - -func TestAddrRangeSeqDropFirstEmpty(t *testing.T) { - var ars AddrRangeSeq - if got, want := ars.DropFirst(1), ars; got != want { - t.Errorf("%v.DropFirst(1): got %v, wanted %v", ars, got, want) - } -} - -func TestAddrRangeSeqDropSingleByteIteration(t *testing.T) { - // Tests AddrRangeSeq iteration using Head/DropFirst, simulating - // I/O-per-AddrRange. - for _, test := range addrRangeSeqTests { - t.Run(test.desc, func(t *testing.T) { - // Figure out what AddrRanges we expect to see. - var wantLen int64 - var wantRanges []AddrRange - for _, ar := range test.ranges { - wantLen += int64(ar.Length()) - wantRanges = append(wantRanges, ar) - if ar.Length() == 0 { - // We "do" 0 bytes of I/O and then call DropFirst(0), - // advancing to the next AddrRange. - continue - } - // Otherwise we "do" 1 byte of I/O and then call DropFirst(1), - // advancing the AddrRange by 1 byte, or to the next AddrRange - // if this one is exhausted. - for ar.Start++; ar.Length() != 0; ar.Start++ { - wantRanges = append(wantRanges, ar) - } - } - t.Logf("Expected AddrRanges: %s (%d bytes)", wantRanges, wantLen) - - ars := AddrRangeSeqFromSlice(test.ranges) - var i int - for !ars.IsEmpty() { - if gotLen := ars.NumBytes(); gotLen != wantLen { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) - } - got := ars.Head() - if i >= len(wantRanges) { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted ", i, ars, got) - } else if want := wantRanges[i]; got != want { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) - } - if got.Length() == 0 { - ars = ars.DropFirst(0) - } else { - ars = ars.DropFirst(1) - wantLen-- - } - i++ - } - if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) - } - }) - } -} - -func TestAddrRangeSeqTakeFirstEmpty(t *testing.T) { - var ars AddrRangeSeq - if got, want := ars.TakeFirst(1), ars; got != want { - t.Errorf("%v.TakeFirst(1): got %v, wanted %v", ars, got, want) - } -} - -func TestAddrRangeSeqTakeFirst(t *testing.T) { - ranges := []AddrRange{ - {0x10, 0x11}, - {0x20, 0x22}, - {0x30, 0x30}, - {0x40, 0x44}, - {0x50, 0x55}, - {0x60, 0x60}, - {0x70, 0x77}, - } - ars := AddrRangeSeqFromSlice(ranges).TakeFirst(5) - want := []AddrRange{ - {0x10, 0x11}, // +1 byte (total 1 byte), not truncated - {0x20, 0x22}, // +2 bytes (total 3 bytes), not truncated - {0x30, 0x30}, // +0 bytes (total 3 bytes), no change - {0x40, 0x42}, // +2 bytes (total 5 bytes), partially truncated - {0x50, 0x50}, // +0 bytes (total 5 bytes), fully truncated - {0x60, 0x60}, // +0 bytes (total 5 bytes), "fully truncated" (no change) - {0x70, 0x70}, // +0 bytes (total 5 bytes), fully truncated - } - testAddrRangeSeqEqualityWithTailIteration(t, ars, want) -} diff --git a/pkg/sentry/usermem/addr_range_seq_unsafe.go b/pkg/sentry/usermem/addr_range_seq_unsafe.go deleted file mode 100644 index c09337c15..000000000 --- a/pkg/sentry/usermem/addr_range_seq_unsafe.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "bytes" - "fmt" - "reflect" - "unsafe" -) - -// An AddrRangeSeq represents a sequence of AddrRanges. -// -// AddrRangeSeqs are immutable and may be copied by value. The zero value of -// AddrRangeSeq represents an empty sequence. -// -// An AddrRangeSeq may contain AddrRanges with a length of 0. This is necessary -// since zero-length AddrRanges are significant to MM bounds checks. -type AddrRangeSeq struct { - // If length is 0, then the AddrRangeSeq represents no AddrRanges. - // Invariants: data == 0; offset == 0; limit == 0. - // - // If length is 1, then the AddrRangeSeq represents the single - // AddrRange{offset, offset+limit}. Invariants: data == 0. - // - // Otherwise, length >= 2, and the AddrRangeSeq represents the `length` - // AddrRanges in the array of AddrRanges starting at address `data`, - // starting at `offset` bytes into the first AddrRange and limited to the - // following `limit` bytes. (AddrRanges after `limit` are still iterated, - // but are truncated to a length of 0.) Invariants: data != 0; offset <= - // data[0].Length(); limit > 0; offset+limit <= the combined length of all - // AddrRanges in the array. - data unsafe.Pointer - length int - offset Addr - limit Addr -} - -// AddrRangeSeqOf returns an AddrRangeSeq representing the single AddrRange ar. -func AddrRangeSeqOf(ar AddrRange) AddrRangeSeq { - return AddrRangeSeq{ - length: 1, - offset: ar.Start, - limit: ar.Length(), - } -} - -// AddrRangeSeqFromSlice returns an AddrRangeSeq representing all AddrRanges in -// slice. -// -// Whether the returned AddrRangeSeq shares memory with slice is unspecified; -// clients should avoid mutating slices passed to AddrRangeSeqFromSlice. -// -// Preconditions: The combined length of all AddrRanges in slice <= -// math.MaxInt64. -func AddrRangeSeqFromSlice(slice []AddrRange) AddrRangeSeq { - var limit int64 - for _, ar := range slice { - len64 := int64(ar.Length()) - if len64 < 0 { - panic(fmt.Sprintf("Length of AddrRange %v overflows int64", ar)) - } - sum := limit + len64 - if sum < limit { - panic(fmt.Sprintf("Total length of AddrRanges %v overflows int64", slice)) - } - limit = sum - } - return addrRangeSeqFromSliceLimited(slice, limit) -} - -// Preconditions: The combined length of all AddrRanges in slice <= limit. -// limit >= 0. If len(slice) != 0, then limit > 0. -func addrRangeSeqFromSliceLimited(slice []AddrRange, limit int64) AddrRangeSeq { - switch len(slice) { - case 0: - return AddrRangeSeq{} - case 1: - return AddrRangeSeq{ - length: 1, - offset: slice[0].Start, - limit: Addr(limit), - } - default: - return AddrRangeSeq{ - data: unsafe.Pointer(&slice[0]), - length: len(slice), - limit: Addr(limit), - } - } -} - -// IsEmpty returns true if ars.NumRanges() == 0. -// -// Note that since AddrRangeSeq may contain AddrRanges with a length of zero, -// an AddrRange representing 0 bytes (AddrRangeSeq.NumBytes() == 0) is not -// necessarily empty. -func (ars AddrRangeSeq) IsEmpty() bool { - return ars.length == 0 -} - -// NumRanges returns the number of AddrRanges in ars. -func (ars AddrRangeSeq) NumRanges() int { - return ars.length -} - -// NumBytes returns the number of bytes represented by ars. -func (ars AddrRangeSeq) NumBytes() int64 { - return int64(ars.limit) -} - -// Head returns the first AddrRange in ars. -// -// Preconditions: !ars.IsEmpty(). -func (ars AddrRangeSeq) Head() AddrRange { - if ars.length == 0 { - panic("empty AddrRangeSeq") - } - if ars.length == 1 { - return AddrRange{ars.offset, ars.offset + ars.limit} - } - ar := *(*AddrRange)(ars.data) - ar.Start += ars.offset - if ar.Length() > ars.limit { - ar.End = ar.Start + ars.limit - } - return ar -} - -// Tail returns an AddrRangeSeq consisting of all AddrRanges in ars after the -// first. -// -// Preconditions: !ars.IsEmpty(). -func (ars AddrRangeSeq) Tail() AddrRangeSeq { - if ars.length == 0 { - panic("empty AddrRangeSeq") - } - if ars.length == 1 { - return AddrRangeSeq{} - } - return ars.externalTail() -} - -// Preconditions: ars.length >= 2. -func (ars AddrRangeSeq) externalTail() AddrRangeSeq { - headLen := (*AddrRange)(ars.data).Length() - ars.offset - var tailLimit int64 - if ars.limit > headLen { - tailLimit = int64(ars.limit - headLen) - } - var extSlice []AddrRange - extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice)) - extSliceHdr.Data = uintptr(ars.data) - extSliceHdr.Len = ars.length - extSliceHdr.Cap = ars.length - return addrRangeSeqFromSliceLimited(extSlice[1:], tailLimit) -} - -// DropFirst returns an AddrRangeSeq equivalent to ars, but with the first n -// bytes omitted. If n > ars.NumBytes(), DropFirst returns an empty -// AddrRangeSeq. -// -// If !ars.IsEmpty() and ars.Head().Length() == 0, DropFirst will always omit -// at least ars.Head(), even if n == 0. This guarantees that the basic pattern -// of: -// -// for !ars.IsEmpty() { -// n, err = doIOWith(ars.Head()) -// if err != nil { -// return err -// } -// ars = ars.DropFirst(n) -// } -// -// works even in the presence of zero-length AddrRanges. -// -// Preconditions: n >= 0. -func (ars AddrRangeSeq) DropFirst(n int) AddrRangeSeq { - if n < 0 { - panic(fmt.Sprintf("invalid n: %d", n)) - } - return ars.DropFirst64(int64(n)) -} - -// DropFirst64 is equivalent to DropFirst but takes an int64. -func (ars AddrRangeSeq) DropFirst64(n int64) AddrRangeSeq { - if n < 0 { - panic(fmt.Sprintf("invalid n: %d", n)) - } - if Addr(n) > ars.limit { - return AddrRangeSeq{} - } - // Handle initial empty AddrRange. - switch ars.length { - case 0: - return AddrRangeSeq{} - case 1: - if ars.limit == 0 { - return AddrRangeSeq{} - } - default: - if rawHeadLen := (*AddrRange)(ars.data).Length(); ars.offset == rawHeadLen { - ars = ars.externalTail() - } - } - for n != 0 { - // Calling ars.Head() here is surprisingly expensive, so inline getting - // the head's length. - var headLen Addr - if ars.length == 1 { - headLen = ars.limit - } else { - headLen = (*AddrRange)(ars.data).Length() - ars.offset - } - if Addr(n) < headLen { - // Dropping ends partway through the head AddrRange. - ars.offset += Addr(n) - ars.limit -= Addr(n) - return ars - } - n -= int64(headLen) - ars = ars.Tail() - } - return ars -} - -// TakeFirst returns an AddrRangeSeq equivalent to ars, but iterating at most n -// bytes. TakeFirst never removes AddrRanges from ars; AddrRanges beyond the -// first n bytes are reduced to a length of zero, but will still be iterated. -// -// Preconditions: n >= 0. -func (ars AddrRangeSeq) TakeFirst(n int) AddrRangeSeq { - if n < 0 { - panic(fmt.Sprintf("invalid n: %d", n)) - } - return ars.TakeFirst64(int64(n)) -} - -// TakeFirst64 is equivalent to TakeFirst but takes an int64. -func (ars AddrRangeSeq) TakeFirst64(n int64) AddrRangeSeq { - if n < 0 { - panic(fmt.Sprintf("invalid n: %d", n)) - } - if ars.limit > Addr(n) { - ars.limit = Addr(n) - } - return ars -} - -// String implements fmt.Stringer.String. -func (ars AddrRangeSeq) String() string { - // This is deliberately chosen to be the same as fmt's automatic stringer - // for []AddrRange. - var buf bytes.Buffer - buf.WriteByte('[') - var sep string - for !ars.IsEmpty() { - buf.WriteString(sep) - sep = " " - buf.WriteString(ars.Head().String()) - ars = ars.Tail() - } - buf.WriteByte(']') - return buf.String() -} diff --git a/pkg/sentry/usermem/bytes_io.go b/pkg/sentry/usermem/bytes_io.go deleted file mode 100644 index 7898851b3..000000000 --- a/pkg/sentry/usermem/bytes_io.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/syserror" -) - -const maxInt = int(^uint(0) >> 1) - -// BytesIO implements IO using a byte slice. Addresses are interpreted as -// offsets into the slice. Reads and writes beyond the end of the slice return -// EFAULT. -type BytesIO struct { - Bytes []byte -} - -// CopyOut implements IO.CopyOut. -func (b *BytesIO) CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) { - rngN, rngErr := b.rangeCheck(addr, len(src)) - if rngN == 0 { - return 0, rngErr - } - return copy(b.Bytes[int(addr):], src[:rngN]), rngErr -} - -// CopyIn implements IO.CopyIn. -func (b *BytesIO) CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) { - rngN, rngErr := b.rangeCheck(addr, len(dst)) - if rngN == 0 { - return 0, rngErr - } - return copy(dst[:rngN], b.Bytes[int(addr):]), rngErr -} - -// ZeroOut implements IO.ZeroOut. -func (b *BytesIO) ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) { - if toZero > int64(maxInt) { - return 0, syserror.EINVAL - } - rngN, rngErr := b.rangeCheck(addr, int(toZero)) - if rngN == 0 { - return 0, rngErr - } - zeroSlice := b.Bytes[int(addr) : int(addr)+rngN] - for i := range zeroSlice { - zeroSlice[i] = 0 - } - return int64(rngN), rngErr -} - -// CopyOutFrom implements IO.CopyOutFrom. -func (b *BytesIO) CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) { - dsts, rngErr := b.blocksFromAddrRanges(ars) - n, err := src.ReadToBlocks(dsts) - if err != nil { - return int64(n), err - } - return int64(n), rngErr -} - -// CopyInTo implements IO.CopyInTo. -func (b *BytesIO) CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) { - srcs, rngErr := b.blocksFromAddrRanges(ars) - n, err := dst.WriteFromBlocks(srcs) - if err != nil { - return int64(n), err - } - return int64(n), rngErr -} - -func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) { - if length == 0 { - return 0, nil - } - if length < 0 { - return 0, syserror.EINVAL - } - max := Addr(len(b.Bytes)) - if addr >= max { - return 0, syserror.EFAULT - } - end, ok := addr.AddLength(uint64(length)) - if !ok || end > max { - return int(max - addr), syserror.EFAULT - } - return length, nil -} - -func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) { - switch ars.NumRanges() { - case 0: - return safemem.BlockSeq{}, nil - case 1: - block, err := b.blockFromAddrRange(ars.Head()) - return safemem.BlockSeqOf(block), err - default: - blocks := make([]safemem.Block, 0, ars.NumRanges()) - for !ars.IsEmpty() { - block, err := b.blockFromAddrRange(ars.Head()) - if block.Len() != 0 { - blocks = append(blocks, block) - } - if err != nil { - return safemem.BlockSeqFromSlice(blocks), err - } - ars = ars.Tail() - } - return safemem.BlockSeqFromSlice(blocks), nil - } -} - -func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) { - n, err := b.rangeCheck(ar.Start, int(ar.Length())) - if n == 0 { - return safemem.Block{}, err - } - return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err -} - -// BytesIOSequence returns an IOSequence representing the given byte slice. -func BytesIOSequence(buf []byte) IOSequence { - return IOSequence{ - IO: &BytesIO{buf}, - Addrs: AddrRangeSeqOf(AddrRange{0, Addr(len(buf))}), - } -} diff --git a/pkg/sentry/usermem/bytes_io_unsafe.go b/pkg/sentry/usermem/bytes_io_unsafe.go deleted file mode 100644 index fca5952f4..000000000 --- a/pkg/sentry/usermem/bytes_io_unsafe.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "sync/atomic" - "unsafe" - - "gvisor.dev/gvisor/pkg/atomicbitops" - "gvisor.dev/gvisor/pkg/sentry/context" -) - -// SwapUint32 implements IO.SwapUint32. -func (b *BytesIO) SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) { - if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil { - return 0, rngErr - } - return atomic.SwapUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)])), new), nil -} - -// CompareAndSwapUint32 implements IO.CompareAndSwapUint32. -func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) { - if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil { - return 0, rngErr - } - return atomicbitops.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)])), old, new), nil -} - -// LoadUint32 implements IO.LoadUint32. -func (b *BytesIO) LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) { - if _, err := b.rangeCheck(addr, 4); err != nil { - return 0, err - } - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)]))), nil -} diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go deleted file mode 100644 index 7b1f312b1..000000000 --- a/pkg/sentry/usermem/usermem.go +++ /dev/null @@ -1,597 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package usermem governs access to user memory. -package usermem - -import ( - "bytes" - "errors" - "io" - "strconv" - - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/syserror" -) - -// IO provides access to the contents of a virtual memory space. -// -// FIXME(b/38173783): Implementations of IO cannot expect ctx to contain any -// meaningful data. -type IO interface { - // CopyOut copies len(src) bytes from src to the memory mapped at addr. It - // returns the number of bytes copied. If the number of bytes copied is < - // len(src), it returns a non-nil error explaining why. - // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. - // - // Postconditions: CopyOut does not retain src. - CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) - - // CopyIn copies len(dst) bytes from the memory mapped at addr to dst. - // It returns the number of bytes copied. If the number of bytes copied is - // < len(dst), it returns a non-nil error explaining why. - // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. - // - // Postconditions: CopyIn does not retain dst. - CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) - - // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number - // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a - // non-nil error explaining why. - // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. toZero >= 0. - ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) - - // CopyOutFrom copies ars.NumBytes() bytes from src to the memory mapped at - // ars. It returns the number of bytes copied, which may be less than the - // number of bytes read from src if copying fails. CopyOutFrom may return a - // partial copy without an error iff src.ReadToBlocks returns a partial - // read without an error. - // - // CopyOutFrom calls src.ReadToBlocks at most once. - // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. src.ReadToBlocks must not block - // on mm.MemoryManager.activeMu or any preceding locks in the lock order. - CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) - - // CopyInTo copies ars.NumBytes() bytes from the memory mapped at ars to - // dst. It returns the number of bytes copied. CopyInTo may return a - // partial copy without an error iff dst.WriteFromBlocks returns a partial - // write without an error. - // - // CopyInTo calls dst.WriteFromBlocks at most once. - // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. dst.WriteFromBlocks must not - // block on mm.MemoryManager.activeMu or any preceding locks in the lock - // order. - CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) - - // TODO(jamieliu): The requirement that CopyOutFrom/CopyInTo call src/dst - // at most once, which is unnecessary in most cases, forces implementations - // to gather safemem.Blocks into a single slice to pass to src/dst. Add - // CopyOutFromIter/CopyInToIter, which relaxes this restriction, to avoid - // this allocation. - - // SwapUint32 atomically sets the uint32 value at addr to new and - // returns the previous value. - // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. - SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) - - // CompareAndSwapUint32 atomically compares the uint32 value at addr to - // old; if they are equal, the value in memory is replaced by new. In - // either case, the previous value stored in memory is returned. - // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. - CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) - - // LoadUint32 atomically loads the uint32 value at addr and returns it. - // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. - LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) -} - -// IOOpts contains options applicable to all IO methods. -type IOOpts struct { - // If IgnorePermissions is true, application-defined memory protections set - // by mmap(2) or mprotect(2) will be ignored. (Memory protections required - // by the target of the mapping are never ignored.) - IgnorePermissions bool - - // If AddressSpaceActive is true, the IO implementation may assume that it - // has an active AddressSpace and can therefore use AddressSpace copying - // without performing activation. See mm/io.go for details. - AddressSpaceActive bool -} - -// IOReadWriter is an io.ReadWriter that reads from / writes to addresses -// starting at addr in IO. The preconditions that apply to IO.CopyIn and -// IO.CopyOut also apply to IOReadWriter.Read and IOReadWriter.Write -// respectively. -type IOReadWriter struct { - Ctx context.Context - IO IO - Addr Addr - Opts IOOpts -} - -// Read implements io.Reader.Read. -// -// Note that an address space does not have an "end of file", so Read can only -// return io.EOF if IO.CopyIn returns io.EOF. Attempts to read unmapped or -// unreadable memory, or beyond the end of the address space, should return -// EFAULT. -func (rw *IOReadWriter) Read(dst []byte) (int, error) { - n, err := rw.IO.CopyIn(rw.Ctx, rw.Addr, dst, rw.Opts) - end, ok := rw.Addr.AddLength(uint64(n)) - if ok { - rw.Addr = end - } else { - // Disallow wraparound. - rw.Addr = ^Addr(0) - if err != nil { - err = syserror.EFAULT - } - } - return n, err -} - -// Writer implements io.Writer.Write. -func (rw *IOReadWriter) Write(src []byte) (int, error) { - n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts) - end, ok := rw.Addr.AddLength(uint64(n)) - if ok { - rw.Addr = end - } else { - // Disallow wraparound. - rw.Addr = ^Addr(0) - if err != nil { - err = syserror.EFAULT - } - } - return n, err -} - -// CopyObjectOut copies a fixed-size value or slice of fixed-size values from -// src to the memory mapped at addr in uio. It returns the number of bytes -// copied. -// -// CopyObjectOut must use reflection to encode src; performance-sensitive -// clients should do encoding manually and use uio.CopyOut directly. -// -// Preconditions: As for IO.CopyOut. -func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) { - w := &IOReadWriter{ - Ctx: ctx, - IO: uio, - Addr: addr, - Opts: opts, - } - // Allocate a byte slice the size of the object being marshaled. This - // adds an extra reflection call, but avoids needing to grow the slice - // during encoding, which can result in many heap-allocated slices. - b := make([]byte, 0, binary.Size(src)) - return w.Write(binary.Marshal(b, ByteOrder, src)) -} - -// CopyObjectIn copies a fixed-size value or slice of fixed-size values from -// the memory mapped at addr in uio to dst. It returns the number of bytes -// copied. -// -// CopyObjectIn must use reflection to decode dst; performance-sensitive -// clients should use uio.CopyIn directly and do decoding manually. -// -// Preconditions: As for IO.CopyIn. -func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) { - r := &IOReadWriter{ - Ctx: ctx, - IO: uio, - Addr: addr, - Opts: opts, - } - buf := make([]byte, binary.Size(dst)) - if _, err := io.ReadFull(r, buf); err != nil { - return 0, err - } - binary.Unmarshal(buf, ByteOrder, dst) - return int(r.Addr - addr), nil -} - -// CopyStringIn tuning parameters, defined outside that function for tests. -const ( - copyStringIncrement = 64 - copyStringMaxInitBufLen = 256 -) - -// CopyStringIn copies a NUL-terminated string of unknown length from the -// memory mapped at addr in uio and returns it as a string (not including the -// trailing NUL). If the length of the string, including the terminating NUL, -// would exceed maxlen, CopyStringIn returns the string truncated to maxlen and -// ENAMETOOLONG. -// -// Preconditions: As for IO.CopyFromUser. maxlen >= 0. -func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) { - initLen := maxlen - if initLen > copyStringMaxInitBufLen { - initLen = copyStringMaxInitBufLen - } - buf := make([]byte, initLen) - var done int - for done < maxlen { - // Read up to copyStringIncrement bytes at a time. - readlen := copyStringIncrement - if readlen > maxlen-done { - readlen = maxlen - done - } - end, ok := addr.AddLength(uint64(readlen)) - if !ok { - return stringFromImmutableBytes(buf[:done]), syserror.EFAULT - } - // Shorten the read to avoid crossing page boundaries, since faulting - // in a page unnecessarily is expensive. This also ensures that partial - // copies up to the end of application-mappable memory succeed. - if addr.RoundDown() != end.RoundDown() { - end = end.RoundDown() - readlen = int(end - addr) - } - // Ensure that our buffer is large enough to accommodate the read. - if done+readlen > len(buf) { - newBufLen := len(buf) * 2 - if newBufLen > maxlen { - newBufLen = maxlen - } - buf = append(buf, make([]byte, newBufLen-len(buf))...) - } - n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts) - // Look for the terminating zero byte, which may have occurred before - // hitting err. - if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 { - return stringFromImmutableBytes(buf[:done+i]), nil - } - - done += n - if err != nil { - return stringFromImmutableBytes(buf[:done]), err - } - addr = end - } - return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG -} - -// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The -// maximum number of bytes copied is ars.NumBytes() or len(src), whichever is -// less. CopyOutVec returns the number of bytes copied; if this is less than -// the maximum, it returns a non-nil error explaining why. -// -// Preconditions: As for IO.CopyOut. -func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts IOOpts) (int, error) { - var done int - for !ars.IsEmpty() && done < len(src) { - ar := ars.Head() - cplen := len(src) - done - if Addr(cplen) >= ar.Length() { - cplen = int(ar.Length()) - } - n, err := uio.CopyOut(ctx, ar.Start, src[done:done+cplen], opts) - done += n - if err != nil { - return done, err - } - ars = ars.DropFirst(n) - } - return done, nil -} - -// CopyInVec copies bytes from the memory mapped at ars in uio to dst. The -// maximum number of bytes copied is ars.NumBytes() or len(dst), whichever is -// less. CopyInVec returns the number of bytes copied; if this is less than the -// maximum, it returns a non-nil error explaining why. -// -// Preconditions: As for IO.CopyIn. -func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts IOOpts) (int, error) { - var done int - for !ars.IsEmpty() && done < len(dst) { - ar := ars.Head() - cplen := len(dst) - done - if Addr(cplen) >= ar.Length() { - cplen = int(ar.Length()) - } - n, err := uio.CopyIn(ctx, ar.Start, dst[done:done+cplen], opts) - done += n - if err != nil { - return done, err - } - ars = ars.DropFirst(n) - } - return done, nil -} - -// ZeroOutVec writes zeroes to the memory mapped at ars in uio. The maximum -// number of bytes written is ars.NumBytes() or toZero, whichever is less. -// ZeroOutVec returns the number of bytes written; if this is less than the -// maximum, it returns a non-nil error explaining why. -// -// Preconditions: As for IO.ZeroOut. -func ZeroOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) { - var done int64 - for !ars.IsEmpty() && done < toZero { - ar := ars.Head() - cplen := toZero - done - if Addr(cplen) >= ar.Length() { - cplen = int64(ar.Length()) - } - n, err := uio.ZeroOut(ctx, ar.Start, cplen, opts) - done += n - if err != nil { - return done, err - } - ars = ars.DropFirst64(n) - } - return done, nil -} - -func isASCIIWhitespace(b byte) bool { - // Compare Linux include/linux/ctype.h, lib/ctype.c. - // 9 => horizontal tab '\t' - // 10 => line feed '\n' - // 11 => vertical tab '\v' - // 12 => form feed '\c' - // 13 => carriage return '\r' - return b == ' ' || (b >= 9 && b <= 13) -} - -// CopyInt32StringsInVec copies up to len(dsts) whitespace-separated decimal -// strings from the memory mapped at ars in uio and converts them to int32 -// values in dsts. It returns the number of bytes read. -// -// CopyInt32StringsInVec shares the following properties with Linux's -// kernel/sysctl.c:proc_dointvec(write=1): -// -// - If any read value overflows the range of int32, or any invalid characters -// are encountered during the read, CopyInt32StringsInVec returns EINVAL. -// -// - If, upon reaching the end of ars, fewer than len(dsts) values have been -// read, CopyInt32StringsInVec returns no error if at least 1 value was read -// and EINVAL otherwise. -// -// - Trailing whitespace after the last successfully read value is counted in -// the number of bytes read. -// -// Unlike proc_dointvec(): -// -// - CopyInt32StringsInVec does not implicitly limit ars.NumBytes() to -// PageSize-1; callers that require this must do so explicitly. -// -// - CopyInt32StringsInVec returns EINVAL if ars.NumBytes() == 0. -// -// Preconditions: As for CopyInVec. -func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) { - if len(dsts) == 0 { - return 0, nil - } - - buf := make([]byte, ars.NumBytes()) - n, cperr := CopyInVec(ctx, uio, ars, buf, opts) - buf = buf[:n] - - var i, j int - for ; j < len(dsts); j++ { - // Skip leading whitespace. - for i < len(buf) && isASCIIWhitespace(buf[i]) { - i++ - } - if i == len(buf) { - break - } - - // Find the end of the value to be parsed (next whitespace or end of string). - nextI := i + 1 - for nextI < len(buf) && !isASCIIWhitespace(buf[nextI]) { - nextI++ - } - - // Parse a single value. - val, err := strconv.ParseInt(string(buf[i:nextI]), 10, 32) - if err != nil { - return int64(i), syserror.EINVAL - } - dsts[j] = int32(val) - - i = nextI - } - - // Skip trailing whitespace. - for i < len(buf) && isASCIIWhitespace(buf[i]) { - i++ - } - - if cperr != nil { - return int64(i), cperr - } - if j == 0 { - return int64(i), syserror.EINVAL - } - return int64(i), nil -} - -// CopyInt32StringInVec is equivalent to CopyInt32StringsInVec, but copies at -// most one int32. -func CopyInt32StringInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst *int32, opts IOOpts) (int64, error) { - dsts := [1]int32{*dst} - n, err := CopyInt32StringsInVec(ctx, uio, ars, dsts[:], opts) - *dst = dsts[0] - return n, err -} - -// IOSequence holds arguments to IO methods. -type IOSequence struct { - IO IO - Addrs AddrRangeSeq - Opts IOOpts -} - -// NumBytes returns s.Addrs.NumBytes(). -// -// Note that NumBytes() may return 0 even if !s.Addrs.IsEmpty(), since -// s.Addrs may contain a non-zero number of zero-length AddrRanges. -// Many clients of -// IOSequence currently do something like: -// -// if ioseq.NumBytes() == 0 { -// return 0, nil -// } -// if f.availableBytes == 0 { -// return 0, syserror.ErrWouldBlock -// } -// return ioseq.CopyOutFrom(..., reader) -// -// In such cases, using s.Addrs.IsEmpty() will cause them to have the wrong -// behavior for zero-length I/O. However, using s.NumBytes() == 0 instead means -// that we will return success for zero-length I/O in cases where Linux would -// return EFAULT due to a failed access_ok() check, so in the long term we -// should move checks for ErrWouldBlock etc. into the body of -// reader.ReadToBlocks and use s.Addrs.IsEmpty() instead. -func (s IOSequence) NumBytes() int64 { - return s.Addrs.NumBytes() -} - -// DropFirst returns a copy of s with s.Addrs.DropFirst(n). -// -// Preconditions: As for AddrRangeSeq.DropFirst. -func (s IOSequence) DropFirst(n int) IOSequence { - return IOSequence{s.IO, s.Addrs.DropFirst(n), s.Opts} -} - -// DropFirst64 returns a copy of s with s.Addrs.DropFirst64(n). -// -// Preconditions: As for AddrRangeSeq.DropFirst64. -func (s IOSequence) DropFirst64(n int64) IOSequence { - return IOSequence{s.IO, s.Addrs.DropFirst64(n), s.Opts} -} - -// TakeFirst returns a copy of s with s.Addrs.TakeFirst(n). -// -// Preconditions: As for AddrRangeSeq.TakeFirst. -func (s IOSequence) TakeFirst(n int) IOSequence { - return IOSequence{s.IO, s.Addrs.TakeFirst(n), s.Opts} -} - -// TakeFirst64 returns a copy of s with s.Addrs.TakeFirst64(n). -// -// Preconditions: As for AddrRangeSeq.TakeFirst64. -func (s IOSequence) TakeFirst64(n int64) IOSequence { - return IOSequence{s.IO, s.Addrs.TakeFirst64(n), s.Opts} -} - -// CopyOut invokes CopyOutVec over s.Addrs. -// -// As with CopyOutVec, if s.NumBytes() < len(src), the copy will be truncated -// to s.NumBytes(), and a nil error will be returned. -// -// Preconditions: As for CopyOutVec. -func (s IOSequence) CopyOut(ctx context.Context, src []byte) (int, error) { - return CopyOutVec(ctx, s.IO, s.Addrs, src, s.Opts) -} - -// CopyIn invokes CopyInVec over s.Addrs. -// -// As with CopyInVec, if s.NumBytes() < len(dst), the copy will be truncated to -// s.NumBytes(), and a nil error will be returned. -// -// Preconditions: As for CopyInVec. -func (s IOSequence) CopyIn(ctx context.Context, dst []byte) (int, error) { - return CopyInVec(ctx, s.IO, s.Addrs, dst, s.Opts) -} - -// ZeroOut invokes ZeroOutVec over s.Addrs. -// -// As with ZeroOutVec, if s.NumBytes() < toZero, the write will be truncated -// to s.NumBytes(), and a nil error will be returned. -// -// Preconditions: As for ZeroOutVec. -func (s IOSequence) ZeroOut(ctx context.Context, toZero int64) (int64, error) { - return ZeroOutVec(ctx, s.IO, s.Addrs, toZero, s.Opts) -} - -// CopyOutFrom invokes s.CopyOutFrom over s.Addrs. -// -// Preconditions: As for IO.CopyOutFrom. -func (s IOSequence) CopyOutFrom(ctx context.Context, src safemem.Reader) (int64, error) { - return s.IO.CopyOutFrom(ctx, s.Addrs, src, s.Opts) -} - -// CopyInTo invokes s.CopyInTo over s.Addrs. -// -// Preconditions: As for IO.CopyInTo. -func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, error) { - return s.IO.CopyInTo(ctx, s.Addrs, dst, s.Opts) -} - -// Reader returns an io.Reader that reads from s. Reads beyond the end of s -// return io.EOF. The preconditions that apply to s.CopyIn also apply to the -// returned io.Reader.Read. -func (s IOSequence) Reader(ctx context.Context) io.Reader { - return &ioSequenceReadWriter{ctx, s} -} - -// Writer returns an io.Writer that writes to s. Writes beyond the end of s -// return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also -// apply to the returned io.Writer.Write. -func (s IOSequence) Writer(ctx context.Context) io.Writer { - return &ioSequenceReadWriter{ctx, s} -} - -// ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when -// attempting to write beyond the end of the IOSequence. -var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence") - -type ioSequenceReadWriter struct { - ctx context.Context - s IOSequence -} - -// Read implements io.Reader.Read. -func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { - n, err := rw.s.CopyIn(rw.ctx, dst) - rw.s = rw.s.DropFirst(n) - if err == nil && rw.s.NumBytes() == 0 { - err = io.EOF - } - return n, err -} - -// Write implements io.Writer.Write. -func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) { - n, err := rw.s.CopyOut(rw.ctx, src) - rw.s = rw.s.DropFirst(n) - if err == nil && n < len(src) { - err = ErrEndOfIOSequence - } - return n, err -} diff --git a/pkg/sentry/usermem/usermem_arm64.go b/pkg/sentry/usermem/usermem_arm64.go deleted file mode 100644 index fdfc30a66..000000000 --- a/pkg/sentry/usermem/usermem_arm64.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package usermem - -import ( - "encoding/binary" - "syscall" -) - -const ( - // PageSize is the system page size. - // arm64 support 4K/16K/64K page size, - // which can be get by syscall.Getpagesize(). - // Currently, only 4K page size is supported. - PageSize = 1 << PageShift - - // HugePageSize is the system huge page size. - HugePageSize = 1 << HugePageShift - - // PageShift is the binary log of the system page size. - PageShift = 12 - - // HugePageShift is the binary log of the system huge page size. - // Should be calculated by "PageShift + (PageShift - 3)" - // when multiple page size support is ready. - HugePageShift = 21 -) - -var ( - // ByteOrder is the native byte order (little endian). - ByteOrder = binary.LittleEndian -) - -func init() { - // Make sure the page size is 4K on arm64 platform. - if size := syscall.Getpagesize(); size != PageSize { - panic("Only 4K page size is supported on arm64!") - } -} diff --git a/pkg/sentry/usermem/usermem_test.go b/pkg/sentry/usermem/usermem_test.go deleted file mode 100644 index 299f64754..000000000 --- a/pkg/sentry/usermem/usermem_test.go +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "bytes" - "encoding/binary" - "fmt" - "reflect" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/syserror" -) - -// newContext returns a context.Context that we can use in these tests (we -// can't use contexttest because it depends on usermem). -func newContext() context.Context { - return context.Background() -} - -func newBytesIOString(s string) *BytesIO { - return &BytesIO{[]byte(s)} -} - -func TestBytesIOCopyOutSuccess(t *testing.T) { - b := newBytesIOString("ABCDE") - n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFailure(t *testing.T) { - b := newBytesIOString("ABC") - n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) - if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInSuccess(t *testing.T) { - b := newBytesIOString("AfooE") - var dst [3]byte - n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInFailure(t *testing.T) { - b := newBytesIOString("Afo") - var dst [3]byte - n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) - if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestBytesIOZeroOutSuccess(t *testing.T) { - b := newBytesIOString("ABCD") - n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{}) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOZeroOutFailure(t *testing.T) { - b := newBytesIOString("ABC") - n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{}) - if wantN, wantErr := int64(2), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFromSuccess(t *testing.T) { - b := newBytesIOString("ABCDEFGH") - n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 4, End: 7}, - {Start: 1, End: 4}, - }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{}) - if wantN := int64(6); n != wantN || err != nil { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFromFailure(t *testing.T) { - b := newBytesIOString("ABCDE") - n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 1, End: 4}, - {Start: 4, End: 7}, - }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{}) - if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInToSuccess(t *testing.T) { - b := newBytesIOString("AfoobarH") - var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 4, End: 7}, - {Start: 1, End: 4}, - }), safemem.FromIOWriter{&dst}, IOOpts{}) - if wantN := int64(6); n != wantN || err != nil { - t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) { - t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInToFailure(t *testing.T) { - b := newBytesIOString("Afoob") - var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 1, End: 4}, - {Start: 4, End: 7}, - }), safemem.FromIOWriter{&dst}, IOOpts{}) - if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { - t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) - } -} - -type testStruct struct { - Int8 int8 - Uint8 uint8 - Int16 int16 - Uint16 uint16 - Int32 int32 - Uint32 uint32 - Int64 int64 - Uint64 uint64 -} - -func TestCopyObject(t *testing.T) { - wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8} - wantN := binary.Size(wantObj) - b := &BytesIO{make([]byte, wantN)} - ctx := newContext() - if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil { - t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - var gotObj testStruct - if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil { - t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if gotObj != wantObj { - t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj) - } -} - -func TestCopyStringInShort(t *testing.T) { - // Tests for string length <= copyStringIncrement. - want := strings.Repeat("A", copyStringIncrement-2) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInLong(t *testing.T) { - // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen - // (requiring multiple calls to IO.CopyIn()). - want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInVeryLong(t *testing.T) { - // Tests for string length > copyStringMaxInitBufLen (requiring buffer - // reallocation). - want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInNoTerminatingZeroByte(t *testing.T) { - want := strings.Repeat("A", copyStringIncrement-1) - got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{}) - if wantErr := syserror.EFAULT; got != want || err != wantErr { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) - } -} - -func TestCopyStringInTruncatedByMaxlen(t *testing.T) { - got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{}) - if want, wantErr := strings.Repeat("A", 5), syserror.ENAMETOOLONG; got != want || err != wantErr { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) - } -} - -func TestCopyInt32StringsInVec(t *testing.T) { - for _, test := range []struct { - str string - n int - initial []int32 - final []int32 - }{ - { - str: "100 200", - n: len("100 200"), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - { - // Fewer values ok - str: "100", - n: len("100"), - initial: []int32{1, 2}, - final: []int32{100, 2}, - }, - { - // Extra values ok - str: "100 200 300", - n: len("100 200 "), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - { - // Leading and trailing whitespace ok - str: " 100\t200\n", - n: len(" 100\t200\n"), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - } { - t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { - src := BytesIOSequence([]byte(test.str)) - dsts := append([]int32(nil), test.initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n) - } - if !reflect.DeepEqual(dsts, test.final) { - t.Errorf("dsts: got %v, wanted %v", dsts, test.final) - } - }) - } -} - -func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { - for _, s := range []string{"", "\n", "a123"} { - t.Run(fmt.Sprintf("%q", s), func(t *testing.T) { - src := BytesIOSequence([]byte(s)) - initial := []int32{1, 2} - dsts := append([]int32(nil), initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL) - } - if !reflect.DeepEqual(dsts, initial) { - t.Errorf("dsts: got %v, wanted %v", dsts, initial) - } - }) - } -} - -func TestIOSequenceCopyOut(t *testing.T) { - buf := []byte("ABCD") - s := BytesIOSequence(buf) - - // CopyOut limited by len(src). - n, err := s.CopyOut(newContext(), []byte("fo")) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foCD"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // CopyOut limited by s.NumBytes(). - n, err = s.CopyOut(newContext(), []byte("obar")) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foob"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceCopyIn(t *testing.T) { - s := BytesIOSequence([]byte("foob")) - dst := []byte("ABCDEF") - - // CopyIn limited by len(dst). - n, err := s.CopyIn(newContext(), dst[:2]) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foCDEF"); !bytes.Equal(dst, want) { - t.Errorf("dst: got %q, wanted %q", dst, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // CopyIn limited by s.Remaining(). - n, err = s.CopyIn(newContext(), dst[2:]) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foobEF"); !bytes.Equal(dst, want) { - t.Errorf("dst: got %q, wanted %q", dst, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceZeroOut(t *testing.T) { - buf := []byte("ABCD") - s := BytesIOSequence(buf) - - // ZeroOut limited by toZero. - n, err := s.ZeroOut(newContext(), 2) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // ZeroOut limited by s.NumBytes(). - n, err = s.ZeroOut(newContext(), 4) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceTakeFirst(t *testing.T) { - s := BytesIOSequence([]byte("foobar")) - if got, want := s.NumBytes(), int64(6); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - s = s.TakeFirst(3) - if got, want := s.NumBytes(), int64(3); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // TakeFirst(n) where n > s.NumBytes() is a no-op. - s = s.TakeFirst(9) - if got, want := s.NumBytes(), int64(3); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - var dst [3]byte - n, err := s.CopyIn(newContext(), dst[:]) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } - s = s.DropFirst(3) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} diff --git a/pkg/sentry/usermem/usermem_unsafe.go b/pkg/sentry/usermem/usermem_unsafe.go deleted file mode 100644 index 876783e78..000000000 --- a/pkg/sentry/usermem/usermem_unsafe.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "unsafe" -) - -// stringFromImmutableBytes is equivalent to string(bs), except that it never -// copies even if escape analysis can't prove that bs does not escape. This is -// only valid if bs is never mutated after stringFromImmutableBytes returns. -func stringFromImmutableBytes(bs []byte) string { - // Compare strings.Builder.String(). - return *(*string)(unsafe.Pointer(&bs)) -} diff --git a/pkg/sentry/usermem/usermem_x86.go b/pkg/sentry/usermem/usermem_x86.go deleted file mode 100644 index 8059b72d2..000000000 --- a/pkg/sentry/usermem/usermem_x86.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 i386 - -package usermem - -import "encoding/binary" - -const ( - // PageSize is the system page size. - PageSize = 1 << PageShift - - // HugePageSize is the system huge page size. - HugePageSize = 1 << HugePageShift - - // PageShift is the binary log of the system page size. - PageShift = 12 - - // HugePageShift is the binary log of the system huge page size. - HugePageShift = 21 -) - -var ( - // ByteOrder is the native byte order (little endian). - ByteOrder = binary.LittleEndian -) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 51acdc4e9..6b1009328 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -26,14 +26,14 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fspath", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", "//pkg/waiter", ], ) @@ -48,11 +48,11 @@ go_test( library = ":vfs", deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", + "//pkg/context", + "//pkg/sentry/contexttest", "//pkg/sentry/kernel/auth", - "//pkg/sentry/usermem", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go index 705194ebc..d97362b9a 100644 --- a/pkg/sentry/vfs/context.go +++ b/pkg/sentry/vfs/context.go @@ -15,7 +15,7 @@ package vfs import ( - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" ) // contextID is this package's type for context.Context.Value keys. diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go index 9f9d6e783..3af2aa58d 100644 --- a/pkg/sentry/vfs/device.go +++ b/pkg/sentry/vfs/device.go @@ -17,7 +17,7 @@ package vfs import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 51c95c2d9..225024463 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -18,12 +18,12 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index c00b3c84b..fb9b87fdc 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -19,12 +19,12 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 9ed58512f..1720d325d 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -22,11 +22,11 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // fileDescription is the common fd struct which a filesystem implementation diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index ea78f555b..a06a6caf3 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -18,8 +18,8 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" ) // A Filesystem is a tree of nodes represented by Dentries, which forms part of diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go index 023301780..c58b70728 100644 --- a/pkg/sentry/vfs/filesystem_type.go +++ b/pkg/sentry/vfs/filesystem_type.go @@ -18,7 +18,7 @@ import ( "bytes" "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 00177b371..d39528051 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -19,7 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index cf80df90e..b318c681a 100644 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go @@ -15,8 +15,8 @@ package vfs import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go index ee5c8b9e2..392c7611e 100644 --- a/pkg/sentry/vfs/testutil.go +++ b/pkg/sentry/vfs/testutil.go @@ -18,8 +18,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 1f6f56293..b2bf48853 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -31,8 +31,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD new file mode 100644 index 000000000..ff8b9e91a --- /dev/null +++ b/pkg/usermem/BUILD @@ -0,0 +1,55 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "addr_range", + out = "addr_range.go", + package = "usermem", + prefix = "Addr", + template = "//pkg/segment:generic_range", + types = { + "T": "Addr", + }, +) + +go_library( + name = "usermem", + srcs = [ + "access_type.go", + "addr.go", + "addr_range.go", + "addr_range_seq_unsafe.go", + "bytes_io.go", + "bytes_io_unsafe.go", + "usermem.go", + "usermem_arm64.go", + "usermem_unsafe.go", + "usermem_x86.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/atomicbitops", + "//pkg/binary", + "//pkg/context", + "//pkg/log", + "//pkg/safemem", + "//pkg/syserror", + ], +) + +go_test( + name = "usermem_test", + size = "small", + srcs = [ + "addr_range_seq_test.go", + "usermem_test.go", + ], + library = ":usermem", + deps = [ + "//pkg/context", + "//pkg/safemem", + "//pkg/syserror", + ], +) diff --git a/pkg/usermem/README.md b/pkg/usermem/README.md new file mode 100644 index 000000000..f6d2137eb --- /dev/null +++ b/pkg/usermem/README.md @@ -0,0 +1,31 @@ +This package defines primitives for sentry access to application memory. + +Major types: + +- The `IO` interface represents a virtual address space and provides I/O + methods on that address space. `IO` is the lowest-level primitive. The + primary implementation of the `IO` interface is `mm.MemoryManager`. + +- `IOSequence` represents a collection of individually-contiguous address + ranges in a `IO` that is operated on sequentially, analogous to Linux's + `struct iov_iter`. + +Major usage patterns: + +- Access to a task's virtual memory, subject to the application's memory + protections and while running on that task's goroutine, from a context that + is at or above the level of the `kernel` package (e.g. most syscall + implementations in `syscalls/linux`); use the `kernel.Task.Copy*` wrappers + defined in `kernel/task_usermem.go`. + +- Access to a task's virtual memory, from a context that is at or above the + level of the `kernel` package, but where any of the above constraints does + not hold (e.g. `PTRACE_POKEDATA`, which ignores application memory + protections); obtain the task's `mm.MemoryManager` by calling + `kernel.Task.MemoryManager`, and call its `IO` methods directly. + +- Access to a task's virtual memory, from a context that is below the level of + the `kernel` package (e.g. filesystem I/O); clients must pass I/O arguments + from higher layers, usually in the form of an `IOSequence`. The + `kernel.Task.SingleIOSequence` and `kernel.Task.IovecsIOSequence` functions + in `kernel/task_usermem.go` are convenience functions for doing so. diff --git a/pkg/usermem/access_type.go b/pkg/usermem/access_type.go new file mode 100644 index 000000000..9c1742a59 --- /dev/null +++ b/pkg/usermem/access_type.go @@ -0,0 +1,128 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "syscall" +) + +// AccessType specifies memory access types. This is used for +// setting mapping permissions, as well as communicating faults. +// +// +stateify savable +type AccessType struct { + // Read is read access. + Read bool + + // Write is write access. + Write bool + + // Execute is executable access. + Execute bool +} + +// String returns a pretty representation of access. This looks like the +// familiar r-x, rw-, etc. and can be relied on as such. +func (a AccessType) String() string { + bits := [3]byte{'-', '-', '-'} + if a.Read { + bits[0] = 'r' + } + if a.Write { + bits[1] = 'w' + } + if a.Execute { + bits[2] = 'x' + } + return string(bits[:]) +} + +// Any returns true iff at least one of Read, Write or Execute is true. +func (a AccessType) Any() bool { + return a.Read || a.Write || a.Execute +} + +// Prot returns the system prot (syscall.PROT_READ, etc.) for this access. +func (a AccessType) Prot() int { + var prot int + if a.Read { + prot |= syscall.PROT_READ + } + if a.Write { + prot |= syscall.PROT_WRITE + } + if a.Execute { + prot |= syscall.PROT_EXEC + } + return prot +} + +// SupersetOf returns true iff the access types in a are a superset of the +// access types in other. +func (a AccessType) SupersetOf(other AccessType) bool { + if !a.Read && other.Read { + return false + } + if !a.Write && other.Write { + return false + } + if !a.Execute && other.Execute { + return false + } + return true +} + +// Intersect returns the access types set in both a and other. +func (a AccessType) Intersect(other AccessType) AccessType { + return AccessType{ + Read: a.Read && other.Read, + Write: a.Write && other.Write, + Execute: a.Execute && other.Execute, + } +} + +// Union returns the access types set in either a or other. +func (a AccessType) Union(other AccessType) AccessType { + return AccessType{ + Read: a.Read || other.Read, + Write: a.Write || other.Write, + Execute: a.Execute || other.Execute, + } +} + +// Effective returns the set of effective access types allowed by a, even if +// some types are not explicitly allowed. +func (a AccessType) Effective() AccessType { + // In Linux, Write and Execute access generally imply Read access. See + // mm/mmap.c:protection_map. + // + // The notable exception is get_user_pages, which only checks against + // the original vma flags. That said, most user memory accesses do not + // use GUP. + if a.Write || a.Execute { + a.Read = true + } + return a +} + +// Convenient access types. +var ( + NoAccess = AccessType{} + Read = AccessType{Read: true} + Write = AccessType{Write: true} + Execute = AccessType{Execute: true} + ReadWrite = AccessType{Read: true, Write: true} + AnyAccess = AccessType{Read: true, Write: true, Execute: true} +) diff --git a/pkg/usermem/addr.go b/pkg/usermem/addr.go new file mode 100644 index 000000000..e79210804 --- /dev/null +++ b/pkg/usermem/addr.go @@ -0,0 +1,108 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "fmt" +) + +// Addr represents a generic virtual address. +// +// +stateify savable +type Addr uintptr + +// AddLength adds the given length to start and returns the result. ok is true +// iff adding the length did not overflow the range of Addr. +// +// Note: This function is usually used to get the end of an address range +// defined by its start address and length. Since the resulting end is +// exclusive, end == 0 is technically valid, and corresponds to a range that +// extends to the end of the address space, but ok will be false. This isn't +// expected to ever come up in practice. +func (v Addr) AddLength(length uint64) (end Addr, ok bool) { + end = v + Addr(length) + // The second half of the following check is needed in case uintptr is + // smaller than 64 bits. + ok = end >= v && length <= uint64(^Addr(0)) + return +} + +// RoundDown returns the address rounded down to the nearest page boundary. +func (v Addr) RoundDown() Addr { + return v & ^Addr(PageSize-1) +} + +// RoundUp returns the address rounded up to the nearest page boundary. ok is +// true iff rounding up did not wrap around. +func (v Addr) RoundUp() (addr Addr, ok bool) { + addr = Addr(v + PageSize - 1).RoundDown() + ok = addr >= v + return +} + +// MustRoundUp is equivalent to RoundUp, but panics if rounding up wraps +// around. +func (v Addr) MustRoundUp() Addr { + addr, ok := v.RoundUp() + if !ok { + panic(fmt.Sprintf("usermem.Addr(%d).RoundUp() wraps", v)) + } + return addr +} + +// HugeRoundDown returns the address rounded down to the nearest huge page +// boundary. +func (v Addr) HugeRoundDown() Addr { + return v & ^Addr(HugePageSize-1) +} + +// HugeRoundUp returns the address rounded up to the nearest huge page boundary. +// ok is true iff rounding up did not wrap around. +func (v Addr) HugeRoundUp() (addr Addr, ok bool) { + addr = Addr(v + HugePageSize - 1).HugeRoundDown() + ok = addr >= v + return +} + +// PageOffset returns the offset of v into the current page. +func (v Addr) PageOffset() uint64 { + return uint64(v & Addr(PageSize-1)) +} + +// IsPageAligned returns true if v.PageOffset() == 0. +func (v Addr) IsPageAligned() bool { + return v.PageOffset() == 0 +} + +// AddrRange is a range of Addrs. +// +// type AddrRange + +// ToRange returns [v, v+length). +func (v Addr) ToRange(length uint64) (AddrRange, bool) { + end, ok := v.AddLength(length) + return AddrRange{v, end}, ok +} + +// IsPageAligned returns true if ar.Start.IsPageAligned() and +// ar.End.IsPageAligned(). +func (ar AddrRange) IsPageAligned() bool { + return ar.Start.IsPageAligned() && ar.End.IsPageAligned() +} + +// String implements fmt.Stringer.String. +func (ar AddrRange) String() string { + return fmt.Sprintf("[%#x, %#x)", ar.Start, ar.End) +} diff --git a/pkg/usermem/addr_range_seq_test.go b/pkg/usermem/addr_range_seq_test.go new file mode 100644 index 000000000..82f735026 --- /dev/null +++ b/pkg/usermem/addr_range_seq_test.go @@ -0,0 +1,197 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "testing" +) + +var addrRangeSeqTests = []struct { + desc string + ranges []AddrRange +}{ + { + desc: "Empty sequence", + }, + { + desc: "Single empty AddrRange", + ranges: []AddrRange{ + {0x10, 0x10}, + }, + }, + { + desc: "Single non-empty AddrRange of length 1", + ranges: []AddrRange{ + {0x10, 0x11}, + }, + }, + { + desc: "Single non-empty AddrRange of length 2", + ranges: []AddrRange{ + {0x10, 0x12}, + }, + }, + { + desc: "Multiple non-empty AddrRanges", + ranges: []AddrRange{ + {0x10, 0x11}, + {0x20, 0x22}, + }, + }, + { + desc: "Multiple AddrRanges including empty AddrRanges", + ranges: []AddrRange{ + {0x10, 0x10}, + {0x20, 0x20}, + {0x30, 0x33}, + {0x40, 0x44}, + {0x50, 0x50}, + {0x60, 0x60}, + {0x70, 0x77}, + {0x80, 0x88}, + {0x90, 0x90}, + {0xa0, 0xa0}, + }, + }, +} + +func testAddrRangeSeqEqualityWithTailIteration(t *testing.T, ars AddrRangeSeq, wantRanges []AddrRange) { + var wantLen int64 + for _, ar := range wantRanges { + wantLen += int64(ar.Length()) + } + + var i int + for !ars.IsEmpty() { + if gotLen := ars.NumBytes(); gotLen != wantLen { + t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) + } + if gotN, wantN := ars.NumRanges(), len(wantRanges)-i; gotN != wantN { + t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted %d", i, ars, gotN, wantN) + } + got := ars.Head() + if i >= len(wantRanges) { + t.Errorf("Iteration %d: %v.Head(): got %s, wanted ", i, ars, got) + } else if want := wantRanges[i]; got != want { + t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) + } + ars = ars.Tail() + wantLen -= int64(got.Length()) + i++ + } + if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { + t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) + } + if gotN := ars.NumRanges(); gotN != 0 { + t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted 0", i, ars, gotN) + } +} + +func TestAddrRangeSeqTailIteration(t *testing.T) { + for _, test := range addrRangeSeqTests { + t.Run(test.desc, func(t *testing.T) { + testAddrRangeSeqEqualityWithTailIteration(t, AddrRangeSeqFromSlice(test.ranges), test.ranges) + }) + } +} + +func TestAddrRangeSeqDropFirstEmpty(t *testing.T) { + var ars AddrRangeSeq + if got, want := ars.DropFirst(1), ars; got != want { + t.Errorf("%v.DropFirst(1): got %v, wanted %v", ars, got, want) + } +} + +func TestAddrRangeSeqDropSingleByteIteration(t *testing.T) { + // Tests AddrRangeSeq iteration using Head/DropFirst, simulating + // I/O-per-AddrRange. + for _, test := range addrRangeSeqTests { + t.Run(test.desc, func(t *testing.T) { + // Figure out what AddrRanges we expect to see. + var wantLen int64 + var wantRanges []AddrRange + for _, ar := range test.ranges { + wantLen += int64(ar.Length()) + wantRanges = append(wantRanges, ar) + if ar.Length() == 0 { + // We "do" 0 bytes of I/O and then call DropFirst(0), + // advancing to the next AddrRange. + continue + } + // Otherwise we "do" 1 byte of I/O and then call DropFirst(1), + // advancing the AddrRange by 1 byte, or to the next AddrRange + // if this one is exhausted. + for ar.Start++; ar.Length() != 0; ar.Start++ { + wantRanges = append(wantRanges, ar) + } + } + t.Logf("Expected AddrRanges: %s (%d bytes)", wantRanges, wantLen) + + ars := AddrRangeSeqFromSlice(test.ranges) + var i int + for !ars.IsEmpty() { + if gotLen := ars.NumBytes(); gotLen != wantLen { + t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) + } + got := ars.Head() + if i >= len(wantRanges) { + t.Errorf("Iteration %d: %v.Head(): got %s, wanted ", i, ars, got) + } else if want := wantRanges[i]; got != want { + t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) + } + if got.Length() == 0 { + ars = ars.DropFirst(0) + } else { + ars = ars.DropFirst(1) + wantLen-- + } + i++ + } + if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { + t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) + } + }) + } +} + +func TestAddrRangeSeqTakeFirstEmpty(t *testing.T) { + var ars AddrRangeSeq + if got, want := ars.TakeFirst(1), ars; got != want { + t.Errorf("%v.TakeFirst(1): got %v, wanted %v", ars, got, want) + } +} + +func TestAddrRangeSeqTakeFirst(t *testing.T) { + ranges := []AddrRange{ + {0x10, 0x11}, + {0x20, 0x22}, + {0x30, 0x30}, + {0x40, 0x44}, + {0x50, 0x55}, + {0x60, 0x60}, + {0x70, 0x77}, + } + ars := AddrRangeSeqFromSlice(ranges).TakeFirst(5) + want := []AddrRange{ + {0x10, 0x11}, // +1 byte (total 1 byte), not truncated + {0x20, 0x22}, // +2 bytes (total 3 bytes), not truncated + {0x30, 0x30}, // +0 bytes (total 3 bytes), no change + {0x40, 0x42}, // +2 bytes (total 5 bytes), partially truncated + {0x50, 0x50}, // +0 bytes (total 5 bytes), fully truncated + {0x60, 0x60}, // +0 bytes (total 5 bytes), "fully truncated" (no change) + {0x70, 0x70}, // +0 bytes (total 5 bytes), fully truncated + } + testAddrRangeSeqEqualityWithTailIteration(t, ars, want) +} diff --git a/pkg/usermem/addr_range_seq_unsafe.go b/pkg/usermem/addr_range_seq_unsafe.go new file mode 100644 index 000000000..c09337c15 --- /dev/null +++ b/pkg/usermem/addr_range_seq_unsafe.go @@ -0,0 +1,277 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "bytes" + "fmt" + "reflect" + "unsafe" +) + +// An AddrRangeSeq represents a sequence of AddrRanges. +// +// AddrRangeSeqs are immutable and may be copied by value. The zero value of +// AddrRangeSeq represents an empty sequence. +// +// An AddrRangeSeq may contain AddrRanges with a length of 0. This is necessary +// since zero-length AddrRanges are significant to MM bounds checks. +type AddrRangeSeq struct { + // If length is 0, then the AddrRangeSeq represents no AddrRanges. + // Invariants: data == 0; offset == 0; limit == 0. + // + // If length is 1, then the AddrRangeSeq represents the single + // AddrRange{offset, offset+limit}. Invariants: data == 0. + // + // Otherwise, length >= 2, and the AddrRangeSeq represents the `length` + // AddrRanges in the array of AddrRanges starting at address `data`, + // starting at `offset` bytes into the first AddrRange and limited to the + // following `limit` bytes. (AddrRanges after `limit` are still iterated, + // but are truncated to a length of 0.) Invariants: data != 0; offset <= + // data[0].Length(); limit > 0; offset+limit <= the combined length of all + // AddrRanges in the array. + data unsafe.Pointer + length int + offset Addr + limit Addr +} + +// AddrRangeSeqOf returns an AddrRangeSeq representing the single AddrRange ar. +func AddrRangeSeqOf(ar AddrRange) AddrRangeSeq { + return AddrRangeSeq{ + length: 1, + offset: ar.Start, + limit: ar.Length(), + } +} + +// AddrRangeSeqFromSlice returns an AddrRangeSeq representing all AddrRanges in +// slice. +// +// Whether the returned AddrRangeSeq shares memory with slice is unspecified; +// clients should avoid mutating slices passed to AddrRangeSeqFromSlice. +// +// Preconditions: The combined length of all AddrRanges in slice <= +// math.MaxInt64. +func AddrRangeSeqFromSlice(slice []AddrRange) AddrRangeSeq { + var limit int64 + for _, ar := range slice { + len64 := int64(ar.Length()) + if len64 < 0 { + panic(fmt.Sprintf("Length of AddrRange %v overflows int64", ar)) + } + sum := limit + len64 + if sum < limit { + panic(fmt.Sprintf("Total length of AddrRanges %v overflows int64", slice)) + } + limit = sum + } + return addrRangeSeqFromSliceLimited(slice, limit) +} + +// Preconditions: The combined length of all AddrRanges in slice <= limit. +// limit >= 0. If len(slice) != 0, then limit > 0. +func addrRangeSeqFromSliceLimited(slice []AddrRange, limit int64) AddrRangeSeq { + switch len(slice) { + case 0: + return AddrRangeSeq{} + case 1: + return AddrRangeSeq{ + length: 1, + offset: slice[0].Start, + limit: Addr(limit), + } + default: + return AddrRangeSeq{ + data: unsafe.Pointer(&slice[0]), + length: len(slice), + limit: Addr(limit), + } + } +} + +// IsEmpty returns true if ars.NumRanges() == 0. +// +// Note that since AddrRangeSeq may contain AddrRanges with a length of zero, +// an AddrRange representing 0 bytes (AddrRangeSeq.NumBytes() == 0) is not +// necessarily empty. +func (ars AddrRangeSeq) IsEmpty() bool { + return ars.length == 0 +} + +// NumRanges returns the number of AddrRanges in ars. +func (ars AddrRangeSeq) NumRanges() int { + return ars.length +} + +// NumBytes returns the number of bytes represented by ars. +func (ars AddrRangeSeq) NumBytes() int64 { + return int64(ars.limit) +} + +// Head returns the first AddrRange in ars. +// +// Preconditions: !ars.IsEmpty(). +func (ars AddrRangeSeq) Head() AddrRange { + if ars.length == 0 { + panic("empty AddrRangeSeq") + } + if ars.length == 1 { + return AddrRange{ars.offset, ars.offset + ars.limit} + } + ar := *(*AddrRange)(ars.data) + ar.Start += ars.offset + if ar.Length() > ars.limit { + ar.End = ar.Start + ars.limit + } + return ar +} + +// Tail returns an AddrRangeSeq consisting of all AddrRanges in ars after the +// first. +// +// Preconditions: !ars.IsEmpty(). +func (ars AddrRangeSeq) Tail() AddrRangeSeq { + if ars.length == 0 { + panic("empty AddrRangeSeq") + } + if ars.length == 1 { + return AddrRangeSeq{} + } + return ars.externalTail() +} + +// Preconditions: ars.length >= 2. +func (ars AddrRangeSeq) externalTail() AddrRangeSeq { + headLen := (*AddrRange)(ars.data).Length() - ars.offset + var tailLimit int64 + if ars.limit > headLen { + tailLimit = int64(ars.limit - headLen) + } + var extSlice []AddrRange + extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice)) + extSliceHdr.Data = uintptr(ars.data) + extSliceHdr.Len = ars.length + extSliceHdr.Cap = ars.length + return addrRangeSeqFromSliceLimited(extSlice[1:], tailLimit) +} + +// DropFirst returns an AddrRangeSeq equivalent to ars, but with the first n +// bytes omitted. If n > ars.NumBytes(), DropFirst returns an empty +// AddrRangeSeq. +// +// If !ars.IsEmpty() and ars.Head().Length() == 0, DropFirst will always omit +// at least ars.Head(), even if n == 0. This guarantees that the basic pattern +// of: +// +// for !ars.IsEmpty() { +// n, err = doIOWith(ars.Head()) +// if err != nil { +// return err +// } +// ars = ars.DropFirst(n) +// } +// +// works even in the presence of zero-length AddrRanges. +// +// Preconditions: n >= 0. +func (ars AddrRangeSeq) DropFirst(n int) AddrRangeSeq { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return ars.DropFirst64(int64(n)) +} + +// DropFirst64 is equivalent to DropFirst but takes an int64. +func (ars AddrRangeSeq) DropFirst64(n int64) AddrRangeSeq { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + if Addr(n) > ars.limit { + return AddrRangeSeq{} + } + // Handle initial empty AddrRange. + switch ars.length { + case 0: + return AddrRangeSeq{} + case 1: + if ars.limit == 0 { + return AddrRangeSeq{} + } + default: + if rawHeadLen := (*AddrRange)(ars.data).Length(); ars.offset == rawHeadLen { + ars = ars.externalTail() + } + } + for n != 0 { + // Calling ars.Head() here is surprisingly expensive, so inline getting + // the head's length. + var headLen Addr + if ars.length == 1 { + headLen = ars.limit + } else { + headLen = (*AddrRange)(ars.data).Length() - ars.offset + } + if Addr(n) < headLen { + // Dropping ends partway through the head AddrRange. + ars.offset += Addr(n) + ars.limit -= Addr(n) + return ars + } + n -= int64(headLen) + ars = ars.Tail() + } + return ars +} + +// TakeFirst returns an AddrRangeSeq equivalent to ars, but iterating at most n +// bytes. TakeFirst never removes AddrRanges from ars; AddrRanges beyond the +// first n bytes are reduced to a length of zero, but will still be iterated. +// +// Preconditions: n >= 0. +func (ars AddrRangeSeq) TakeFirst(n int) AddrRangeSeq { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return ars.TakeFirst64(int64(n)) +} + +// TakeFirst64 is equivalent to TakeFirst but takes an int64. +func (ars AddrRangeSeq) TakeFirst64(n int64) AddrRangeSeq { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + if ars.limit > Addr(n) { + ars.limit = Addr(n) + } + return ars +} + +// String implements fmt.Stringer.String. +func (ars AddrRangeSeq) String() string { + // This is deliberately chosen to be the same as fmt's automatic stringer + // for []AddrRange. + var buf bytes.Buffer + buf.WriteByte('[') + var sep string + for !ars.IsEmpty() { + buf.WriteString(sep) + sep = " " + buf.WriteString(ars.Head().String()) + ars = ars.Tail() + } + buf.WriteByte(']') + return buf.String() +} diff --git a/pkg/usermem/bytes_io.go b/pkg/usermem/bytes_io.go new file mode 100644 index 000000000..e177d30eb --- /dev/null +++ b/pkg/usermem/bytes_io.go @@ -0,0 +1,141 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/syserror" +) + +const maxInt = int(^uint(0) >> 1) + +// BytesIO implements IO using a byte slice. Addresses are interpreted as +// offsets into the slice. Reads and writes beyond the end of the slice return +// EFAULT. +type BytesIO struct { + Bytes []byte +} + +// CopyOut implements IO.CopyOut. +func (b *BytesIO) CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) { + rngN, rngErr := b.rangeCheck(addr, len(src)) + if rngN == 0 { + return 0, rngErr + } + return copy(b.Bytes[int(addr):], src[:rngN]), rngErr +} + +// CopyIn implements IO.CopyIn. +func (b *BytesIO) CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) { + rngN, rngErr := b.rangeCheck(addr, len(dst)) + if rngN == 0 { + return 0, rngErr + } + return copy(dst[:rngN], b.Bytes[int(addr):]), rngErr +} + +// ZeroOut implements IO.ZeroOut. +func (b *BytesIO) ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) { + if toZero > int64(maxInt) { + return 0, syserror.EINVAL + } + rngN, rngErr := b.rangeCheck(addr, int(toZero)) + if rngN == 0 { + return 0, rngErr + } + zeroSlice := b.Bytes[int(addr) : int(addr)+rngN] + for i := range zeroSlice { + zeroSlice[i] = 0 + } + return int64(rngN), rngErr +} + +// CopyOutFrom implements IO.CopyOutFrom. +func (b *BytesIO) CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) { + dsts, rngErr := b.blocksFromAddrRanges(ars) + n, err := src.ReadToBlocks(dsts) + if err != nil { + return int64(n), err + } + return int64(n), rngErr +} + +// CopyInTo implements IO.CopyInTo. +func (b *BytesIO) CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) { + srcs, rngErr := b.blocksFromAddrRanges(ars) + n, err := dst.WriteFromBlocks(srcs) + if err != nil { + return int64(n), err + } + return int64(n), rngErr +} + +func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) { + if length == 0 { + return 0, nil + } + if length < 0 { + return 0, syserror.EINVAL + } + max := Addr(len(b.Bytes)) + if addr >= max { + return 0, syserror.EFAULT + } + end, ok := addr.AddLength(uint64(length)) + if !ok || end > max { + return int(max - addr), syserror.EFAULT + } + return length, nil +} + +func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) { + switch ars.NumRanges() { + case 0: + return safemem.BlockSeq{}, nil + case 1: + block, err := b.blockFromAddrRange(ars.Head()) + return safemem.BlockSeqOf(block), err + default: + blocks := make([]safemem.Block, 0, ars.NumRanges()) + for !ars.IsEmpty() { + block, err := b.blockFromAddrRange(ars.Head()) + if block.Len() != 0 { + blocks = append(blocks, block) + } + if err != nil { + return safemem.BlockSeqFromSlice(blocks), err + } + ars = ars.Tail() + } + return safemem.BlockSeqFromSlice(blocks), nil + } +} + +func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) { + n, err := b.rangeCheck(ar.Start, int(ar.Length())) + if n == 0 { + return safemem.Block{}, err + } + return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err +} + +// BytesIOSequence returns an IOSequence representing the given byte slice. +func BytesIOSequence(buf []byte) IOSequence { + return IOSequence{ + IO: &BytesIO{buf}, + Addrs: AddrRangeSeqOf(AddrRange{0, Addr(len(buf))}), + } +} diff --git a/pkg/usermem/bytes_io_unsafe.go b/pkg/usermem/bytes_io_unsafe.go new file mode 100644 index 000000000..20de5037d --- /dev/null +++ b/pkg/usermem/bytes_io_unsafe.go @@ -0,0 +1,47 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "sync/atomic" + "unsafe" + + "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/context" +) + +// SwapUint32 implements IO.SwapUint32. +func (b *BytesIO) SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) { + if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil { + return 0, rngErr + } + return atomic.SwapUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)])), new), nil +} + +// CompareAndSwapUint32 implements IO.CompareAndSwapUint32. +func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) { + if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil { + return 0, rngErr + } + return atomicbitops.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)])), old, new), nil +} + +// LoadUint32 implements IO.LoadUint32. +func (b *BytesIO) LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) { + if _, err := b.rangeCheck(addr, 4); err != nil { + return 0, err + } + return atomic.LoadUint32((*uint32)(unsafe.Pointer(&b.Bytes[int(addr)]))), nil +} diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go new file mode 100644 index 000000000..71fd4e155 --- /dev/null +++ b/pkg/usermem/usermem.go @@ -0,0 +1,597 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package usermem governs access to user memory. +package usermem + +import ( + "bytes" + "errors" + "io" + "strconv" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// IO provides access to the contents of a virtual memory space. +// +// FIXME(b/38173783): Implementations of IO cannot expect ctx to contain any +// meaningful data. +type IO interface { + // CopyOut copies len(src) bytes from src to the memory mapped at addr. It + // returns the number of bytes copied. If the number of bytes copied is < + // len(src), it returns a non-nil error explaining why. + // + // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or + // any following locks in the lock order. + // + // Postconditions: CopyOut does not retain src. + CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) + + // CopyIn copies len(dst) bytes from the memory mapped at addr to dst. + // It returns the number of bytes copied. If the number of bytes copied is + // < len(dst), it returns a non-nil error explaining why. + // + // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or + // any following locks in the lock order. + // + // Postconditions: CopyIn does not retain dst. + CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) + + // ZeroOut sets toZero bytes to 0, starting at addr. It returns the number + // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a + // non-nil error explaining why. + // + // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or + // any following locks in the lock order. toZero >= 0. + ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) + + // CopyOutFrom copies ars.NumBytes() bytes from src to the memory mapped at + // ars. It returns the number of bytes copied, which may be less than the + // number of bytes read from src if copying fails. CopyOutFrom may return a + // partial copy without an error iff src.ReadToBlocks returns a partial + // read without an error. + // + // CopyOutFrom calls src.ReadToBlocks at most once. + // + // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or + // any following locks in the lock order. src.ReadToBlocks must not block + // on mm.MemoryManager.activeMu or any preceding locks in the lock order. + CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) + + // CopyInTo copies ars.NumBytes() bytes from the memory mapped at ars to + // dst. It returns the number of bytes copied. CopyInTo may return a + // partial copy without an error iff dst.WriteFromBlocks returns a partial + // write without an error. + // + // CopyInTo calls dst.WriteFromBlocks at most once. + // + // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or + // any following locks in the lock order. dst.WriteFromBlocks must not + // block on mm.MemoryManager.activeMu or any preceding locks in the lock + // order. + CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) + + // TODO(jamieliu): The requirement that CopyOutFrom/CopyInTo call src/dst + // at most once, which is unnecessary in most cases, forces implementations + // to gather safemem.Blocks into a single slice to pass to src/dst. Add + // CopyOutFromIter/CopyInToIter, which relaxes this restriction, to avoid + // this allocation. + + // SwapUint32 atomically sets the uint32 value at addr to new and + // returns the previous value. + // + // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or + // any following locks in the lock order. addr must be aligned to a 4-byte + // boundary. + SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) + + // CompareAndSwapUint32 atomically compares the uint32 value at addr to + // old; if they are equal, the value in memory is replaced by new. In + // either case, the previous value stored in memory is returned. + // + // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or + // any following locks in the lock order. addr must be aligned to a 4-byte + // boundary. + CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) + + // LoadUint32 atomically loads the uint32 value at addr and returns it. + // + // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or + // any following locks in the lock order. addr must be aligned to a 4-byte + // boundary. + LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) +} + +// IOOpts contains options applicable to all IO methods. +type IOOpts struct { + // If IgnorePermissions is true, application-defined memory protections set + // by mmap(2) or mprotect(2) will be ignored. (Memory protections required + // by the target of the mapping are never ignored.) + IgnorePermissions bool + + // If AddressSpaceActive is true, the IO implementation may assume that it + // has an active AddressSpace and can therefore use AddressSpace copying + // without performing activation. See mm/io.go for details. + AddressSpaceActive bool +} + +// IOReadWriter is an io.ReadWriter that reads from / writes to addresses +// starting at addr in IO. The preconditions that apply to IO.CopyIn and +// IO.CopyOut also apply to IOReadWriter.Read and IOReadWriter.Write +// respectively. +type IOReadWriter struct { + Ctx context.Context + IO IO + Addr Addr + Opts IOOpts +} + +// Read implements io.Reader.Read. +// +// Note that an address space does not have an "end of file", so Read can only +// return io.EOF if IO.CopyIn returns io.EOF. Attempts to read unmapped or +// unreadable memory, or beyond the end of the address space, should return +// EFAULT. +func (rw *IOReadWriter) Read(dst []byte) (int, error) { + n, err := rw.IO.CopyIn(rw.Ctx, rw.Addr, dst, rw.Opts) + end, ok := rw.Addr.AddLength(uint64(n)) + if ok { + rw.Addr = end + } else { + // Disallow wraparound. + rw.Addr = ^Addr(0) + if err != nil { + err = syserror.EFAULT + } + } + return n, err +} + +// Writer implements io.Writer.Write. +func (rw *IOReadWriter) Write(src []byte) (int, error) { + n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts) + end, ok := rw.Addr.AddLength(uint64(n)) + if ok { + rw.Addr = end + } else { + // Disallow wraparound. + rw.Addr = ^Addr(0) + if err != nil { + err = syserror.EFAULT + } + } + return n, err +} + +// CopyObjectOut copies a fixed-size value or slice of fixed-size values from +// src to the memory mapped at addr in uio. It returns the number of bytes +// copied. +// +// CopyObjectOut must use reflection to encode src; performance-sensitive +// clients should do encoding manually and use uio.CopyOut directly. +// +// Preconditions: As for IO.CopyOut. +func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) { + w := &IOReadWriter{ + Ctx: ctx, + IO: uio, + Addr: addr, + Opts: opts, + } + // Allocate a byte slice the size of the object being marshaled. This + // adds an extra reflection call, but avoids needing to grow the slice + // during encoding, which can result in many heap-allocated slices. + b := make([]byte, 0, binary.Size(src)) + return w.Write(binary.Marshal(b, ByteOrder, src)) +} + +// CopyObjectIn copies a fixed-size value or slice of fixed-size values from +// the memory mapped at addr in uio to dst. It returns the number of bytes +// copied. +// +// CopyObjectIn must use reflection to decode dst; performance-sensitive +// clients should use uio.CopyIn directly and do decoding manually. +// +// Preconditions: As for IO.CopyIn. +func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) { + r := &IOReadWriter{ + Ctx: ctx, + IO: uio, + Addr: addr, + Opts: opts, + } + buf := make([]byte, binary.Size(dst)) + if _, err := io.ReadFull(r, buf); err != nil { + return 0, err + } + binary.Unmarshal(buf, ByteOrder, dst) + return int(r.Addr - addr), nil +} + +// CopyStringIn tuning parameters, defined outside that function for tests. +const ( + copyStringIncrement = 64 + copyStringMaxInitBufLen = 256 +) + +// CopyStringIn copies a NUL-terminated string of unknown length from the +// memory mapped at addr in uio and returns it as a string (not including the +// trailing NUL). If the length of the string, including the terminating NUL, +// would exceed maxlen, CopyStringIn returns the string truncated to maxlen and +// ENAMETOOLONG. +// +// Preconditions: As for IO.CopyFromUser. maxlen >= 0. +func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) { + initLen := maxlen + if initLen > copyStringMaxInitBufLen { + initLen = copyStringMaxInitBufLen + } + buf := make([]byte, initLen) + var done int + for done < maxlen { + // Read up to copyStringIncrement bytes at a time. + readlen := copyStringIncrement + if readlen > maxlen-done { + readlen = maxlen - done + } + end, ok := addr.AddLength(uint64(readlen)) + if !ok { + return stringFromImmutableBytes(buf[:done]), syserror.EFAULT + } + // Shorten the read to avoid crossing page boundaries, since faulting + // in a page unnecessarily is expensive. This also ensures that partial + // copies up to the end of application-mappable memory succeed. + if addr.RoundDown() != end.RoundDown() { + end = end.RoundDown() + readlen = int(end - addr) + } + // Ensure that our buffer is large enough to accommodate the read. + if done+readlen > len(buf) { + newBufLen := len(buf) * 2 + if newBufLen > maxlen { + newBufLen = maxlen + } + buf = append(buf, make([]byte, newBufLen-len(buf))...) + } + n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts) + // Look for the terminating zero byte, which may have occurred before + // hitting err. + if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 { + return stringFromImmutableBytes(buf[:done+i]), nil + } + + done += n + if err != nil { + return stringFromImmutableBytes(buf[:done]), err + } + addr = end + } + return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG +} + +// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The +// maximum number of bytes copied is ars.NumBytes() or len(src), whichever is +// less. CopyOutVec returns the number of bytes copied; if this is less than +// the maximum, it returns a non-nil error explaining why. +// +// Preconditions: As for IO.CopyOut. +func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts IOOpts) (int, error) { + var done int + for !ars.IsEmpty() && done < len(src) { + ar := ars.Head() + cplen := len(src) - done + if Addr(cplen) >= ar.Length() { + cplen = int(ar.Length()) + } + n, err := uio.CopyOut(ctx, ar.Start, src[done:done+cplen], opts) + done += n + if err != nil { + return done, err + } + ars = ars.DropFirst(n) + } + return done, nil +} + +// CopyInVec copies bytes from the memory mapped at ars in uio to dst. The +// maximum number of bytes copied is ars.NumBytes() or len(dst), whichever is +// less. CopyInVec returns the number of bytes copied; if this is less than the +// maximum, it returns a non-nil error explaining why. +// +// Preconditions: As for IO.CopyIn. +func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts IOOpts) (int, error) { + var done int + for !ars.IsEmpty() && done < len(dst) { + ar := ars.Head() + cplen := len(dst) - done + if Addr(cplen) >= ar.Length() { + cplen = int(ar.Length()) + } + n, err := uio.CopyIn(ctx, ar.Start, dst[done:done+cplen], opts) + done += n + if err != nil { + return done, err + } + ars = ars.DropFirst(n) + } + return done, nil +} + +// ZeroOutVec writes zeroes to the memory mapped at ars in uio. The maximum +// number of bytes written is ars.NumBytes() or toZero, whichever is less. +// ZeroOutVec returns the number of bytes written; if this is less than the +// maximum, it returns a non-nil error explaining why. +// +// Preconditions: As for IO.ZeroOut. +func ZeroOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) { + var done int64 + for !ars.IsEmpty() && done < toZero { + ar := ars.Head() + cplen := toZero - done + if Addr(cplen) >= ar.Length() { + cplen = int64(ar.Length()) + } + n, err := uio.ZeroOut(ctx, ar.Start, cplen, opts) + done += n + if err != nil { + return done, err + } + ars = ars.DropFirst64(n) + } + return done, nil +} + +func isASCIIWhitespace(b byte) bool { + // Compare Linux include/linux/ctype.h, lib/ctype.c. + // 9 => horizontal tab '\t' + // 10 => line feed '\n' + // 11 => vertical tab '\v' + // 12 => form feed '\c' + // 13 => carriage return '\r' + return b == ' ' || (b >= 9 && b <= 13) +} + +// CopyInt32StringsInVec copies up to len(dsts) whitespace-separated decimal +// strings from the memory mapped at ars in uio and converts them to int32 +// values in dsts. It returns the number of bytes read. +// +// CopyInt32StringsInVec shares the following properties with Linux's +// kernel/sysctl.c:proc_dointvec(write=1): +// +// - If any read value overflows the range of int32, or any invalid characters +// are encountered during the read, CopyInt32StringsInVec returns EINVAL. +// +// - If, upon reaching the end of ars, fewer than len(dsts) values have been +// read, CopyInt32StringsInVec returns no error if at least 1 value was read +// and EINVAL otherwise. +// +// - Trailing whitespace after the last successfully read value is counted in +// the number of bytes read. +// +// Unlike proc_dointvec(): +// +// - CopyInt32StringsInVec does not implicitly limit ars.NumBytes() to +// PageSize-1; callers that require this must do so explicitly. +// +// - CopyInt32StringsInVec returns EINVAL if ars.NumBytes() == 0. +// +// Preconditions: As for CopyInVec. +func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) { + if len(dsts) == 0 { + return 0, nil + } + + buf := make([]byte, ars.NumBytes()) + n, cperr := CopyInVec(ctx, uio, ars, buf, opts) + buf = buf[:n] + + var i, j int + for ; j < len(dsts); j++ { + // Skip leading whitespace. + for i < len(buf) && isASCIIWhitespace(buf[i]) { + i++ + } + if i == len(buf) { + break + } + + // Find the end of the value to be parsed (next whitespace or end of string). + nextI := i + 1 + for nextI < len(buf) && !isASCIIWhitespace(buf[nextI]) { + nextI++ + } + + // Parse a single value. + val, err := strconv.ParseInt(string(buf[i:nextI]), 10, 32) + if err != nil { + return int64(i), syserror.EINVAL + } + dsts[j] = int32(val) + + i = nextI + } + + // Skip trailing whitespace. + for i < len(buf) && isASCIIWhitespace(buf[i]) { + i++ + } + + if cperr != nil { + return int64(i), cperr + } + if j == 0 { + return int64(i), syserror.EINVAL + } + return int64(i), nil +} + +// CopyInt32StringInVec is equivalent to CopyInt32StringsInVec, but copies at +// most one int32. +func CopyInt32StringInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst *int32, opts IOOpts) (int64, error) { + dsts := [1]int32{*dst} + n, err := CopyInt32StringsInVec(ctx, uio, ars, dsts[:], opts) + *dst = dsts[0] + return n, err +} + +// IOSequence holds arguments to IO methods. +type IOSequence struct { + IO IO + Addrs AddrRangeSeq + Opts IOOpts +} + +// NumBytes returns s.Addrs.NumBytes(). +// +// Note that NumBytes() may return 0 even if !s.Addrs.IsEmpty(), since +// s.Addrs may contain a non-zero number of zero-length AddrRanges. +// Many clients of +// IOSequence currently do something like: +// +// if ioseq.NumBytes() == 0 { +// return 0, nil +// } +// if f.availableBytes == 0 { +// return 0, syserror.ErrWouldBlock +// } +// return ioseq.CopyOutFrom(..., reader) +// +// In such cases, using s.Addrs.IsEmpty() will cause them to have the wrong +// behavior for zero-length I/O. However, using s.NumBytes() == 0 instead means +// that we will return success for zero-length I/O in cases where Linux would +// return EFAULT due to a failed access_ok() check, so in the long term we +// should move checks for ErrWouldBlock etc. into the body of +// reader.ReadToBlocks and use s.Addrs.IsEmpty() instead. +func (s IOSequence) NumBytes() int64 { + return s.Addrs.NumBytes() +} + +// DropFirst returns a copy of s with s.Addrs.DropFirst(n). +// +// Preconditions: As for AddrRangeSeq.DropFirst. +func (s IOSequence) DropFirst(n int) IOSequence { + return IOSequence{s.IO, s.Addrs.DropFirst(n), s.Opts} +} + +// DropFirst64 returns a copy of s with s.Addrs.DropFirst64(n). +// +// Preconditions: As for AddrRangeSeq.DropFirst64. +func (s IOSequence) DropFirst64(n int64) IOSequence { + return IOSequence{s.IO, s.Addrs.DropFirst64(n), s.Opts} +} + +// TakeFirst returns a copy of s with s.Addrs.TakeFirst(n). +// +// Preconditions: As for AddrRangeSeq.TakeFirst. +func (s IOSequence) TakeFirst(n int) IOSequence { + return IOSequence{s.IO, s.Addrs.TakeFirst(n), s.Opts} +} + +// TakeFirst64 returns a copy of s with s.Addrs.TakeFirst64(n). +// +// Preconditions: As for AddrRangeSeq.TakeFirst64. +func (s IOSequence) TakeFirst64(n int64) IOSequence { + return IOSequence{s.IO, s.Addrs.TakeFirst64(n), s.Opts} +} + +// CopyOut invokes CopyOutVec over s.Addrs. +// +// As with CopyOutVec, if s.NumBytes() < len(src), the copy will be truncated +// to s.NumBytes(), and a nil error will be returned. +// +// Preconditions: As for CopyOutVec. +func (s IOSequence) CopyOut(ctx context.Context, src []byte) (int, error) { + return CopyOutVec(ctx, s.IO, s.Addrs, src, s.Opts) +} + +// CopyIn invokes CopyInVec over s.Addrs. +// +// As with CopyInVec, if s.NumBytes() < len(dst), the copy will be truncated to +// s.NumBytes(), and a nil error will be returned. +// +// Preconditions: As for CopyInVec. +func (s IOSequence) CopyIn(ctx context.Context, dst []byte) (int, error) { + return CopyInVec(ctx, s.IO, s.Addrs, dst, s.Opts) +} + +// ZeroOut invokes ZeroOutVec over s.Addrs. +// +// As with ZeroOutVec, if s.NumBytes() < toZero, the write will be truncated +// to s.NumBytes(), and a nil error will be returned. +// +// Preconditions: As for ZeroOutVec. +func (s IOSequence) ZeroOut(ctx context.Context, toZero int64) (int64, error) { + return ZeroOutVec(ctx, s.IO, s.Addrs, toZero, s.Opts) +} + +// CopyOutFrom invokes s.CopyOutFrom over s.Addrs. +// +// Preconditions: As for IO.CopyOutFrom. +func (s IOSequence) CopyOutFrom(ctx context.Context, src safemem.Reader) (int64, error) { + return s.IO.CopyOutFrom(ctx, s.Addrs, src, s.Opts) +} + +// CopyInTo invokes s.CopyInTo over s.Addrs. +// +// Preconditions: As for IO.CopyInTo. +func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, error) { + return s.IO.CopyInTo(ctx, s.Addrs, dst, s.Opts) +} + +// Reader returns an io.Reader that reads from s. Reads beyond the end of s +// return io.EOF. The preconditions that apply to s.CopyIn also apply to the +// returned io.Reader.Read. +func (s IOSequence) Reader(ctx context.Context) io.Reader { + return &ioSequenceReadWriter{ctx, s} +} + +// Writer returns an io.Writer that writes to s. Writes beyond the end of s +// return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also +// apply to the returned io.Writer.Write. +func (s IOSequence) Writer(ctx context.Context) io.Writer { + return &ioSequenceReadWriter{ctx, s} +} + +// ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when +// attempting to write beyond the end of the IOSequence. +var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence") + +type ioSequenceReadWriter struct { + ctx context.Context + s IOSequence +} + +// Read implements io.Reader.Read. +func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { + n, err := rw.s.CopyIn(rw.ctx, dst) + rw.s = rw.s.DropFirst(n) + if err == nil && rw.s.NumBytes() == 0 { + err = io.EOF + } + return n, err +} + +// Write implements io.Writer.Write. +func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) { + n, err := rw.s.CopyOut(rw.ctx, src) + rw.s = rw.s.DropFirst(n) + if err == nil && n < len(src) { + err = ErrEndOfIOSequence + } + return n, err +} diff --git a/pkg/usermem/usermem_arm64.go b/pkg/usermem/usermem_arm64.go new file mode 100644 index 000000000..fdfc30a66 --- /dev/null +++ b/pkg/usermem/usermem_arm64.go @@ -0,0 +1,53 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package usermem + +import ( + "encoding/binary" + "syscall" +) + +const ( + // PageSize is the system page size. + // arm64 support 4K/16K/64K page size, + // which can be get by syscall.Getpagesize(). + // Currently, only 4K page size is supported. + PageSize = 1 << PageShift + + // HugePageSize is the system huge page size. + HugePageSize = 1 << HugePageShift + + // PageShift is the binary log of the system page size. + PageShift = 12 + + // HugePageShift is the binary log of the system huge page size. + // Should be calculated by "PageShift + (PageShift - 3)" + // when multiple page size support is ready. + HugePageShift = 21 +) + +var ( + // ByteOrder is the native byte order (little endian). + ByteOrder = binary.LittleEndian +) + +func init() { + // Make sure the page size is 4K on arm64 platform. + if size := syscall.Getpagesize(); size != PageSize { + panic("Only 4K page size is supported on arm64!") + } +} diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go new file mode 100644 index 000000000..bf3c5df2b --- /dev/null +++ b/pkg/usermem/usermem_test.go @@ -0,0 +1,424 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "bytes" + "encoding/binary" + "fmt" + "reflect" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// newContext returns a context.Context that we can use in these tests (we +// can't use contexttest because it depends on usermem). +func newContext() context.Context { + return context.Background() +} + +func newBytesIOString(s string) *BytesIO { + return &BytesIO{[]byte(s)} +} + +func TestBytesIOCopyOutSuccess(t *testing.T) { + b := newBytesIOString("ABCDE") + n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) + if wantN := 3; n != wantN || err != nil { + t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) { + t.Errorf("Bytes: got %q, wanted %q", got, want) + } +} + +func TestBytesIOCopyOutFailure(t *testing.T) { + b := newBytesIOString("ABC") + n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) + if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { + t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) + } + if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) { + t.Errorf("Bytes: got %q, wanted %q", got, want) + } +} + +func TestBytesIOCopyInSuccess(t *testing.T) { + b := newBytesIOString("AfooE") + var dst [3]byte + n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) + if wantN := 3; n != wantN || err != nil { + t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} + +func TestBytesIOCopyInFailure(t *testing.T) { + b := newBytesIOString("Afo") + var dst [3]byte + n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) + if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { + t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) + } + if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} + +func TestBytesIOZeroOutSuccess(t *testing.T) { + b := newBytesIOString("ABCD") + n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{}) + if wantN := int64(2); n != wantN || err != nil { + t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) { + t.Errorf("Bytes: got %q, wanted %q", got, want) + } +} + +func TestBytesIOZeroOutFailure(t *testing.T) { + b := newBytesIOString("ABC") + n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{}) + if wantN, wantErr := int64(2), syserror.EFAULT; n != wantN || err != wantErr { + t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) + } + if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) { + t.Errorf("Bytes: got %q, wanted %q", got, want) + } +} + +func TestBytesIOCopyOutFromSuccess(t *testing.T) { + b := newBytesIOString("ABCDEFGH") + n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ + {Start: 4, End: 7}, + {Start: 1, End: 4}, + }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{}) + if wantN := int64(6); n != wantN || err != nil { + t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) { + t.Errorf("Bytes: got %q, wanted %q", got, want) + } +} + +func TestBytesIOCopyOutFromFailure(t *testing.T) { + b := newBytesIOString("ABCDE") + n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ + {Start: 1, End: 4}, + {Start: 4, End: 7}, + }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{}) + if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { + t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) + } + if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) { + t.Errorf("Bytes: got %q, wanted %q", got, want) + } +} + +func TestBytesIOCopyInToSuccess(t *testing.T) { + b := newBytesIOString("AfoobarH") + var dst bytes.Buffer + n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ + {Start: 4, End: 7}, + {Start: 1, End: 4}, + }), safemem.FromIOWriter{&dst}, IOOpts{}) + if wantN := int64(6); n != wantN || err != nil { + t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) { + t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) + } +} + +func TestBytesIOCopyInToFailure(t *testing.T) { + b := newBytesIOString("Afoob") + var dst bytes.Buffer + n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ + {Start: 1, End: 4}, + {Start: 4, End: 7}, + }), safemem.FromIOWriter{&dst}, IOOpts{}) + if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { + t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) + } + if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { + t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) + } +} + +type testStruct struct { + Int8 int8 + Uint8 uint8 + Int16 int16 + Uint16 uint16 + Int32 int32 + Uint32 uint32 + Int64 int64 + Uint64 uint64 +} + +func TestCopyObject(t *testing.T) { + wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8} + wantN := binary.Size(wantObj) + b := &BytesIO{make([]byte, wantN)} + ctx := newContext() + if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil { + t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + var gotObj testStruct + if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil { + t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if gotObj != wantObj { + t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj) + } +} + +func TestCopyStringInShort(t *testing.T) { + // Tests for string length <= copyStringIncrement. + want := strings.Repeat("A", copyStringIncrement-2) + mem := want + "\x00" + if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { + t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) + } +} + +func TestCopyStringInLong(t *testing.T) { + // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen + // (requiring multiple calls to IO.CopyIn()). + want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4) + mem := want + "\x00" + if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { + t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) + } +} + +func TestCopyStringInVeryLong(t *testing.T) { + // Tests for string length > copyStringMaxInitBufLen (requiring buffer + // reallocation). + want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4) + mem := want + "\x00" + if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil { + t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) + } +} + +func TestCopyStringInNoTerminatingZeroByte(t *testing.T) { + want := strings.Repeat("A", copyStringIncrement-1) + got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{}) + if wantErr := syserror.EFAULT; got != want || err != wantErr { + t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) + } +} + +func TestCopyStringInTruncatedByMaxlen(t *testing.T) { + got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{}) + if want, wantErr := strings.Repeat("A", 5), syserror.ENAMETOOLONG; got != want || err != wantErr { + t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) + } +} + +func TestCopyInt32StringsInVec(t *testing.T) { + for _, test := range []struct { + str string + n int + initial []int32 + final []int32 + }{ + { + str: "100 200", + n: len("100 200"), + initial: []int32{1, 2}, + final: []int32{100, 200}, + }, + { + // Fewer values ok + str: "100", + n: len("100"), + initial: []int32{1, 2}, + final: []int32{100, 2}, + }, + { + // Extra values ok + str: "100 200 300", + n: len("100 200 "), + initial: []int32{1, 2}, + final: []int32{100, 200}, + }, + { + // Leading and trailing whitespace ok + str: " 100\t200\n", + n: len(" 100\t200\n"), + initial: []int32{1, 2}, + final: []int32{100, 200}, + }, + } { + t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { + src := BytesIOSequence([]byte(test.str)) + dsts := append([]int32(nil), test.initial...) + if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil { + t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n) + } + if !reflect.DeepEqual(dsts, test.final) { + t.Errorf("dsts: got %v, wanted %v", dsts, test.final) + } + }) + } +} + +func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { + for _, s := range []string{"", "\n", "a123"} { + t.Run(fmt.Sprintf("%q", s), func(t *testing.T) { + src := BytesIOSequence([]byte(s)) + initial := []int32{1, 2} + dsts := append([]int32(nil), initial...) + if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL { + t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL) + } + if !reflect.DeepEqual(dsts, initial) { + t.Errorf("dsts: got %v, wanted %v", dsts, initial) + } + }) + } +} + +func TestIOSequenceCopyOut(t *testing.T) { + buf := []byte("ABCD") + s := BytesIOSequence(buf) + + // CopyOut limited by len(src). + n, err := s.CopyOut(newContext(), []byte("fo")) + if wantN := 2; n != wantN || err != nil { + t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if want := []byte("foCD"); !bytes.Equal(buf, want) { + t.Errorf("buf: got %q, wanted %q", buf, want) + } + s = s.DropFirst(2) + if got, want := s.NumBytes(), int64(2); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } + + // CopyOut limited by s.NumBytes(). + n, err = s.CopyOut(newContext(), []byte("obar")) + if wantN := 2; n != wantN || err != nil { + t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if want := []byte("foob"); !bytes.Equal(buf, want) { + t.Errorf("buf: got %q, wanted %q", buf, want) + } + s = s.DropFirst(2) + if got, want := s.NumBytes(), int64(0); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } +} + +func TestIOSequenceCopyIn(t *testing.T) { + s := BytesIOSequence([]byte("foob")) + dst := []byte("ABCDEF") + + // CopyIn limited by len(dst). + n, err := s.CopyIn(newContext(), dst[:2]) + if wantN := 2; n != wantN || err != nil { + t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if want := []byte("foCDEF"); !bytes.Equal(dst, want) { + t.Errorf("dst: got %q, wanted %q", dst, want) + } + s = s.DropFirst(2) + if got, want := s.NumBytes(), int64(2); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } + + // CopyIn limited by s.Remaining(). + n, err = s.CopyIn(newContext(), dst[2:]) + if wantN := 2; n != wantN || err != nil { + t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if want := []byte("foobEF"); !bytes.Equal(dst, want) { + t.Errorf("dst: got %q, wanted %q", dst, want) + } + s = s.DropFirst(2) + if got, want := s.NumBytes(), int64(0); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } +} + +func TestIOSequenceZeroOut(t *testing.T) { + buf := []byte("ABCD") + s := BytesIOSequence(buf) + + // ZeroOut limited by toZero. + n, err := s.ZeroOut(newContext(), 2) + if wantN := int64(2); n != wantN || err != nil { + t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) { + t.Errorf("buf: got %q, wanted %q", buf, want) + } + s = s.DropFirst(2) + if got, want := s.NumBytes(), int64(2); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } + + // ZeroOut limited by s.NumBytes(). + n, err = s.ZeroOut(newContext(), 4) + if wantN := int64(2); n != wantN || err != nil { + t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) { + t.Errorf("buf: got %q, wanted %q", buf, want) + } + s = s.DropFirst(2) + if got, want := s.NumBytes(), int64(0); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } +} + +func TestIOSequenceTakeFirst(t *testing.T) { + s := BytesIOSequence([]byte("foobar")) + if got, want := s.NumBytes(), int64(6); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } + + s = s.TakeFirst(3) + if got, want := s.NumBytes(), int64(3); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } + + // TakeFirst(n) where n > s.NumBytes() is a no-op. + s = s.TakeFirst(9) + if got, want := s.NumBytes(), int64(3); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } + + var dst [3]byte + n, err := s.CopyIn(newContext(), dst[:]) + if wantN := 3; n != wantN || err != nil { + t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } + s = s.DropFirst(3) + if got, want := s.NumBytes(), int64(0); got != want { + t.Errorf("NumBytes: got %v, wanted %v", got, want) + } +} diff --git a/pkg/usermem/usermem_unsafe.go b/pkg/usermem/usermem_unsafe.go new file mode 100644 index 000000000..876783e78 --- /dev/null +++ b/pkg/usermem/usermem_unsafe.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "unsafe" +) + +// stringFromImmutableBytes is equivalent to string(bs), except that it never +// copies even if escape analysis can't prove that bs does not escape. This is +// only valid if bs is never mutated after stringFromImmutableBytes returns. +func stringFromImmutableBytes(bs []byte) string { + // Compare strings.Builder.String(). + return *(*string)(unsafe.Pointer(&bs)) +} diff --git a/pkg/usermem/usermem_x86.go b/pkg/usermem/usermem_x86.go new file mode 100644 index 000000000..8059b72d2 --- /dev/null +++ b/pkg/usermem/usermem_x86.go @@ -0,0 +1,38 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 i386 + +package usermem + +import "encoding/binary" + +const ( + // PageSize is the system page size. + PageSize = 1 << PageShift + + // HugePageSize is the system huge page size. + HugePageSize = 1 << HugePageShift + + // PageShift is the binary log of the system page size. + PageShift = 12 + + // HugePageShift is the binary log of the system huge page size. + HugePageShift = 21 +) + +var ( + // ByteOrder is the native byte order (little endian). + ByteOrder = binary.LittleEndian +) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index f3ebc0231..a96c80261 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -30,6 +30,7 @@ go_library( deps = [ "//pkg/abi", "//pkg/abi/linux", + "//pkg/context", "//pkg/control/server", "//pkg/cpuid", "//pkg/eventchannel", @@ -39,7 +40,6 @@ go_library( "//pkg/refs", "//pkg/sentry/arch", "//pkg/sentry/arch:registers_go_proto", - "//pkg/sentry/context", "//pkg/sentry/control", "//pkg/sentry/fs", "//pkg/sentry/fs/dev", @@ -71,7 +71,6 @@ go_library( "//pkg/sentry/time", "//pkg/sentry/unimpl:unimplemented_syscall_go_proto", "//pkg/sentry/usage", - "//pkg/sentry/usermem", "//pkg/sentry/watchdog", "//pkg/sync", "//pkg/syserror", @@ -88,6 +87,7 @@ go_library( "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/urpc", + "//pkg/usermem", "//runsc/boot/filter", "//runsc/boot/platforms", "//runsc/specutils", @@ -111,7 +111,7 @@ go_test( "//pkg/control/server", "//pkg/log", "//pkg/p9", - "//pkg/sentry/context/contexttest", + "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sync", diff --git a/runsc/boot/fds.go b/runsc/boot/fds.go index e5de1f3d7..417d2d5fb 100644 --- a/runsc/boot/fds.go +++ b/runsc/boot/fds.go @@ -17,7 +17,7 @@ package boot import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" "gvisor.dev/gvisor/pkg/sentry/kernel" diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 421ccd255..0f62842ea 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -32,8 +32,8 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/gofer" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index bec0dc292..44aa63196 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -27,7 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/control/server" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" diff --git a/runsc/boot/user.go b/runsc/boot/user.go index 56cc12ee0..f0aa52135 100644 --- a/runsc/boot/user.go +++ b/runsc/boot/user.go @@ -22,10 +22,10 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) type fileReader struct { diff --git a/runsc/boot/user_test.go b/runsc/boot/user_test.go index 9aee2ad07..fb4e13dfb 100644 --- a/runsc/boot/user_test.go +++ b/runsc/boot/user_test.go @@ -23,7 +23,7 @@ import ( "testing" specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index 2918ceffe..d79786a68 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -54,8 +54,8 @@ go_marshal = rule( # marshal_deps are the dependencies requied by generated code. marshal_deps = [ "//tools/go_marshal/marshal", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", + "//pkg/safecopy", + "//pkg/usermem", ] # marshal_test_deps are required by test targets. diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 8392f3f6d..af90bdecb 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -27,8 +27,8 @@ import ( const ( marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal" - usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem" - safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" + safecopyImport = "gvisor.dev/gvisor/pkg/safecopy" + usermemImport = "gvisor.dev/gvisor/pkg/usermem" ) // List of identifiers we use in generated code, that may conflict a diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index 38ba49fed..e345e3a8e 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -15,7 +15,7 @@ go_test( deps = [ ":test", "//pkg/binary", - "//pkg/sentry/usermem", + "//pkg/usermem", "//tools/go_marshal/analysis", ], ) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go index e70db06d8..e12403741 100644 --- a/tools/go_marshal/test/benchmark_test.go +++ b/tools/go_marshal/test/benchmark_test.go @@ -22,7 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/analysis" test "gvisor.dev/gvisor/tools/go_marshal/test" ) -- cgit v1.2.3 From 51b783505b1ec164b02b48a0fd234509fba01a73 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 29 Jan 2020 15:41:51 -0800 Subject: Add support for TCP_DEFER_ACCEPT. PiperOrigin-RevId: 292233574 --- pkg/sentry/socket/netstack/netstack.go | 22 ++++ pkg/tcpip/tcpip.go | 6 ++ pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 25 ++--- pkg/tcpip/transport/tcp/connect.go | 53 +++++++++- pkg/tcpip/transport/tcp/endpoint.go | 26 ++++- pkg/tcpip/transport/tcp/forwarder.go | 4 +- pkg/tcpip/transport/tcp/tcp_test.go | 126 ++++++++++++++++++++++ test/syscalls/linux/socket_inet_loopback.cc | 158 ++++++++++++++++++++++++++++ test/syscalls/linux/tcp_socket.cc | 53 ++++++++++ 10 files changed, 451 insertions(+), 23 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8619cc506..049d04bf2 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1260,6 +1260,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_DEFER_ACCEPT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPDeferAcceptOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1713,6 +1725,16 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_DEFER_ACCEPT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 59c9b3fb0..0fa141d58 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -626,6 +626,12 @@ type TCPLingerTimeoutOption time.Duration // before being marked closed. type TCPTimeWaitTimeoutOption time.Duration +// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a +// accept to return a completed connection only when there is data to be +// read. This usually means the listening socket will drop the final ACK +// for a handshake till the specified timeout until a segment with data arrives. +type TCPDeferAcceptOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 4acd9fb9a..7b4a87a2d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -57,6 +57,7 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d469758eb..6101f2945 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { netProto = s.route.NetProto } - n := newEndpoint(l.stack, netProto, nil) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6only n.ID = s.id n.boundNICID = s.route.NICID() @@ -273,16 +273,17 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // createEndpoint creates a new endpoint in connected state and then performs // the TCP 3-way handshake. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) if err != nil { return nil, err } // listenEP is nil when listenContext is used by tcp.Forwarder. + deferAccept := time.Duration(0) if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { @@ -290,13 +291,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } // Perform the 3-way handshake. - h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - - h.resetToSynRcvd(isn, irs, opts) + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -377,16 +377,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer e.decSynRcvdCount() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(n) @@ -546,7 +544,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -576,8 +574,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // space available in the backlog. // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4e3c5419c..9ff7ac261 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -86,6 +86,19 @@ type handshake struct { // rcvWndScale is the receive window scale, as defined in RFC 1323. rcvWndScale int + + // startTime is the time at which the first SYN/SYN-ACK was sent. + startTime time.Time + + // deferAccept if non-zero will drop the final ACK for a passive + // handshake till an ACK segment with data is received or the timeout is + // hit. + deferAccept time.Duration + + // acked is true if the the final ACK for a 3-way handshake has + // been received. This is required to stop retransmitting the + // original SYN-ACK when deferAccept is enabled. + acked bool } func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { @@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { return h } +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) + h.resetToSynRcvd(isn, irs, opts, deferAccept) + return h +} + // FindWndScale determines the window scale to use for the given maximum window // size. func FindWndScale(wnd seqnum.Size) int { @@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.ackNum = irs + 1 h.mss = opts.MSS h.sndWndScale = opts.WS + h.deferAccept = deferAccept h.ep.mu.Lock() h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() @@ -352,6 +372,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { + // If deferAccept is not zero and this is a bare ACK and the + // timeout is not hit then drop the ACK. + if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + h.acked = true + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. @@ -365,10 +393,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) + // If the segment has data then requeue it for the receiver + // to process it again once main loop is started. + if s.data.Size() > 0 { + s.incRef() + h.ep.enqueueSegment(s) + } h.ep.mu.Unlock() - return nil } @@ -471,6 +505,7 @@ func (h *handshake) execute() *tcpip.Error { } } + h.startTime = time.Now() // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) @@ -524,11 +559,21 @@ func (h *handshake) execute() *tcpip.Error { switch index, _ := s.Fetch(true); index { case wakerForResend: timeOut *= 2 - if timeOut > 60*time.Second { + if timeOut > MaxRTO { return tcpip.ErrTimeout } rt.Reset(timeOut) - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + // Resend the SYN/SYN-ACK only if the following conditions hold. + // - It's an active handshake (deferAccept does not apply) + // - It's a passive handshake and we have not yet got the final-ACK. + // - It's a passive handshake and we got an ACK but deferAccept is + // enabled and we are now past the deferAccept duration. + // The last is required to provide a way for the peer to complete + // the connection with another ACK or data (as ACKs are never + // retransmitted on their own). + if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + } case wakerForNotification: n := h.ep.fetchNotifications() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 13718ff55..8d52414b7 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -498,6 +498,13 @@ type endpoint struct { // without any data being acked. userTimeout time.Duration + // deferAccept if non-zero specifies a user specified time during + // which the final ACK of a handshake will be dropped provided the + // ACK is a bare ACK and carries no data. If the timeout is crossed then + // the bare ACK is accepted and the connection is delivered to the + // listener. + deferAccept time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -1574,6 +1581,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPDeferAcceptOption: + e.mu.Lock() + if time.Duration(v) > MaxRTO { + v = tcpip.TCPDeferAcceptOption(MaxRTO) + } + e.deferAccept = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1798,6 +1814,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.TCPDeferAcceptOption: + e.mu.Lock() + *o = tcpip.TCPDeferAcceptOption(e.deferAccept) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -2149,9 +2171,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. -func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { +func (e *endpoint) startAcceptedLoop() { e.mu.Lock() - e.waiterQueue = waiterQueue e.workerRunning = true e.mu.Unlock() wakerInitDone := make(chan struct{}) @@ -2177,7 +2198,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } - return n, n.waiterQueue, nil } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 7eb613be5..c9ee5bf06 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }) + }, queue) if err != nil { return nil, err } // Start the protocol goroutine. - ep.startAcceptedLoop(queue) + ep.startAcceptedLoop() return ep, nil } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index df2fb1071..a12336d47 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6787,3 +6787,129 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { ), ) } + +func TestTCPDeferAccept(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give a bit of time for the socket to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestTCPDeferAcceptTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Sleep for a little of the tcpDeferAccept timeout. + time.Sleep(tcpDeferAccept + 100*time.Millisecond) + + // On timeout expiry we should get a SYN-ACK retransmission. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give sometime for the endpoint to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 2f9821555..3bf7081b9 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -828,6 +828,164 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { EXPECT_EQ(get, kUserTimeout); } +// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not +// saved. Enable S/R once issue is fixed. +TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not + // saved. Enable S/R issue is fixed. + DisableSave ds; + + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the TCP_DEFER_ACCEPT on the listening socket. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Set the listening socket to nonblock so that we can verify that there is no + // connection in queue despite the connect above succeeding since the peer has + // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the + // FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Set FD back to blocking. + opts &= ~O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Now write some data to the socket. + int data = 0; + ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); + + // This should now cause the connection to complete and be delivered to the + // accept socket. + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Verify that the accepted socket returns the data written. + int get = -1; + ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0), + SyscallSucceedsWithValue(sizeof(get))); + + EXPECT_EQ(get, data); +} + +// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not +// saved. Enable S/R once issue is fixed. +TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) { + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not + // saved. Enable S/R once issue is fixed. + DisableSave ds; + + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the TCP_DEFER_ACCEPT on the listening socket. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Set the listening socket to nonblock so that we can verify that there is no + // connection in queue despite the connect above succeeding since the peer has + // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the + // FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT + // timeout is hit. + absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1)); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Set FD back to blocking. + opts &= ~O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout + // is hit a SYN-ACK should be retransmitted by the listener as a last ditch + // attempt to complete the connection with or without data. + absl::SleepFor(absl::Seconds(2)); + + // Verify that we have a connection that can be accepted even though no + // data was written. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 33a5ac66c..525ccbd88 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1286,6 +1286,59 @@ TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { EXPECT_EQ(get, kTCPUserTimeout); } +TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // -ve TCP_DEFER_ACCEPT is same as setting it to zero. + constexpr int kNeg = -1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)), + SyscallSucceeds()); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); +} + +TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); +} + +TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // kTCPDeferAccept is in seconds. + // NOTE: linux translates seconds to # of retries and back from + // #of retries to seconds. Which means only certain values + // translate back exactly. That's why we use 3 here, a value of + // 5 will result in us getting back 7 instead of 5 in the + // getsockopt. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPDeferAccept); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 665b614e4a6e715bac25bea15c5c29184016e549 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Tue, 4 Feb 2020 18:04:26 -0800 Subject: Support RTM_NEWADDR and RTM_GETLINK in (rt)netlink. PiperOrigin-RevId: 293271055 --- pkg/sentry/inet/inet.go | 4 + pkg/sentry/inet/test_stack.go | 6 + pkg/sentry/socket/hostinet/stack.go | 5 + pkg/sentry/socket/netlink/BUILD | 14 +- pkg/sentry/socket/netlink/message.go | 129 +++++++++++ pkg/sentry/socket/netlink/message_test.go | 312 +++++++++++++++++++++++++++ pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/route/BUILD | 2 - pkg/sentry/socket/netlink/route/protocol.go | 238 ++++++++++++++------ pkg/sentry/socket/netlink/socket.go | 54 ++--- pkg/sentry/socket/netlink/uevent/protocol.go | 2 +- pkg/sentry/socket/netstack/stack.go | 55 +++++ pkg/tcpip/stack/stack.go | 9 + test/syscalls/linux/BUILD | 2 + test/syscalls/linux/socket_netlink_route.cc | 296 ++++++++++++++++++++----- test/syscalls/linux/socket_netlink_util.cc | 45 +++- test/syscalls/linux/socket_netlink_util.h | 9 + 17 files changed, 1022 insertions(+), 162 deletions(-) create mode 100644 pkg/sentry/socket/netlink/message_test.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index a7dfb78a7..2916a0644 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -28,6 +28,10 @@ type Stack interface { // interface indexes to a slice of associated interface address properties. InterfaceAddrs() map[int32][]InterfaceAddr + // AddInterfaceAddr adds an address to the network interface identified by + // index. + AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index dcfcbd97e..d8961fc94 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -47,6 +47,12 @@ func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr { return s.InterfaceAddrsMap } +// AddInterfaceAddr implements Stack.AddInterfaceAddr. +func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { + s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr) + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 034eca676..a48082631 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -310,6 +310,11 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return addrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + return syserror.EACCES +} + // SupportsIPv6 implements inet.Stack.SupportsIPv6. func (s *Stack) SupportsIPv6() bool { return s.supportsIPv6 diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index f8b8e467d..1911cd9b8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -33,3 +33,15 @@ go_library( "//pkg/waiter", ], ) + +go_test( + name = "netlink_test", + size = "small", + srcs = [ + "message_test.go", + ], + deps = [ + ":netlink", + "//pkg/abi/linux", + ], +) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index b21e0ca4b..4ea252ccb 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -30,8 +30,16 @@ func alignUp(length int, align uint) int { return (length + int(align) - 1) &^ (int(align) - 1) } +// alignPad returns the length of padding required for alignment. +// +// Preconditions: align is a power of two. +func alignPad(length int, align uint) int { + return alignUp(length, align) - length +} + // Message contains a complete serialized netlink message. type Message struct { + hdr linux.NetlinkMessageHeader buf []byte } @@ -40,10 +48,86 @@ type Message struct { // The header length will be updated by Finalize. func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ + hdr: hdr, buf: binary.Marshal(nil, usermem.ByteOrder, hdr), } } +// ParseMessage parses the first message seen at buf, returning the rest of the +// buffer. If message is malformed, ok of false is returned. For last message, +// padding check is loose, if there isn't enought padding, whole buf is consumed +// and ok is set to true. +func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { + b := BytesView(buf) + + hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize) + if !ok { + return + } + var hdr linux.NetlinkMessageHeader + binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + + // Msg portion. + totalMsgLen := int(hdr.Length) + _, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize) + if !ok { + return + } + + // Padding. + numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO) + // Linux permits the last message not being aligned, just consume all of it. + // Ref: net/netlink/af_netlink.c:netlink_rcv_skb + if numPad > len(b) { + numPad = len(b) + } + _, ok = b.Extract(numPad) + if !ok { + return + } + + return &Message{ + hdr: hdr, + buf: buf[:totalMsgLen], + }, []byte(b), true +} + +// Header returns the header of this message. +func (m *Message) Header() linux.NetlinkMessageHeader { + return m.hdr +} + +// GetData unmarshals the payload message header from this netlink message, and +// returns the attributes portion. +func (m *Message) GetData(msg interface{}) (AttrsView, bool) { + b := BytesView(m.buf) + + _, ok := b.Extract(linux.NetlinkMessageHeaderSize) + if !ok { + return nil, false + } + + size := int(binary.Size(msg)) + msgBytes, ok := b.Extract(size) + if !ok { + return nil, false + } + binary.Unmarshal(msgBytes, usermem.ByteOrder, msg) + + numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) + // Linux permits the last message not being aligned, just consume all of it. + // Ref: net/netlink/af_netlink.c:netlink_rcv_skb + if numPad > len(b) { + numPad = len(b) + } + _, ok = b.Extract(numPad) + if !ok { + return nil, false + } + + return AttrsView(b), true +} + // Finalize returns the []byte containing the entire message, with the total // length set in the message header. The Message must not be modified after // calling Finalize. @@ -157,3 +241,48 @@ func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message { ms.Messages = append(ms.Messages, m) return m } + +// AttrsView is a view into the attributes portion of a netlink message. +type AttrsView []byte + +// Empty returns whether there is no attribute left in v. +func (v AttrsView) Empty() bool { + return len(v) == 0 +} + +// ParseFirst parses first netlink attribute at the beginning of v. +func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) { + b := BytesView(v) + + hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize) + if !ok { + return + } + binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + + value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) + if !ok { + return + } + + _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO)) + if !ok { + return + } + + return hdr, value, AttrsView(b), ok +} + +// BytesView supports extracting data from a byte slice with bounds checking. +type BytesView []byte + +// Extract removes the first n bytes from v and returns it. If n is out of +// bounds, it returns false. +func (v *BytesView) Extract(n int) ([]byte, bool) { + if n < 0 || n > len(*v) { + return nil, false + } + extracted := (*v)[:n] + *v = (*v)[n:] + return extracted, true +} diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go new file mode 100644 index 000000000..ef13d9386 --- /dev/null +++ b/pkg/sentry/socket/netlink/message_test.go @@ -0,0 +1,312 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message_test + +import ( + "bytes" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" +) + +type dummyNetlinkMsg struct { + Foo uint16 +} + +func TestParseMessage(t *testing.T) { + tests := []struct { + desc string + input []byte + + header linux.NetlinkMessageHeader + dataMsg *dummyNetlinkMsg + restLen int + ok bool + }{ + { + desc: "valid", + input: []byte{ + 0x14, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + header: linux.NetlinkMessageHeader{ + Length: 20, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "valid with next message", + input: []byte{ + 0x14, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + 0xFF, // Next message (rest) + }, + header: linux.NetlinkMessageHeader{ + Length: 20, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 1, + ok: true, + }, + { + desc: "valid for last message without padding", + input: []byte{ + 0x12, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, // Data message + }, + header: linux.NetlinkMessageHeader{ + Length: 18, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "valid for last message not to be aligned", + input: []byte{ + 0x13, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, // Data message + 0x00, // Excessive 1 byte permitted at end + }, + header: linux.NetlinkMessageHeader{ + Length: 19, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "header.Length too short", + input: []byte{ + 0x04, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + ok: false, + }, + { + desc: "header.Length too long", + input: []byte{ + 0xFF, 0xFF, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + ok: false, + }, + { + desc: "header incomplete", + input: []byte{ + 0x04, 0x00, 0x00, 0x00, // Length + }, + ok: false, + }, + { + desc: "empty message", + input: []byte{}, + ok: false, + }, + } + for _, test := range tests { + msg, rest, ok := netlink.ParseMessage(test.input) + if ok != test.ok { + t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) + continue + } + if !test.ok { + continue + } + if !reflect.DeepEqual(msg.Header(), test.header) { + t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) + } + + dataMsg := &dummyNetlinkMsg{} + _, dataOk := msg.GetData(dataMsg) + if !dataOk { + t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) + } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) + } + + if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) { + t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want) + } + } +} + +func TestAttrView(t *testing.T) { + tests := []struct { + desc string + input []byte + + // Outputs for ParseFirst. + hdr linux.NetlinkAttrHeader + value []byte + restLen int + ok bool + + // Outputs for Empty. + isEmpty bool + }{ + { + desc: "valid", + input: []byte{ + 0x06, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding + }, + hdr: linux.NetlinkAttrHeader{ + Length: 6, + Type: 1, + }, + value: []byte{0x30, 0x31}, + restLen: 0, + ok: true, + isEmpty: false, + }, + { + desc: "at alignment", + input: []byte{ + 0x08, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + hdr: linux.NetlinkAttrHeader{ + Length: 8, + Type: 1, + }, + value: []byte{0x30, 0x31, 0x32, 0x33}, + restLen: 0, + ok: true, + isEmpty: false, + }, + { + desc: "at alignment with rest data", + input: []byte{ + 0x08, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + 0xFF, 0xFE, // Rest data + }, + hdr: linux.NetlinkAttrHeader{ + Length: 8, + Type: 1, + }, + value: []byte{0x30, 0x31, 0x32, 0x33}, + restLen: 2, + ok: true, + isEmpty: false, + }, + { + desc: "hdr.Length too long", + input: []byte{ + 0xFF, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + ok: false, + isEmpty: false, + }, + { + desc: "hdr.Length too short", + input: []byte{ + 0x01, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + ok: false, + isEmpty: false, + }, + { + desc: "empty", + input: []byte{}, + ok: false, + isEmpty: true, + }, + } + for _, test := range tests { + attrs := netlink.AttrsView(test.input) + + // Test ParseFirst(). + hdr, value, rest, ok := attrs.ParseFirst() + if ok != test.ok { + t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) + } else if test.ok { + if !reflect.DeepEqual(hdr, test.hdr) { + t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr) + } + if !bytes.Equal(value, test.value) { + t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value) + } + if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) { + t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest) + } + } + + // Test Empty(). + if got, want := attrs.Empty(), test.isEmpty; got != want { + t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want) + } + } +} diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 07f860a49..b0dc70e5c 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -42,7 +42,7 @@ type Protocol interface { // If err == nil, any messages added to ms will be sent back to the // other end of the socket. Setting ms.Multi will cause an NLMSG_DONE // message to be sent even if ms contains no messages. - ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *MessageSet) *syserr.Error + ProcessMessage(ctx context.Context, msg *Message, ms *MessageSet) *syserr.Error } // Provider is a function that creates a new Protocol for a specific netlink diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 622a1eafc..93127398d 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -10,13 +10,11 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/socket/netlink", "//pkg/syserr", - "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 2b3c7f5b3..c84d8bd7c 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -17,16 +17,15 @@ package route import ( "bytes" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/usermem" ) // commandKind describes the operational class of a message type. @@ -69,13 +68,7 @@ func (p *Protocol) CanSend() bool { } // dumpLinks handles RTM_GETLINK dump requests. -func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { - // TODO(b/68878065): Only the dump variant of the types below are - // supported. - if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP { - return syserr.ErrNotSupported - } - +func (p *Protocol) dumpLinks(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an // ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some // userspace applications (including glibc) still include rtgenmsg. @@ -99,44 +92,105 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader return nil } - for id, i := range stack.Interfaces() { - m := ms.AddMessage(linux.NetlinkMessageHeader{ - Type: linux.RTM_NEWLINK, - }) + for idx, i := range stack.Interfaces() { + addNewLinkMessage(ms, idx, i) + } - m.Put(linux.InterfaceInfoMessage{ - Family: linux.AF_UNSPEC, - Type: i.DeviceType, - Index: id, - Flags: i.Flags, - }) + return nil +} - m.PutAttrString(linux.IFLA_IFNAME, i.Name) - m.PutAttr(linux.IFLA_MTU, i.MTU) +// getLinks handles RTM_GETLINK requests. +func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network devices. + return nil + } - mac := make([]byte, 6) - brd := mac - if len(i.Addr) > 0 { - mac = i.Addr - brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) + // Parse message. + var ifi linux.InterfaceInfoMessage + attrs, ok := msg.GetData(&ifi) + if !ok { + return syserr.ErrInvalidArgument + } + + // Parse attributes. + var byName []byte + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument } - m.PutAttr(linux.IFLA_ADDRESS, mac) - m.PutAttr(linux.IFLA_BROADCAST, brd) + attrs = rest - // TODO(gvisor.dev/issue/578): There are many more attributes. + switch ahdr.Type { + case linux.IFLA_IFNAME: + if len(value) < 1 { + return syserr.ErrInvalidArgument + } + byName = value[:len(value)-1] + + // TODO(gvisor.dev/issue/578): Support IFLA_EXT_MASK. + } } + found := false + for idx, i := range stack.Interfaces() { + switch { + case ifi.Index > 0: + if idx != ifi.Index { + continue + } + case byName != nil: + if string(byName) != i.Name { + continue + } + default: + // Criteria not specified. + return syserr.ErrInvalidArgument + } + + addNewLinkMessage(ms, idx, i) + found = true + break + } + if !found { + return syserr.ErrNoDevice + } return nil } -// dumpAddrs handles RTM_GETADDR dump requests. -func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { - // TODO(b/68878065): Only the dump variant of the types below are - // supported. - if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP { - return syserr.ErrNotSupported +// addNewLinkMessage appends RTM_NEWLINK message for the given interface into +// the message set. +func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_NEWLINK, + }) + + m.Put(linux.InterfaceInfoMessage{ + Family: linux.AF_UNSPEC, + Type: i.DeviceType, + Index: idx, + Flags: i.Flags, + }) + + m.PutAttrString(linux.IFLA_IFNAME, i.Name) + m.PutAttr(linux.IFLA_MTU, i.MTU) + + mac := make([]byte, 6) + brd := mac + if len(i.Addr) > 0 { + mac = i.Addr + brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) } + m.PutAttr(linux.IFLA_ADDRESS, mac) + m.PutAttr(linux.IFLA_BROADCAST, brd) + + // TODO(gvisor.dev/issue/578): There are many more attributes. +} +// dumpAddrs handles RTM_GETADDR dump requests. +func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // RTM_GETADDR dump requests need not contain anything more than the // netlink header and 1 byte protocol family common to all // NETLINK_ROUTE requests. @@ -168,6 +222,7 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader Index: uint32(id), }) + m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr)) m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) // TODO(gvisor.dev/issue/578): There are many more attributes. @@ -252,12 +307,12 @@ func fillRoute(routes []inet.Route, addr []byte) (inet.Route, *syserr.Error) { } // parseForDestination parses a message as format of RouteMessage-RtAttr-dst. -func parseForDestination(data []byte) ([]byte, *syserr.Error) { +func parseForDestination(msg *netlink.Message) ([]byte, *syserr.Error) { var rtMsg linux.RouteMessage - if len(data) < linux.SizeOfRouteMessage { + attrs, ok := msg.GetData(&rtMsg) + if !ok { return nil, syserr.ErrInvalidArgument } - binary.Unmarshal(data[:linux.SizeOfRouteMessage], usermem.ByteOrder, &rtMsg) // iproute2 added the RTM_F_LOOKUP_TABLE flag in version v4.4.0. See // commit bc234301af12. Note we don't check this flag for backward // compatibility. @@ -265,26 +320,15 @@ func parseForDestination(data []byte) ([]byte, *syserr.Error) { return nil, syserr.ErrNotSupported } - data = data[linux.SizeOfRouteMessage:] - - // TODO(gvisor.dev/issue/1611): Add generic attribute parsing. - var rtAttr linux.RtAttr - if len(data) < linux.SizeOfRtAttr { - return nil, syserr.ErrInvalidArgument + // Expect first attribute is RTA_DST. + if hdr, value, _, ok := attrs.ParseFirst(); ok && hdr.Type == linux.RTA_DST { + return value, nil } - binary.Unmarshal(data[:linux.SizeOfRtAttr], usermem.ByteOrder, &rtAttr) - if rtAttr.Type != linux.RTA_DST { - return nil, syserr.ErrInvalidArgument - } - - if len(data) < int(rtAttr.Len) { - return nil, syserr.ErrInvalidArgument - } - return data[linux.SizeOfRtAttr:rtAttr.Len], nil + return nil, syserr.ErrInvalidArgument } // dumpRoutes handles RTM_GETROUTE requests. -func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // RTM_GETROUTE dump requests need not contain anything more than the // netlink header and 1 byte protocol family common to all // NETLINK_ROUTE requests. @@ -295,10 +339,11 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade return nil } + hdr := msg.Header() routeTables := stack.RouteTable() if hdr.Flags == linux.NLM_F_REQUEST { - dst, err := parseForDestination(data) + dst, err := parseForDestination(msg) if err != nil { return err } @@ -357,10 +402,55 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade return nil } +// newAddr handles RTM_NEWADDR requests. +func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err == syscall.EEXIST { + flags := msg.Header().Flags + if flags&linux.NLM_F_EXCL != 0 { + return syserr.ErrExists + } + } else if err != nil { + return syserr.ErrInvalidArgument + } + } + } + return nil +} + // ProcessMessage implements netlink.Protocol.ProcessMessage. -func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + hdr := msg.Header() + // All messages start with a 1 byte protocol family. - if len(data) < 1 { + var family uint8 + if _, ok := msg.GetData(&family); !ok { // Linux ignores messages missing the protocol family. See // net/core/rtnetlink.c:rtnetlink_rcv_msg. return nil @@ -374,16 +464,32 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH } } - switch hdr.Type { - case linux.RTM_GETLINK: - return p.dumpLinks(ctx, hdr, data, ms) - case linux.RTM_GETADDR: - return p.dumpAddrs(ctx, hdr, data, ms) - case linux.RTM_GETROUTE: - return p.dumpRoutes(ctx, hdr, data, ms) - default: - return syserr.ErrNotSupported + if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP { + // TODO(b/68878065): Only the dump variant of the types below are + // supported. + switch hdr.Type { + case linux.RTM_GETLINK: + return p.dumpLinks(ctx, msg, ms) + case linux.RTM_GETADDR: + return p.dumpAddrs(ctx, msg, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, msg, ms) + default: + return syserr.ErrNotSupported + } + } else if hdr.Flags&linux.NLM_F_REQUEST == linux.NLM_F_REQUEST { + switch hdr.Type { + case linux.RTM_GETLINK: + return p.getLink(ctx, msg, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, msg, ms) + case linux.RTM_NEWADDR: + return p.newAddr(ctx, msg, ms) + default: + return syserr.ErrNotSupported + } } + return syserr.ErrNotSupported } // init registers the NETLINK_ROUTE provider. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index c4b95debb..2ca02567d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -644,47 +644,38 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error return nil } -func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error { +func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ Error: int32(-err.ToLinux().Number()), Header: hdr, }) - return nil +} +func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.NLMSG_ERROR, + }) + m.Put(linux.NetlinkErrorMessage{ + Error: 0, + Header: hdr, + }) } // processMessages handles each message in buf, passing it to the protocol // handler for final handling. func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error { for len(buf) > 0 { - if len(buf) < linux.NetlinkMessageHeaderSize { + msg, rest, ok := ParseMessage(buf) + if !ok { // Linux ignores messages that are too short. See // net/netlink/af_netlink.c:netlink_rcv_skb. break } - - var hdr linux.NetlinkMessageHeader - binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr) - - if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) { - // Linux ignores malformed messages. See - // net/netlink/af_netlink.c:netlink_rcv_skb. - break - } - - // Data from this message. - data := buf[linux.NetlinkMessageHeaderSize:hdr.Length] - - // Advance to the next message. - next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO) - if next >= len(buf)-1 { - next = len(buf) - 1 - } - buf = buf[next:] + buf = rest + hdr := msg.Header() // Ignore control messages. if hdr.Type < linux.NLMSG_MIN_TYPE { @@ -692,19 +683,10 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error } ms := NewMessageSet(s.portID, hdr.Seq) - var err *syserr.Error - // TODO(b/68877377): ACKs not supported yet. - if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { - err = syserr.ErrNotSupported - } else { - - err = s.protocol.ProcessMessage(ctx, hdr, data, ms) - } - if err != nil { - ms = NewMessageSet(s.portID, hdr.Seq) - if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil { - return err - } + if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil { + dumpErrorMesage(hdr, ms, err) + } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { + dumpAckMesage(hdr, ms) } if err := s.sendResponse(ctx, ms); err != nil { diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go index 1ee4296bc..029ba21b5 100644 --- a/pkg/sentry/socket/netlink/uevent/protocol.go +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -49,7 +49,7 @@ func (p *Protocol) CanSend() bool { } // ProcessMessage implements netlink.Protocol.ProcessMessage. -func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // Silently ignore all messages. return nil } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 31ea66eca..0692482e9 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -20,6 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -88,6 +90,59 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + var ( + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + ) + switch addr.Family { + case linux.AF_INET: + if len(addr.Addr) < header.IPv4AddressSize { + return syserror.EINVAL + } + if addr.PrefixLen > header.IPv4AddressSize*8 { + return syserror.EINVAL + } + protocol = ipv4.ProtocolNumber + address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) + + case linux.AF_INET6: + if len(addr.Addr) < header.IPv6AddressSize { + return syserror.EINVAL + } + if addr.PrefixLen > header.IPv6AddressSize*8 { + return syserror.EINVAL + } + protocol = ipv6.ProtocolNumber + address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) + + default: + return syserror.ENOTSUP + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: int(addr.PrefixLen), + }, + } + + // Attach address to interface. + if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network. + s.Stack.AddRoute(tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: tcpip.NICID(idx), + }) + return nil +} + // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { var rs tcp.ReceiveBufferSizeOption diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7057b110e..b793f1d74 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -795,6 +795,8 @@ func (s *Stack) Forwarding() bool { // SetRouteTable assigns the route table to be used by this stack. It // specifies which NIC to use for given destination address ranges. +// +// This method takes ownership of the table. func (s *Stack) SetRouteTable(table []tcpip.Route) { s.mu.Lock() defer s.mu.Unlock() @@ -809,6 +811,13 @@ func (s *Stack) GetRouteTable() []tcpip.Route { return append([]tcpip.Route(nil), s.routeTable...) } +// AddRoute appends a route to the route table. +func (s *Stack) AddRoute(route tcpip.Route) { + s.mu.Lock() + defer s.mu.Unlock() + s.routeTable = append(s.routeTable, route) +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 273b014d6..f2e3c7072 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2769,9 +2769,11 @@ cc_binary( deps = [ ":socket_netlink_util", ":socket_test_util", + "//test/util:capability_util", "//test/util:cleanup", "//test/util:file_descriptor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", gtest, "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 1e28e658d..e5aed1eec 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -25,8 +26,10 @@ #include "gtest/gtest.h" #include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/test_util.h" @@ -38,6 +41,8 @@ namespace testing { namespace { +constexpr uint32_t kSeq = 12345; + using ::testing::AnyOf; using ::testing::Eq; @@ -113,58 +118,224 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { // TODO(mpratt): Check ifinfomsg contents and following attrs. } +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + TEST(NetlinkRouteTest, GetLinkDump) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); + // Loopback is common among all tests, check that it's found. + bool loopbackFound = false; + ASSERT_NO_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + CheckGetLinkResponse(hdr, kSeq, port); + if (hdr->nlmsg_type != RTM_NEWLINK) { + return; + } + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + std::cout << "Found interface idx=" << msg->ifi_index + << ", type=" << std::hex << msg->ifi_type; + if (msg->ifi_type == ARPHRD_LOOPBACK) { + loopbackFound = true; + EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); + } + })); + EXPECT_TRUE(loopbackFound); +} + +struct Link { + int index; + std::string name; +}; + +PosixErrorOr> FindLoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + absl::optional link; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (msg->ifi_type == ARPHRD_LOOPBACK) { + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + link = Link(); + link->index = msg->ifi_index; + link->name = std::string(reinterpret_cast(RTA_DATA(rta))); + } + })); + return link; +} + +// CheckLinkMsg checks a netlink message against an expected link. +void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { + ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK)); + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + EXPECT_EQ(msg->ifi_index, link.index); + + const struct rtattr* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + EXPECT_NE(nullptr, rta) << "IFLA_IFNAME not found in message."; + if (rta != nullptr) { + std::string name(reinterpret_cast(RTA_DATA(rta))); + EXPECT_EQ(name, link.name); + } +} + +TEST(NetlinkRouteTest, GetLinkByIndex) { + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_flags = NLM_F_REQUEST; req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = loopback_link->index; - // Loopback is common among all tests, check that it's found. - bool loopbackFound = false; + bool found = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckGetLinkResponse(hdr, kSeq, port); - if (hdr->nlmsg_type != RTM_NEWLINK) { - return; - } - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); - const struct ifinfomsg* msg = - reinterpret_cast(NLMSG_DATA(hdr)); - std::cout << "Found interface idx=" << msg->ifi_index - << ", type=" << std::hex << msg->ifi_type; - if (msg->ifi_type == ARPHRD_LOOPBACK) { - loopbackFound = true; - EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); - } + CheckLinkMsg(hdr, *loopback_link); + found = true; }, false)); - EXPECT_TRUE(loopbackFound); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; } -TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { +TEST(NetlinkRouteTest, GetLinkByName) { + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; }; - constexpr uint32_t kSeq = 12345; + struct request req = {}; + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1); + strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + bool found = false; + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + CheckLinkMsg(hdr, *loopback_link); + found = true; + }, + false)); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; +} + +TEST(NetlinkRouteTest, GetLinkByIndexNotFound) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = 1234590; + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, ::testing::_)); +} + +TEST(NetlinkRouteTest, GetLinkByNameNotFound) { + const std::string name = "nodevice?!"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(name.size() + 1); + strncpy(req.ifname, name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, ::testing::_)); +} + +TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; struct request req = {}; req.hdr.nlmsg_len = sizeof(req); @@ -175,18 +346,8 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR)); - EXPECT_EQ(hdr->nlmsg_seq, kSeq); - EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); - - const struct nlmsgerr* msg = - reinterpret_cast(NLMSG_DATA(hdr)); - EXPECT_EQ(msg->error, -EOPNOTSUPP); - }, - true)); + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); } TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { @@ -198,8 +359,6 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; @@ -238,8 +397,6 @@ TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; @@ -282,8 +439,6 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; // This control message is ignored. We still receive a response for the @@ -317,8 +472,6 @@ TEST(NetlinkRouteTest, GetAddrDump) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -367,6 +520,57 @@ TEST(NetlinkRouteTest, LookupAll) { ASSERT_GT(count, 0); } +TEST(NetlinkRouteTest, AddAddr) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifa; + struct rtattr rtattr; + struct in_addr addr; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_seq = kSeq; + req.ifa.ifa_family = AF_INET; + req.ifa.ifa_prefixlen = 24; + req.ifa.ifa_flags = 0; + req.ifa.ifa_scope = 0; + req.ifa.ifa_index = loopback_link->index; + req.rtattr.rta_type = IFA_LOCAL; + req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); + inet_pton(AF_INET, "10.0.0.1", &req.addr); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len); + + // Create should succeed, as no such address in kernel. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK; + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + + // Replace an existing address should succeed. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; + req.hdr.nlmsg_seq++; + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + + // Create exclusive should fail, as we created the address above. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK; + req.hdr.nlmsg_seq++; + EXPECT_THAT( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len), + PosixErrorIs(EEXIST, ::testing::_)); +} + // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. TEST(NetlinkRouteTest, GetRouteDump) { FileDescriptor fd = @@ -378,8 +582,6 @@ TEST(NetlinkRouteTest, GetRouteDump) { struct rtmsg rtm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETROUTE; @@ -538,8 +740,6 @@ TEST(NetlinkRouteTest, RecvmsgTrunc) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -615,8 +815,6 @@ TEST(NetlinkRouteTest, RecvmsgTruncPeek) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -695,8 +893,6 @@ TEST(NetlinkRouteTest, NoPasscredNoCreds) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -743,8 +939,6 @@ TEST(NetlinkRouteTest, PasscredCreds) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index cd2212a1a..952eecfe8 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -71,9 +72,10 @@ PosixError NetlinkRequestResponse( iov.iov_base = buf.data(); iov.iov_len = buf.size(); - // Response is a series of NLM_F_MULTI messages, ending with a NLMSG_DONE - // message. + // If NLM_F_MULTI is set, response is a series of messages that ends with a + // NLMSG_DONE message. int type = -1; + int flags = 0; do { int len; RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); @@ -89,6 +91,7 @@ PosixError NetlinkRequestResponse( for (struct nlmsghdr* hdr = reinterpret_cast(buf.data()); NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { fn(hdr); + flags = hdr->nlmsg_flags; type = hdr->nlmsg_type; // Done should include an integer payload for dump_done_errno. // See net/netlink/af_netlink.c:netlink_dump @@ -98,11 +101,11 @@ PosixError NetlinkRequestResponse( EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int))); } } - } while (type != NLMSG_DONE && type != NLMSG_ERROR); + } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR); if (expect_nlmsgerr) { EXPECT_EQ(type, NLMSG_ERROR); - } else { + } else if (flags & NLM_F_MULTI) { EXPECT_EQ(type, NLMSG_DONE); } return NoError(); @@ -146,5 +149,39 @@ PosixError NetlinkRequestResponseSingle( return NoError(); } +PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, + void* request, size_t len) { + // Dummy negative number for no error message received. + // We won't get a negative error number so there will be no confusion. + int err = -42; + RETURN_IF_ERRNO(NetlinkRequestResponse( + fd, request, len, + [&](const struct nlmsghdr* hdr) { + EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type); + EXPECT_EQ(hdr->nlmsg_seq, seq); + EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); + + const struct nlmsgerr* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + err = -msg->error; + }, + true)); + return PosixError(err); +} + +const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, + const struct ifinfomsg* msg, int16_t attr) { + const int ifi_space = NLMSG_SPACE(sizeof(*msg)); + int attrlen = hdr->nlmsg_len - ifi_space; + const struct rtattr* rta = reinterpret_cast( + reinterpret_cast(hdr) + NLMSG_ALIGN(ifi_space)); + for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) { + if (rta->rta_type == attr) { + return rta; + } + } + return nullptr; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index 3678c0599..e13ead406 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -19,6 +19,7 @@ // socket.h has to be included before if_arp.h. #include #include +#include #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -47,6 +48,14 @@ PosixError NetlinkRequestResponseSingle( const FileDescriptor& fd, void* request, size_t len, const std::function& fn); +// Send the passed request then expect and return an ack or error. +PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, + void* request, size_t len); + +// Find rtnetlink attribute in message. +const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, + const struct ifinfomsg* msg, int16_t attr); + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From f3d95607036b8a502c65aa7b3e8145227274dbbc Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Wed, 5 Feb 2020 17:56:00 -0800 Subject: recv() on a closed TCP socket returns ENOTCONN From RFC 793 s3.9 p58 Event Processing: If RECEIVE Call arrives in CLOSED state and the user has access to such a connection, the return should be "error: connection does not exist" Fixes #1598 PiperOrigin-RevId: 293494287 --- pkg/sentry/socket/netstack/netstack.go | 7 ++++++- pkg/tcpip/tcpip.go | 4 ++++ pkg/tcpip/transport/tcp/endpoint.go | 4 ++-- pkg/tcpip/transport/tcp/tcp_test.go | 9 ++++----- test/syscalls/linux/tcp_socket.cc | 9 +++++++++ 5 files changed, 25 insertions(+), 8 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 049d04bf2..ed2fbcceb 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2229,11 +2229,16 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq var copied int // Copy as many views as possible into the user-provided buffer. - for dst.NumBytes() != 0 { + for { + // Always do at least one fetchReadView, even if the number of bytes to + // read is 0. err = s.fetchReadView() if err != nil { break } + if dst.NumBytes() == 0 { + break + } var n int var e error diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0fa141d58..d29d9a704 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1124,6 +1124,10 @@ type ReadErrors struct { // InvalidEndpointState is the number of times we found the endpoint state // to be unexpected. InvalidEndpointState StatCounter + + // NotConnected is the number of times we tried to read but found that the + // endpoint was not connected. + NotConnected StatCounter } // WriteErrors collects packet write errors from an endpoint write call. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b5a8e15ee..e4a6b1b8b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1003,8 +1003,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } - e.stats.ReadErrors.InvalidEndpointState.Increment() - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + e.stats.ReadErrors.NotConnected.Increment() + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected } v, err := e.readLocked() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2c1505067..cc118c993 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5405,12 +5405,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } - // Expect InvalidEndpointState errors on a read at this point. - if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { + t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected) } - if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { - t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { + t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1) } if err := ep.Listen(10); err != nil { diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 525ccbd88..8a8b68e75 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1339,6 +1339,15 @@ TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) { EXPECT_EQ(get, kTCPDeferAccept); } +TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) { + auto s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + char buf[1]; + EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(ENOTCONN)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 69bf39e8a47d3b4dcbbd04d2e8df476cdfab5e74 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 13 Feb 2020 10:58:47 -0800 Subject: Internal change. PiperOrigin-RevId: 294952610 --- pkg/abi/linux/socket.go | 13 ++++ pkg/sentry/socket/control/BUILD | 1 + pkg/sentry/socket/control/control.go | 43 +++++++++++++ pkg/sentry/socket/hostinet/socket.go | 11 +++- pkg/sentry/socket/netstack/netstack.go | 37 ++++++++++-- pkg/tcpip/tcpip.go | 25 ++++++++ pkg/tcpip/transport/udp/endpoint.go | 26 ++++++++ test/syscalls/linux/socket_ip_udp_generic.cc | 44 ++++++++++++++ test/syscalls/linux/socket_ipv4_udp_unbound.cc | 84 ++++++++++++++++++++++++++ test/syscalls/linux/udp_socket_test_cases.cc | 1 - 10 files changed, 278 insertions(+), 7 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 766ee4014..4a14ef691 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -411,6 +411,15 @@ type ControlMessageCredentials struct { GID uint32 } +// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message. +// +// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. +type ControlMessageIPPacketInfo struct { + NIC int32 + LocalAddr InetAddr + DestinationAddr InetAddr +} + // SizeOfControlMessageCredentials is the binary size of a // ControlMessageCredentials struct. var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{})) @@ -431,6 +440,10 @@ const SizeOfControlMessageTOS = 1 // SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message. const SizeOfControlMessageTClass = 4 +// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO +// control message. +const SizeOfControlMessageIPPacketInfo = 12 + // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 79e16d6e8..4d42d29cb 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", "//pkg/syserror", + "//pkg/tcpip", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 6145a7fc3..4667373d2 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -338,6 +339,22 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { ) } +// PackIPPacketInfo packs an IP_PKTINFO socket control message. +func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + + return putCmsgStruct( + buf, + linux.SOL_IP, + linux.IP_PKTINFO, + t.Arch().Width(), + p, + ) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. @@ -362,6 +379,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackTClass(t, cmsgs.IP.TClass, buf) } + if cmsgs.IP.HasIPPacketInfo { + buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + } + return buf } @@ -394,6 +415,16 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { return space } +// NewIPPacketInfo returns the IPPacketInfo struct. +func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { + var p tcpip.IPPacketInfo + p.NIC = tcpip.NICID(packetInfo.NIC) + copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) + copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + + return p +} + // Parse parses a raw socket control message into portable objects. func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( @@ -468,6 +499,18 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) i += binary.AlignUp(length, width) + case linux.IP_PKTINFO: + if length < linux.SizeOfControlMessageIPPacketInfo { + return socket.ControlMessages{}, syserror.EINVAL + } + + cmsgs.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + + cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index de76388ac..22f78d2e2 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -289,7 +289,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: optlen = sizeofInt32 } case linux.SOL_IPV6: @@ -336,6 +336,8 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ switch name { case linux.IP_TOS, linux.IP_RECVTOS: optlen = sizeofInt32 + case linux.IP_PKTINFO: + optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { @@ -473,7 +475,14 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags case syscall.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + + case syscall.IP_PKTINFO: + controlMessages.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) } + case syscall.SOL_IPV6: switch unixCmsg.Header.Type { case syscall.IPV6_TCLASS: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index ed2fbcceb..9757fbfba 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1414,6 +1414,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return o, nil + case linux.IP_PKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + var o int32 + if v { + o = 1 + } + return o, nil + default: emitUnimplementedEventIP(t, name) } @@ -1762,6 +1777,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) linux.IPV6_IPSEC_POLICY, linux.IPV6_JOIN_ANYCAST, linux.IPV6_LEAVE_ANYCAST, + // TODO(b/148887420): Add support for IPV6_PKTINFO. linux.IPV6_PKTINFO, linux.IPV6_ROUTER_ALERT, linux.IPV6_XFRM_POLICY, @@ -1949,6 +1965,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + case linux.IP_PKTINFO: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1964,7 +1990,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_PKTINFO, linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, @@ -2395,10 +2420,12 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe func (s *SocketOperations) controlMessages() socket.ControlMessages { return socket.ControlMessages{ IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, }, } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0e944712f..9ca39ce40 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -328,6 +328,12 @@ type ControlMessages struct { // Tclass is the IPv6 traffic class of the associated packet. TClass int32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo IPPacketInfo } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -503,6 +509,11 @@ const ( // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. V6OnlyOption + + // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify + // if more inforamtion is provided with incoming packets such + // as interface index and address. + ReceiveIPPacketInfoOption ) // SockOptInt represents socket options which values have the int type. @@ -685,6 +696,20 @@ type IPv4TOSOption uint8 // for all subsequent outgoing IPv6 packets from the endpoint. type IPv6TrafficClassOption uint8 +// IPPacketInfo is the message struture for IP_PKTINFO. +// +// +stateify savable +type IPPacketInfo struct { + // NIC is the ID of the NIC to be used. + NIC NICID + + // LocalAddr is the local address. + LocalAddr Address + + // DestinationAddr is the destination address. + DestinationAddr Address +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c9cbed8f4..3fe91cac2 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -29,6 +29,7 @@ import ( type udpPacket struct { udpPacketEntry senderAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 tos uint8 @@ -118,6 +119,9 @@ type endpoint struct { // as ancillary data to ControlMessages on Read. receiveTOS bool + // receiveIPPacketInfo determines if the packet info is returned by Read. + receiveIPPacketInfo bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -254,11 +258,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess } e.mu.RLock() receiveTOS := e.receiveTOS + receiveIPPacketInfo := e.receiveIPPacketInfo e.mu.RUnlock() if receiveTOS { cm.HasTOS = true cm.TOS = p.tos } + + if receiveIPPacketInfo { + cm.HasIPPacketInfo = true + cm.PacketInfo = p.packetInfo + } return p.data.ToView(), cm, nil } @@ -495,6 +505,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { } e.v6only = v + return nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + return nil } return nil @@ -703,6 +720,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil } return false, tcpip.ErrUnknownProtocolOption @@ -1247,6 +1270,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk switch r.NetProto { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.RemoteAddress + packet.packetInfo.NIC = r.NICID() } packet.timestamp = e.stack.NowNanoseconds() diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 53290bed7..db5663ecd 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -357,5 +357,49 @@ TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) { EXPECT_EQ(get, kSockOptOn); } +// Test getsockopt for a socket which is not set with IP_PKTINFO option. +TEST_P(UDPSocketPairTest, IPPKTINFODefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test setsockopt and getsockopt for a socket with IP_PKTINFO option. +TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int level = SOL_IP; + int type = IP_PKTINFO; + + // Check getsockopt before IP_PKTINFO is set. + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOn); + EXPECT_EQ(get_len, sizeof(get)); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOff); + EXPECT_EQ(get_len, sizeof(get)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 990ccf23c..bc4b07a62 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -15,6 +15,7 @@ #include "test/syscalls/linux/socket_ipv4_udp_unbound.h" #include +#include #include #include #include @@ -2128,5 +2129,88 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { SyscallSucceedsWithValue(kMessageSize)); } +// Test that socket will receive packet info control message. +TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF((IsRunningWithHostinet())); + + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto sender_addr = V4Loopback(); + int level = SOL_IP; + int type = IP_PKTINFO; + + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + auto receiver_addr = V4Loopback(); + reinterpret_cast(&receiver_addr.addr)->sin_port = + reinterpret_cast(&sender_addr.addr)->sin_port; + ASSERT_THAT( + connect(sender->get(), reinterpret_cast(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + + // Allow socket to receive control message. + ASSERT_THAT( + setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + msghdr sent_msg = {}; + iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = sent_data; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + sent_msg.msg_flags = 0; + + ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + msghdr received_msg = {}; + iovec received_iov = {}; + char received_data[kDataLength]; + char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; + size_t cmsg_data_len = sizeof(in_pktinfo); + received_iov.iov_base = received_data; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + received_msg.msg_control = received_cmsg_buf; + + ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, level); + EXPECT_EQ(cmsg->cmsg_type, type); + + // Get loopback index. + ifreq ifr = {}; + absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); + ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds()); + ASSERT_NE(ifr.ifr_ifindex, 0); + + // Check the data + in_pktinfo received_pktinfo = {}; + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK)); +} } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index a2f6ef8cc..9f8de6b48 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1495,6 +1495,5 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) { memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); EXPECT_EQ(received_tos, sent_tos); } - } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 56fd9504aab44a738d3df164cbee8e572b309f28 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 18 Feb 2020 15:44:22 -0800 Subject: Enable IPV6_RECVTCLASS socket option for datagram sockets Added the ability to get/set the IP_RECVTCLASS socket option on UDP endpoints. If enabled, traffic class from the incoming Network Header passed as ancillary data in the ControlMessages. Adding Get/SetSockOptBool to decrease the overhead of getting/setting simple options. (This was absorbed in a CL that will be landing before this one). Test: * Added unit test to udp_test.go that tests getting/setting as well as verifying that we receive expected TOS from incoming packet. * Added a syscall test for verifying getting/setting * Removed test skip for existing syscall test to enable end to end test. PiperOrigin-RevId: 295840218 --- pkg/sentry/socket/control/control.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 27 +++++- pkg/tcpip/checker/checker.go | 14 +++ pkg/tcpip/tcpip.go | 15 ++- pkg/tcpip/transport/udp/endpoint.go | 38 +++++++- pkg/tcpip/transport/udp/udp_test.go | 120 ++++++++++++++---------- test/syscalls/linux/ip_socket_test_util.h | 16 ++-- test/syscalls/linux/socket_ip_udp_generic.cc | 133 +++++++++++++++++++-------- test/syscalls/linux/udp_socket_test_cases.cc | 4 - 9 files changed, 260 insertions(+), 109 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4667373d2..8834a1e1a 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -329,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { } // PackTClass packs an IPV6_TCLASS socket control message. -func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { +func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IPV6, diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9757fbfba..e187276c5 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1318,6 +1318,22 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } return ib, nil + case linux.IPV6_RECVTCLASS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + var o int32 + if v { + o = 1 + } + return o, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1803,6 +1819,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + case linux.IPV6_RECVTCLASS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + default: emitUnimplementedEventIPv6(t, name) } @@ -2086,7 +2110,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, - linux.IPV6_RECVTCLASS, linux.IPV6_RTHDR, linux.IPV6_RTHDRDSTOPTS, linux.IPV6_TCLASS, @@ -2424,6 +2447,8 @@ func (s *SocketOperations) controlMessages() socket.ControlMessages { Timestamp: s.readCM.Timestamp, HasTOS: s.readCM.HasTOS, TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, HasIPPacketInfo: s.readCM.HasIPPacketInfo, PacketInfo: s.readCM.PacketInfo, }, diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 4d6ae0871..c6c160dfc 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -161,6 +161,20 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTClass creates a checker that checks the TCLASS field in +// ControlMessages. +func ReceiveTClass(want uint32) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTClass { + t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) + } + if got := cm.TClass; got != want { + t.Fatalf("got cm.TClass = %d, want %d", got, want) + } + } +} + // ReceiveTOS creates a checker that checks the TOS field in ControlMessages. func ReceiveTOS(want uint8) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 9ca39ce40..ce5527391 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -323,11 +323,11 @@ type ControlMessages struct { // TOS is the IPv4 type of service of the associated packet. TOS uint8 - // HasTClass indicates whether Tclass is valid/set. + // HasTClass indicates whether TClass is valid/set. HasTClass bool - // Tclass is the IPv6 traffic class of the associated packet. - TClass int32 + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 // HasIPPacketInfo indicates whether PacketInfo is set. HasIPPacketInfo bool @@ -502,9 +502,13 @@ type WriteOptions struct { type SockOptBool int const ( + // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the + // IPV6_TCLASS ancillary message is passed with incoming packets. + ReceiveTClassOption SockOptBool = iota + // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS // ancillary message is passed with incoming packets. - ReceiveTOSOption SockOptBool = iota + ReceiveTOSOption // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. @@ -514,6 +518,9 @@ const ( // if more inforamtion is provided with incoming packets such // as interface index and address. ReceiveIPPacketInfoOption + + // TODO(b/146901447): convert existing bool socket options to be handled via + // Get/SetSockOptBool ) // SockOptInt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3fe91cac2..eff7f3600 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -32,7 +32,8 @@ type udpPacket struct { packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - tos uint8 + // tos stores either the receiveTOS or receiveTClass value. + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -119,6 +120,10 @@ type endpoint struct { // as ancillary data to ControlMessages on Read. receiveTOS bool + // receiveTClass determines if the incoming IPv6 TClass header field is + // passed as ancillary data to ControlMessages on Read. + receiveTClass bool + // receiveIPPacketInfo determines if the packet info is returned by Read. receiveIPPacketInfo bool @@ -258,13 +263,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess } e.mu.RLock() receiveTOS := e.receiveTOS + receiveTClass := e.receiveTClass receiveIPPacketInfo := e.receiveIPPacketInfo e.mu.RUnlock() if receiveTOS { cm.HasTOS = true cm.TOS = p.tos } - + if receiveTClass { + cm.HasTClass = true + // Although TClass is an 8-bit value it's read in the CMsg as a uint32. + cm.TClass = uint32(p.tos) + } if receiveIPPacketInfo { cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo @@ -490,6 +500,17 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrNotSupported + } + + e.mu.Lock() + e.receiveTClass = v + e.mu.Unlock() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -709,6 +730,17 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrNotSupported + } + + e.mu.RLock() + v := e.receiveTClass + e.mu.RUnlock() + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1273,6 +1305,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk packet.packetInfo.LocalAddr = r.LocalAddress packet.packetInfo.DestinationAddr = r.RemoteAddress packet.packetInfo.NIC = r.NICID() + case header.IPv6ProtocolNumber: + packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } packet.timestamp = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index f0ff3fe71..34b7c2360 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -409,6 +409,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, @@ -1336,7 +1337,7 @@ func TestSetTTL(t *testing.T) { } } -func TestTOSV4(t *testing.T) { +func TestSetTOS(t *testing.T) { for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) @@ -1347,23 +1348,23 @@ func TestTOSV4(t *testing.T) { const tos = testTOS var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) } if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err) } if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) } testWrite(c, flow, checker.TOS(tos, 0)) @@ -1371,7 +1372,7 @@ func TestTOSV4(t *testing.T) { } } -func TestTOSV6(t *testing.T) { +func TestSetTClass(t *testing.T) { for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) @@ -1379,71 +1380,92 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = testTOS + const tClass = testTOS var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - c.t.Errorf("SetSockOpt failed: %s", err) + if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil { + c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err) } if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if want := tcpip.IPv6TrafficClassOption(tClass); v != want { + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) } - testWrite(c, flow, checker.TOS(tos, 0)) + // The header getter for TClass is called TOS, so use that checker. + testWrite(c, flow, checker.TOS(tClass, 0)) }) } } -func TestReceiveTOSV4(t *testing.T) { - for _, flow := range []testFlow{unicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() +func TestReceiveTosTClass(t *testing.T) { + testCases := []struct { + name string + getReceiveOption tcpip.SockOptBool + tests []testFlow + }{ + {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, + {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + } + for _, testCase := range testCases { + for _, flow := range testCase.tests { + t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpointForFlow(flow) + c.createEndpointForFlow(flow) + option := testCase.getReceiveOption + name := testCase.name - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) - } - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false) - } + // Verify that setting and reading the option works. + v, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) + } - want := true - if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil { - c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err) - } + want := true + if err := c.ep.SetSockOptBool(option, want); err != nil { + c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) + } - got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) - } - if got != want { - c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want) - } + got, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + } - // Verify that the correct received TOS is handed through as - // ancillary data to the ControlMessages struct. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - testRead(c, flow, checker.ReceiveTOS(testTOS)) - }) + if got != want { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) + } + + // Verify that the correct received TOS or TClass is handed through as + // ancillary data to the ControlMessages struct. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + switch option { + case tcpip.ReceiveTClassOption: + testRead(c, flow, checker.ReceiveTClass(testTOS)) + case tcpip.ReceiveTOSOption: + testRead(c, flow, checker.ReceiveTOS(testTOS)) + default: + t.Fatalf("unknown test variant: %s", name) + } + }) + } } } diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 083ebbcf0..39fd6709d 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -84,20 +84,20 @@ SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type); // SocketPairs created with AF_INET and the given type. SocketPairKind IPv4UDPUnboundSocketPair(int type); -// IPv4UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type. +// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_DGRAM, and the given type. SocketKind IPv4UDPUnboundSocket(int type); -// IPv6UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type. +// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_DGRAM, and the given type. SocketKind IPv6UDPUnboundSocket(int type); -// IPv4TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type. +// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_STREAM and the given type. SocketKind IPv4TCPUnboundSocket(int type); -// IPv6TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type. +// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_STREAM and the given type. SocketKind IPv6TCPUnboundSocket(int type); // IfAddrHelper is a helper class that determines the local interfaces present diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index db5663ecd..1c533fdf2 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -14,6 +14,7 @@ #include "test/syscalls/linux/socket_ip_udp_generic.h" +#include #include #include #include @@ -209,46 +210,6 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } -// Ensure that Receiving TOS is off by default. -TEST_P(UDPSocketPairTest, RecvTosDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -// Test that setting and getting IP_RECVTOS works as expected. -TEST_P(UDPSocketPairTest, SetRecvTos) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -401,5 +362,97 @@ TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { EXPECT_EQ(get_len, sizeof(get)); } +// Holds TOS or TClass information for IPv4 or IPv6 respectively. +struct RecvTosOption { + int level; + int option; +}; + +RecvTosOption GetRecvTosOption(int domain) { + TEST_CHECK(domain == AF_INET || domain == AF_INET6); + RecvTosOption opt; + switch (domain) { + case AF_INET: + opt.level = IPPROTO_IP; + opt.option = IP_RECVTOS; + break; + case AF_INET6: + opt.level = IPPROTO_IPV6; + opt.option = IPV6_RECVTCLASS; + break; + } + return opt; +} + +// Ensure that Receiving TOS or TCLASS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as +// expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + +// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this +// mirrors behavior in linux. +TEST_P(UDPSocketPairTest, TOSRecvMismatch) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(AF_INET); + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); +} + +// Test that an IPv4 socket does not support the IPv6 TClass option. +TEST_P(UDPSocketPairTest, TClassRecvMismatch) { + // This should only test AF_INET sockets for the mismatch behavior. + SKIP_IF(GetParam().domain != AF_INET); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS, + &get, &get_len), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index 9f8de6b48..57b1a357c 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1349,9 +1349,6 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack. - SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && - !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1422,7 +1419,6 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack. // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); -- cgit v1.2.3 From 2daa21e4d73f2297a8bca32c76100333e9ac4af4 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 19 Feb 2020 16:47:58 -0800 Subject: Internal change. PiperOrigin-RevId: 296088213 --- pkg/sentry/socket/netstack/provider.go | 2 ++ 1 file changed, 2 insertions(+) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 5afff2564..5f181f017 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -75,6 +75,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in switch protocol { case syscall.IPPROTO_ICMP: return header.ICMPv4ProtocolNumber, true, nil + case syscall.IPPROTO_ICMPV6: + return header.ICMPv6ProtocolNumber, true, nil case syscall.IPPROTO_UDP: return header.UDPProtocolNumber, true, nil case syscall.IPPROTO_TCP: -- cgit v1.2.3 From abf7ebcd38e8c2750f4542f29115140bb2b44a9b Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Thu, 27 Feb 2020 10:59:32 -0800 Subject: Internal change. PiperOrigin-RevId: 297638665 --- pkg/sentry/socket/netstack/netstack.go | 40 +++++++++-- pkg/tcpip/transport/packet/endpoint.go | 21 +++++- test/syscalls/linux/packet_socket.cc | 124 ++++++++++++++++++++++++++++++--- 3 files changed, 167 insertions(+), 18 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e187276c5..48c268bfa 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -712,14 +712,40 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) - if err != nil { - return err - } - if err := s.checkFamily(family, true /* exact */); err != nil { - return err + family := usermem.ByteOrder.Uint16(sockaddr) + var addr tcpip.FullAddress + + // Bind for AF_PACKET requires only family, protocol and ifindex. + // In function AddressAndFamily, we check the address length which is + // not needed for AF_PACKET bind. + if family == linux.AF_PACKET { + var a linux.SockAddrLink + if len(sockaddr) < sockAddrLinkSize { + return syserr.ErrInvalidArgument + } + binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a) + + if a.Protocol != uint16(s.protocol) { + return syserr.ErrInvalidArgument + } + + addr = tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + } + } else { + var err *syserr.Error + addr, family, err = AddressAndFamily(sockaddr) + if err != nil { + return err + } + + if err = s.checkFamily(family, true /* exact */); err != nil { + return err + } + + addr = s.mapFamily(addr, family) } - addr = s.mapFamily(addr, family) // Issue the bind request to the endpoint. return syserr.TranslateNetstackError(s.Endpoint.Bind(addr)) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 5722815e9..09a1cd436 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -76,6 +76,7 @@ type endpoint struct { sndBufSize int closed bool stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool } // NewEndpoint returns a new packet endpoint. @@ -125,6 +126,7 @@ func (ep *endpoint) Close() { } ep.closed = true + ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -216,7 +218,24 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." // - packet(7). - return tcpip.ErrNotSupported + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.bound { + return tcpip.ErrAlreadyBound + } + + // Unregister endpoint with all the nics. + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + // Bind endpoint to receive packets from specific interface. + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + return err + } + + ep.bound = true + + return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 92ae55eec..bc22de788 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() { return ifr.ifr_ifindex; } -// Receive via a packet socket. -TEST_P(CookedPacketTest, Receive) { - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - +// Receive and verify the message via packet socket on interface. +void ReceiveMessage(int sock, int ifindex) { // Wait for the socket to become readable. struct pollfd pfd = {}; - pfd.fd = socket_; + pfd.fd = sock; pfd.events = POLLIN; EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); @@ -182,9 +178,10 @@ TEST_P(CookedPacketTest, Receive) { char buf[64]; struct sockaddr_ll src = {}; socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0, reinterpret_cast(&src), &src_len), SyscallSucceedsWithValue(packet_size)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns @@ -194,7 +191,7 @@ TEST_P(CookedPacketTest, Receive) { // TODO(b/129292371): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_ifindex, ifindex); EXPECT_EQ(src.sll_halen, ETH_ALEN); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { @@ -222,6 +219,18 @@ TEST_P(CookedPacketTest, Receive) { EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); } +// Receive via a packet socket. +TEST_P(CookedPacketTest, Receive) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + int loopback_index = GetLoopbackIndex(); + ReceiveMessage(socket_, loopback_index); +} + // Send via a packet socket. TEST_P(CookedPacketTest, Send) { // TODO(b/129292371): Remove once we support packet socket writing. @@ -313,6 +322,101 @@ TEST_P(CookedPacketTest, Send) { EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); } +// Bind and receive via packet socket. +TEST_P(CookedPacketTest, BindReceive) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + ReceiveMessage(socket_, bind_addr.sll_ifindex); +} + +// Double Bind socket. +TEST_P(CookedPacketTest, DoubleBind) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Binding socket again should fail. + ASSERT_THAT( + bind(socket_, reinterpret_cast(&bind_addr), + sizeof(bind_addr)), + // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched + // to EADDRINUSE. + AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL))); +} + +// Bind and verify we do not receive data on interface which is not bound +TEST_P(CookedPacketTest, BindDrop) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + + // Bind to packet socket requires only family, protocol and ifindex. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = ifr.ifr_ifindex; + + ASSERT_THAT(bind(socket_, reinterpret_cast(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Send to loopback interface. + struct sockaddr_in dest = {}; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket never receives any data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); -- cgit v1.2.3 From 42fb7d349137bd8847e7c3df6493fde3bc8e6e89 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Mon, 2 Mar 2020 10:32:20 -0800 Subject: socket: take readMu to access readView DATA RACE in netstack.(*SocketOperations).fetchReadView Write at 0x00c001dca138 by goroutine 1001: gvisor.dev/gvisor/pkg/sentry/socket/netstack.(*SocketOperations).fetchReadView() pkg/sentry/socket/netstack/netstack.go:418 +0x85 gvisor.dev/gvisor/pkg/sentry/socket/netstack.(*SocketOperations).coalescingRead() pkg/sentry/socket/netstack/netstack.go:2309 +0x67 gvisor.dev/gvisor/pkg/sentry/socket/netstack.(*SocketOperations).nonBlockingRead() pkg/sentry/socket/netstack/netstack.go:2378 +0x183d Previous read at 0x00c001dca138 by goroutine 1111: gvisor.dev/gvisor/pkg/sentry/socket/netstack.(*SocketOperations).Ioctl() pkg/sentry/socket/netstack/netstack.go:2666 +0x533 gvisor.dev/gvisor/pkg/sentry/syscalls/linux.Ioctl() Reported-by: syzbot+d4c3885fcc346f08deb6@syzkaller.appspotmail.com PiperOrigin-RevId: 298387377 --- pkg/sentry/socket/netstack/netstack.go | 2 ++ 1 file changed, 2 insertions(+) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 48c268bfa..1eeb37446 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2663,7 +2663,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } // Add bytes removed from the endpoint but not yet sent to the caller. + s.readMu.Lock() v += len(s.readView) + s.readMu.Unlock() if v > math.MaxInt32 { v = math.MaxInt32 -- cgit v1.2.3 From 43abb24657e737dee1108ff0d512b2e1b6d8a3f6 Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Mon, 2 Mar 2020 16:30:51 -0800 Subject: Fix panic caused by invalid address for Bind in packet sockets. PiperOrigin-RevId: 298476533 --- pkg/sentry/socket/netstack/netstack.go | 4 ++++ test/syscalls/linux/packet_socket.cc | 13 +++++++++++++ 2 files changed, 17 insertions(+) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 1eeb37446..13a9a60b4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -712,6 +712,10 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { + if len(sockaddr) < 2 { + return syserr.ErrInvalidArgument + } + family := usermem.ByteOrder.Uint16(sockaddr) var addr tcpip.FullAddress diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index bc22de788..248762ca9 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -417,6 +417,19 @@ TEST_P(CookedPacketTest, BindDrop) { EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); } +// Bind with invalid address. +TEST_P(CookedPacketTest, BindFail) { + // Null address. + ASSERT_THAT(bind(socket_, nullptr, sizeof(struct sockaddr)), + SyscallFailsWithErrno(EFAULT)); + + // Address of size 1. + uint8_t addr = 0; + ASSERT_THAT( + bind(socket_, reinterpret_cast(&addr), sizeof(addr)), + SyscallFailsWithErrno(EINVAL)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); -- cgit v1.2.3 From e9e399c25d4fcad2adfe92d73b192b9784774964 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 19 Mar 2020 07:18:47 -0700 Subject: Remove workMu from tcpip.Endpoint. workMu is removed and e.mu is now a mutex that supports TryLock. The packet processing path tries to lock the mutex and if its locked it will just queue the packet and move on. The endpoint.UnlockUser() will process any backlog of packets before unlocking the socket. This simplifies the locking inside tcp endpoints a lot. Further the endpoint.LockUser() implements spinning as long as the lock is not held by another syscall goroutine. This ensures low latency as not spinning leads to the task thread being put to sleep if the lock is held by the packet dispatch path. This is suboptimal as the lower layer rarely holds the lock for long so implementing spinning here helps. If the lock is held by another task goroutine then we just proceed to call LockUser() and the task could be put to sleep. The protocol goroutines themselves just call e.mu.Lock() and block if the lock is currently not available. Updates #231, #357 PiperOrigin-RevId: 301808349 --- pkg/sentry/kernel/epoll/epoll.go | 2 + pkg/sentry/socket/netstack/netstack.go | 24 +- pkg/tcpip/transport/tcp/accept.go | 90 +++--- pkg/tcpip/transport/tcp/connect.go | 78 ++--- pkg/tcpip/transport/tcp/dispatcher.go | 8 +- pkg/tcpip/transport/tcp/endpoint.go | 495 ++++++++++++++++-------------- pkg/tcpip/transport/tcp/endpoint_state.go | 5 +- pkg/tcpip/transport/tcp/protocol.go | 38 +-- pkg/tcpip/transport/tcp/rcv.go | 6 - pkg/tcpip/transport/tcp/segment_queue.go | 8 +- pkg/tcpip/transport/tcp/snd.go | 3 - pkg/tcpip/transport/tcp/tcp_test.go | 41 ++- 12 files changed, 424 insertions(+), 374 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 8bffb78fc..592650923 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -296,8 +296,10 @@ func (*readyCallback) Callback(w *waiter.Entry) { e.waitingList.Remove(entry) e.readyList.PushBack(entry) entry.curList = &e.readyList + e.listsMu.Unlock() e.Notify(waiter.EventIn) + return } e.listsMu.Unlock() diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 13a9a60b4..a2e1da02f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,6 +29,7 @@ import ( "io" "math" "reflect" + "sync/atomic" "syscall" "time" @@ -264,6 +265,12 @@ type SocketOperations struct { skType linux.SockType protocol int + // readViewHasData is 1 iff readView has data to be read, 0 otherwise. + // Must be accessed using atomic operations. It must only be written + // with readMu held but can be read without holding readMu. The latter + // is required to avoid deadlocks in epoll Readiness checks. + readViewHasData uint32 + // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` // readView contains the remaining payload from the last packet. @@ -410,21 +417,24 @@ func (s *SocketOperations) isPacketBased() bool { // fetchReadView updates the readView field of the socket if it's currently // empty. It assumes that the socket is locked. +// +// Precondition: s.readMu must be held. func (s *SocketOperations) fetchReadView() *syserr.Error { if len(s.readView) > 0 { return nil } - s.readView = nil s.sender = tcpip.FullAddress{} v, cms, err := s.Endpoint.Read(&s.sender) if err != nil { + atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) } s.readView = v s.readCM = cms + atomic.StoreUint32(&s.readViewHasData, 1) return nil } @@ -623,11 +633,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Check our cached value iff the caller asked for readability and the // endpoint itself is currently not readable. if (mask & ^r & waiter.EventIn) != 0 { - s.readMu.Lock() - if len(s.readView) > 0 { + if atomic.LoadUint32(&s.readViewHasData) == 1 { r |= waiter.EventIn } - s.readMu.Unlock() } return r @@ -2334,6 +2342,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } copied += n s.readView.TrimFront(n) + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) @@ -2456,6 +2468,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + var flags int if msgLen > int(n) { flags |= linux.MSG_TRUNC diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 85049e54e..4d7602d54 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -221,7 +221,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu } // createConnectingEndpoint creates a new endpoint in a connecting state, with -// the connection parameters given by the arguments. +// the connection parameters given by the arguments. The endpoint is returned +// with n.mu held. func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto @@ -243,21 +244,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() - // Now inherit any socket options that should be inherited from the - // listening endpoint. - // In case of Forwarder listenEP will be nil and hence this check. - if l.listenEP != nil { - l.listenEP.propagateInheritableOptions(n) - } - - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { - n.Close() - return nil, err - } - - n.isRegistered = true - // Create sender and receiver. // // The receiver at least temporarily has a zero receive window scale, @@ -269,11 +255,27 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + n.mu.Lock() + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + return nil, err + } + + n.isRegistered = true + return n, nil } // createEndpointAndPerformHandshake creates a new endpoint in connected state // and then performs the TCP 3-way handshake. +// +// The new endpoint is returned with e.mu held. func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber @@ -289,9 +291,25 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { l.listenEP.mu.Unlock() + // Ensure we release any registrations done by the newly + // created endpoint. + ep.mu.Unlock() + ep.Close() + + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + l.listenEP.propagateInheritableOptionsLocked(ep) + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } @@ -299,6 +317,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { + ep.mu.Unlock() ep.Close() // Wake up any waiters. This is strictly not required normally // as a socket that was never accepted can't really have any @@ -312,9 +331,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head } return nil, err } - ep.mu.Lock() ep.isConnectNotified = true - ep.mu.Unlock() // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -366,12 +383,12 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } -// propagateInheritableOptions propagates any options set on the listening +// propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. -func (e *endpoint) propagateInheritableOptions(n *endpoint) { - e.mu.Lock() +// +// Precondition: e.mu and n.mu must be held. +func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout - e.mu.Unlock() } // handleSynSegment is called in its own goroutine once the listening endpoint @@ -382,7 +399,11 @@ func (e *endpoint) propagateInheritableOptions(n *endpoint) { // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { defer decSynRcvdCount() - defer e.decSynRcvdCount() + defer func() { + e.mu.Lock() + e.decSynRcvdCount() + e.mu.Unlock() + }() defer s.decRef() n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) @@ -399,29 +420,21 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } func (e *endpoint) incSynRcvdCount() bool { - e.mu.Lock() - if e.synRcvdCount >= cap(e.acceptedChan) { - e.mu.Unlock() + if e.synRcvdCount >= (cap(e.acceptedChan)) { return false } e.synRcvdCount++ - e.mu.Unlock() return true } func (e *endpoint) decSynRcvdCount() { - e.mu.Lock() e.synRcvdCount-- - e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { - e.mu.Lock() if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { - e.mu.Unlock() return true } - e.mu.Unlock() return false } @@ -559,6 +572,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + // clear the tsOffset for the newly created // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was @@ -593,14 +610,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() v6only := e.v6only - e.mu.Unlock() ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections // to the endpoint. - e.mu.Lock() e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. @@ -622,7 +637,10 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { - switch index, _ := s.Fetch(true); index { + e.mu.Unlock() + index, _ := s.Fetch(true) + e.mu.Lock() + switch index { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { @@ -635,7 +653,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { s.decRef() } close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } case wakerForNewSegment: diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index be86af502..edb37a549 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -61,6 +61,9 @@ const ( ) // handshake holds the state used during a TCP 3-way handshake. +// +// NOTE: handshake.ep.mu is held during handshake processing. It is released if +// we are going to block and reacquired when we start processing an event. type handshake struct { ep *endpoint state handshakeState @@ -209,9 +212,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.deferAccept = deferAccept - h.ep.mu.Lock() h.ep.setEndpointState(StateSynRecv) - h.ep.mu.Unlock() } // checkAck checks if the ACK number, if present, of a segment received during @@ -241,9 +242,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK // was acceptable then signal the user "error: connection reset", drop // the segment, enter CLOSED state, delete TCB, and return." - h.ep.mu.Lock() h.ep.workerCleanup = true - h.ep.mu.Unlock() // Although the RFC above calls out ECONNRESET, Linux actually returns // ECONNREFUSED here so we do as well. return tcpip.ErrConnectionRefused @@ -281,9 +280,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted - h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) - h.ep.mu.Unlock() h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil @@ -293,11 +290,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd - h.ep.mu.Lock() ttl := h.ep.ttl amss := h.ep.amss h.ep.setEndpointState(StateSynRecv) - h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, @@ -357,10 +352,6 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - h.ep.mu.RLock() - amss := h.ep.amss - h.ep.mu.RUnlock() - h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -368,7 +359,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, - MSS: amss, + MSS: h.ep.amss, } h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil @@ -399,15 +390,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { } h.state = handshakeCompleted - h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) + // If the segment has data then requeue it for the receiver // to process it again once main loop is started. if s.data.Size() > 0 { s.incRef() h.ep.enqueueSegment(s) } - h.ep.mu.Unlock() return nil } @@ -493,7 +483,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } if n¬ifyDrain != 0 { close(h.ep.drainDone) + h.ep.mu.Unlock() <-h.ep.undrain + h.ep.mu.Lock() } } @@ -535,7 +527,6 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.mu.Lock() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ @@ -546,7 +537,6 @@ func (h *handshake) execute() *tcpip.Error { SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } - h.ep.mu.Unlock() // Execute is also called in a listen context so we want to make sure we // only send the TS/SACK option when we received the TS/SACK in the @@ -563,7 +553,11 @@ func (h *handshake) execute() *tcpip.Error { h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) for h.state != handshakeCompleted { - switch index, _ := s.Fetch(true); index { + h.ep.mu.Unlock() + index, _ := s.Fetch(true) + h.ep.mu.Lock() + switch index { + case wakerForResend: timeOut *= 2 if timeOut > MaxRTO { @@ -600,7 +594,9 @@ func (h *handshake) execute() *tcpip.Error { } } close(h.ep.drainDone) + h.ep.mu.Unlock() <-h.ep.undrain + h.ep.mu.Lock() } case wakerForNewSegment: @@ -1016,7 +1012,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So // we only process it if it's acceptable. - e.mu.Lock() switch e.EndpointState() { // In case of a RST in CLOSE-WAIT linux moves // the socket to closed state with an error set @@ -1040,11 +1035,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { case StateCloseWait: e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted - e.mu.Unlock() e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: - e.mu.Unlock() // RFC 793, page 37 states that "in all states // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So @@ -1157,9 +1150,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - e.mu.RLock() state := e.state - e.mu.RUnlock() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. @@ -1182,9 +1173,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { - e.mu.RLock() userTimeout := e.userTimeout - e.mu.RUnlock() e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { @@ -1248,6 +1237,7 @@ func (e *endpoint) disableKeepaliveTimer() { // goroutine and is responsible for sending segments and handling received // segments. func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { + e.mu.Lock() var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1269,7 +1259,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.mu.Unlock() - e.workMu.Unlock() // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1280,16 +1269,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // completion. initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) - e.mu.Lock() h.ep.setEndpointState(StateSynSent) - e.mu.Unlock() if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() - e.mu.Lock() e.setEndpointState(StateError) e.HardError = err @@ -1302,9 +1288,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - e.mu.Lock() drained := e.drainDone != nil - e.mu.Unlock() if drained { close(e.drainDone) <-e.undrain @@ -1330,10 +1314,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // This means the socket is being closed due // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. - e.mu.Lock() e.transitionToStateCloseLocked() e.workerCleanup = true - e.mu.Unlock() return nil }, }, @@ -1388,7 +1370,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyClose != 0 && closeTimer == nil { - e.mu.Lock() if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. @@ -1397,7 +1378,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }) e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } - e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1417,7 +1397,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } } @@ -1460,7 +1442,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.rcvListMu.Unlock() - e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } @@ -1468,7 +1449,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. cleanupOnError := func(err *tcpip.Error) { - e.mu.Lock() e.workerCleanup = true if err != nil { e.resetConnectionLocked(err) @@ -1480,16 +1460,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ loop: for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { e.mu.Unlock() - e.workMu.Unlock() v, _ := s.Fetch(true) - e.workMu.Lock() + e.mu.Lock() - // We need to double check here because the notification maybe + // We need to double check here because the notification may be // stale by the time we got around to processing it. - // - // NOTE: since we now hold the workMu the processors cannot - // change the state of the endpoint so it's safe to proceed - // after this check. switch e.EndpointState() { case StateError: // If the endpoint has already transitioned to an ERROR @@ -1502,21 +1477,17 @@ loop: case StateTimeWait: fallthrough case StateClose: - e.mu.Lock() break loop default: if err := funcs[v].f(); err != nil { cleanupOnError(err) return nil } - e.mu.Lock() } } - state := e.EndpointState() - e.mu.Unlock() var reuseTW func() - if state == StateTimeWait { + if e.EndpointState() == StateTimeWait { // Disable close timer as we now entering real TIME_WAIT. if closeTimer != nil { closeTimer.Stop() @@ -1526,14 +1497,11 @@ loop: s.Done() // Wake up any waiters before we enter TIME_WAIT. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.mu.Lock() e.workerCleanup = true - e.mu.Unlock() reuseTW = e.doTimeWait() } // Mark endpoint as closed. - e.mu.Lock() if e.EndpointState() != StateError { e.transitionToStateCloseLocked() } @@ -1649,9 +1617,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) { defer timeWaitTimer.Stop() for { - e.workMu.Unlock() + e.mu.Unlock() v, _ := s.Fetch(true) - e.workMu.Lock() + e.mu.Lock() switch v { case newSegment: extendTimeWait, reuseTW := e.handleTimeWaitSegments() @@ -1674,7 +1642,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) { e.handleTimeWaitSegments() } close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() return nil } case timeWaitDone: diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index d792b07d6..90ac956a9 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -128,7 +128,7 @@ func (p *processor) handleSegments() { continue } - if !ep.workMu.TryLock() { + if !ep.mu.TryLock() { ep.newSegmentWaker.Assert() continue } @@ -138,12 +138,10 @@ func (p *processor) handleSegments() { if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { // Send any active resets if required. if err != nil { - ep.mu.Lock() ep.resetConnectionLocked(err) - ep.mu.Unlock() } ep.notifyProtocolGoroutine(notifyTickleWorker) - ep.workMu.Unlock() + ep.mu.Unlock() continue } @@ -151,7 +149,7 @@ func (p *processor) handleSegments() { p.epQ.enqueue(ep) } - ep.workMu.Unlock() + ep.mu.Unlock() } } } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 5187a5e25..eb8a9d73e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "fmt" "math" + "runtime" "strings" "sync/atomic" "time" @@ -33,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tmutex" "gvisor.dev/gvisor/pkg/waiter" ) @@ -283,6 +283,37 @@ func (*EndpointInfo) IsEndpointInfo() {} // synchronized. The protocol implementation, however, runs in a single // goroutine. // +// Each endpoint has a few mutexes: +// +// e.mu -> Primary mutex for an endpoint must be held for all operations except +// in e.Readiness where acquiring it will result in a deadlock in epoll +// implementation. +// +// The following three mutexes can be acquired independent of e.mu but if +// acquired with e.mu then e.mu must be acquired first. +// +// e.rcvListMu -> Protects the rcvList and associated fields. +// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.lastErrorMu -> Protects the lastError field. +// +// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different +// based on the context in which the lock is acquired. In the syscall context +// e.LockUser/e.UnlockUser should be used and when doing background processing +// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below +// in brief. +// +// The reason for this locking behaviour is to avoid wakeups to handle packets. +// In cases where the endpoint is already locked the background processor can +// queue the packet up and go its merry way and the lock owner will eventually +// process the backlog when releasing the lock. Similarly when acquiring the +// lock from say a syscall goroutine we can implement a bit of spinning if we +// know that the lock is not held by another syscall goroutine. Background +// processors should never hold the lock for long and we can avoid an expensive +// sleep/wakeup by spinning for a shortwhile. +// +// For more details please see the detailed documentation on +// e.LockUser/e.UnlockUser methods. +// // +stateify savable type endpoint struct { EndpointInfo @@ -299,12 +330,6 @@ type endpoint struct { // Precondition: epQueue.mu must be held to read/write this field.. pendingProcessing bool `state:"nosave"` - // workMu is used to arbitrate which goroutine may perform protocol - // work. Only the main protocol goroutine is expected to call Lock() on - // it, but other goroutines (e.g., send) may call TryLock() to eagerly - // perform work without having to wait for the main one to wake up. - workMu tmutex.Mutex `state:"nosave"` - // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` @@ -330,15 +355,11 @@ type endpoint struct { rcvBufSize int rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams - // zeroWindow indicates that the window was closed due to receive buffer - // space being filled up. This is set by the worker goroutine before - // moving a segment to the rcvList. This setting is cleared by the - // endpoint when a Read() call reads enough data for the new window to - // be non-zero. - zeroWindow bool - // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` + // mu protects all endpoint fields unless documented otherwise. mu must + // be acquired before interacting with the endpoint fields. + mu sync.Mutex `state:"nosave"` + ownedByUser uint32 // state must be read/set using the EndpointState()/setEndpointState() methods. state EndpointState `state:".(EndpointState)"` @@ -583,14 +604,93 @@ func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { return maxMSS } +// LockUser tries to lock e.mu and if it fails it will check if the lock is held +// by another syscall goroutine. If yes, then it will goto sleep waiting for the +// lock to be released, if not then it will spin till it acquires the lock or +// another syscall goroutine acquires it in which case it will goto sleep as +// described above. +// +// The assumption behind spinning here being that background packet processing +// should not be holding the lock for long and spinning reduces latency as we +// avoid an expensive sleep/wakeup of of the syscall goroutine). +func (e *endpoint) LockUser() { + for { + // Try first if the sock is locked then check if it's owned + // by another user goroutine if not then we spin, otherwise + // we just goto sleep on the Lock() and wait. + if !e.mu.TryLock() { + // If socket is owned by the user then just goto sleep + // as the lock could be held for a reasonably long time. + if atomic.LoadUint32(&e.ownedByUser) == 1 { + e.mu.Lock() + atomic.StoreUint32(&e.ownedByUser, 1) + return + } + // Spin but yield the processor since the lower half + // should yield the lock soon. + runtime.Gosched() + continue + } + atomic.StoreUint32(&e.ownedByUser, 1) + return + } +} + +// UnlockUser will check if there are any segments already queued for processing +// and process any such segments before unlocking e.mu. This is required because +// we when packets arrive and endpoint lock is already held then such packets +// are queued up to be processed. If the lock is held by the endpoint goroutine +// then it will process these packets but if the lock is instead held by the +// syscall goroutine then we can have the syscall goroutine process the backlog +// before unlocking. +// +// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the +// endpoint. It's also required eventually when we get rid of the endpoint +// protocol goroutine altogether. +// +// Precondition: e.LockUser() must have been called before calling e.UnlockUser() +func (e *endpoint) UnlockUser() { + // Lock segment queue before checking so that we avoid a race where + // segments can be queued between the time we check if queue is empty + // and actually unlock the endpoint mutex. + for { + e.segmentQueue.mu.Lock() + if e.segmentQueue.emptyLocked() { + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + e.segmentQueue.mu.Unlock() + return + } + e.segmentQueue.mu.Unlock() + + switch e.EndpointState() { + case StateEstablished: + if err := e.handleSegments(true /* fastPath */); err != nil { + e.notifyProtocolGoroutine(notifyTickleWorker) + } + default: + // Since we are waking the endpoint goroutine here just unlock + // and let it process the queued segments. + e.newSegmentWaker.Assert() + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + return + } + } +} + // StopWork halts packet processing. Only to be used in tests. func (e *endpoint) StopWork() { - e.workMu.Lock() + e.mu.Lock() } // ResumeWork resumes packet processing. Only to be used in tests. func (e *endpoint) ResumeWork() { - e.workMu.Unlock() + e.mu.Unlock() } // setEndpointState updates the state of the endpoint to state atomically. This @@ -709,8 +809,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - e.workMu.Lock() e.tsOffset = timeStampOffset() return e @@ -721,9 +819,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result := waiter.EventMask(0) - e.mu.RLock() - defer e.mu.RUnlock() - switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. @@ -823,20 +918,22 @@ func (e *endpoint) Abort() { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { - e.mu.Lock() - closed := e.closed - e.closed = true - e.mu.Unlock() - if closed { + e.LockUser() + defer e.UnlockUser() + if e.closed { return } // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. - e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - - e.mu.Lock() + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdownLocked() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdownLocked() { // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail @@ -853,6 +950,8 @@ func (e *endpoint) Close() { e.boundPortFlags = ports.Flags{} } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. switch e.EndpointState() { @@ -873,8 +972,6 @@ func (e *endpoint) Close() { // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) } - - e.mu.Unlock() } // closePendingAcceptableConnections closes all connections that have completed @@ -909,7 +1006,6 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { // after Close() is called and the worker goroutine (if any) is done with its // work. func (e *endpoint) cleanupLocked() { - // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { @@ -954,18 +1050,18 @@ func (e *endpoint) initialReceiveWindow() int { // ModerateRecvBuf adjusts the receive buffer and the advertised window // based on the number of bytes copied to user space. func (e *endpoint) ModerateRecvBuf(copied int) { - e.mu.RLock() + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() - e.mu.RUnlock() return } now := time.Now() if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { e.rcvAutoParams.copied += copied e.rcvListMu.Unlock() - e.mu.RUnlock() return } prevRTTCopied := e.rcvAutoParams.copied + copied @@ -1021,7 +1117,6 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvAutoParams.measureTime = now e.rcvAutoParams.copied = 0 e.rcvListMu.Unlock() - e.mu.RUnlock() } // IPTables implements tcpip.Endpoint.IPTables. @@ -1031,7 +1126,7 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() + e.LockUser() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the @@ -1041,7 +1136,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError - e.mu.RUnlock() + e.UnlockUser() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } @@ -1051,7 +1146,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() if err == tcpip.ErrClosedForReceive { e.stats.ReadErrors.ReadClosed.Increment() @@ -1124,13 +1219,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. - e.mu.RLock() + e.LockUser() e.sndBufMu.Lock() avail, err := e.isEndpointWritableLocked() if err != nil { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -1142,113 +1237,68 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // are copying data in. if !opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } // Fetch data. v, perr := p.Payload(avail) if perr != nil || len(v) == 0 { - if opts.Atomic { // See above. + // Note that perr may be nil if len(v) == 0. + if opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } - // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - if opts.Atomic { + queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { // Add data to the send queue. s := newSegmentFromView(&e.route, e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - } - if e.workMu.TryLock() { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() - - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - - } // Do the work inline. e.handleWrite() - e.workMu.Unlock() - } else { - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + e.UnlockUser() + return int64(len(v)), nil, nil + } - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } + if opts.Atomic { + // Locks released in queueAndSend() + return queueAndSend() + } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + e.LockUser() + e.sndBufMu.Lock() + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } - } - // Let the protocol goroutine do the work. - e.sndWaker.Assert() + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } - return int64(len(v)), nil, nil + // Locks released in queueAndSend() + return queueAndSend() } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. @@ -1339,6 +1389,9 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + switch opt { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. @@ -1346,9 +1399,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - e.mu.Lock() - defer e.mu.Unlock() - // We only allow this to be set when we're in the initial state. if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState @@ -1379,7 +1429,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { mask := uint32(notifyReceiveWindowChanged) - e.mu.RLock() + e.LockUser() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1409,8 +1459,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } + e.rcvListMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() e.notifyProtocolGoroutine(mask) return nil @@ -1466,15 +1517,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.ReuseAddressOption: - e.mu.Lock() + e.LockUser() e.reuseAddr = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.ReusePortOption: - e.mu.Lock() + e.LockUser() e.reusePort = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.BindToDeviceOption: @@ -1482,9 +1533,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { if id != 0 && !e.stack.HasNIC(id) { return tcpip.ErrUnknownDevice } - e.mu.Lock() + e.LockUser() e.bindToDevice = id - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.QuickAckOption: @@ -1500,16 +1551,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { return tcpip.ErrInvalidOptionValue } - e.mu.Lock() + e.LockUser() e.userMSS = uint16(userMSS) - e.mu.Unlock() + e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) return nil case tcpip.TTLOption: - e.mu.Lock() + e.LockUser() e.ttl = uint8(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.KeepaliveEnabledOption: @@ -1541,15 +1592,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.TCPUserTimeoutOption: - e.mu.Lock() + e.LockUser() e.userTimeout = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.BroadcastOption: - e.mu.Lock() + e.LockUser() e.broadcast = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.CongestionControlOption: @@ -1563,22 +1614,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { availCC := strings.Split(string(avail), " ") for _, cc := range availCC { if v == tcpip.CongestionControlOption(cc) { - // Acquire the work mutex as we may need to - // reinitialize the congestion control state. - e.mu.Lock() + e.LockUser() state := e.EndpointState() e.cc = v - e.mu.Unlock() switch state { case StateEstablished: - e.workMu.Lock() - e.mu.Lock() if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } - e.mu.Unlock() - e.workMu.Unlock() } + e.UnlockUser() return nil } } @@ -1588,23 +1633,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrNoSuchFile case tcpip.IPv4TOSOption: - e.mu.Lock() + e.LockUser() // TODO(gvisor.dev/issue/995): ECN is not currently supported, // ignore the bits for now. e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.IPv6TrafficClassOption: - e.mu.Lock() + e.LockUser() // TODO(gvisor.dev/issue/995): ECN is not currently supported, // ignore the bits for now. e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.TCPLingerTimeoutOption: - e.mu.Lock() + e.LockUser() if v < 0 { // Same as effectively disabling TCPLinger timeout. v = 0 @@ -1622,16 +1667,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { v = stkTCPLingerTimeout } e.tcpLingerTimeout = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.TCPDeferAcceptOption: - e.mu.Lock() + e.LockUser() if time.Duration(v) > MaxRTO { v = tcpip.TCPDeferAcceptOption(MaxRTO) } e.deferAccept = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil default: @@ -1641,8 +1686,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // readyReceiveSize returns the number of bytes ready to be received. func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint cannot be in listen state. if e.EndpointState() == StateListen { @@ -1664,9 +1709,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrUnknownProtocolOption } - e.mu.Lock() + e.LockUser() v := e.v6only - e.mu.Unlock() + e.UnlockUser() return v, nil } @@ -1730,9 +1775,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.ReuseAddressOption: - e.mu.RLock() + e.LockUser() v := e.reuseAddr - e.mu.RUnlock() + e.UnlockUser() *o = 0 if v { @@ -1741,9 +1786,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.ReusePortOption: - e.mu.RLock() + e.LockUser() v := e.reusePort - e.mu.RUnlock() + e.UnlockUser() *o = 0 if v { @@ -1752,9 +1797,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.BindToDeviceOption: - e.mu.RLock() + e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.QuickAckOption: @@ -1765,16 +1810,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.TTLOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} - e.mu.RLock() + e.LockUser() snd := e.snd - e.mu.RUnlock() + e.UnlockUser() if snd != nil { snd.rtt.Lock() o.RTT = snd.rtt.srtt @@ -1813,9 +1858,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.TCPUserTimeoutOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPUserTimeoutOption(e.userTimeout) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.OutOfBandInlineOption: @@ -1824,9 +1869,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.BroadcastOption: - e.mu.Lock() + e.LockUser() v := e.broadcast - e.mu.Unlock() + e.UnlockUser() *o = 0 if v { @@ -1835,33 +1880,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.CongestionControlOption: - e.mu.Lock() + e.LockUser() *o = e.cc - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.IPv4TOSOption: - e.mu.RLock() + e.LockUser() *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() + e.LockUser() *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.TCPLingerTimeoutOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.TCPDeferAcceptOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPDeferAcceptOption(e.deferAccept) - e.mu.Unlock() + e.UnlockUser() return nil default: @@ -1901,8 +1946,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // yet accepted by the app, they are restored without running the main goroutine // here. func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() connectingAddr := addr.Addr @@ -2071,9 +2116,13 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - e.mu.Lock() + e.LockUser() + defer e.UnlockUser() + return e.shutdownLocked(flags) +} + +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { e.shutdownFlags |= flags - finQueued := false switch { case e.EndpointState().connected(): // Close for read. @@ -2087,24 +2136,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.mu.Unlock() - // Try to send an active reset immediately if the - // work mutex is available. - if e.workMu.TryLock() { - e.mu.Lock() - // We need to double check here to make - // sure worker has not transitioned the - // endpoint out of a connected state - // before trying to send a reset. - if e.EndpointState().connected() { - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.notifyProtocolGoroutine(notifyTickleWorker) - } - e.mu.Unlock() - e.workMu.Unlock() - } else { - e.notifyProtocolGoroutine(notifyReset) - } + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + // Wake up worker to terminate loop. + e.notifyProtocolGoroutine(notifyTickleWorker) return nil } } @@ -2116,42 +2150,32 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Already closed. e.sndBufMu.Unlock() if e.EndpointState() == StateTimeWait { - e.mu.Unlock() return tcpip.ErrNotConnected } - break + return nil } // Queue fin segment. s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() + e.handleClose() } + return nil case e.EndpointState() == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } + return nil + default: - e.mu.Unlock() return tcpip.ErrNotConnected } - e.mu.Unlock() - if finQueued { - if e.workMu.TryLock() { - e.handleClose() - e.workMu.Unlock() - } else { - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() - } - } - return nil } // Listen puts the endpoint in "listen" mode, which allows it to accept @@ -2166,8 +2190,8 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error { } func (e *endpoint) listen(backlog int) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -2229,7 +2253,6 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. func (e *endpoint) startAcceptedLoop() { - e.mu.Lock() e.workerRunning = true e.mu.Unlock() wakerInitDone := make(chan struct{}) @@ -2240,8 +2263,8 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // Endpoint must be in listen state before it can accept connections. if e.EndpointState() != StateListen { @@ -2260,8 +2283,8 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { // Bind binds the endpoint to a specific local port and optionally address. func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() return e.bindLocked(addr) } @@ -2339,8 +2362,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // GetLocalAddress returns the address to which the endpoint is bound. func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() return tcpip.FullAddress{ Addr: e.ID.LocalAddress, @@ -2351,8 +2374,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { // GetRemoteAddress returns the address to which the endpoint is connected. func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected @@ -2419,7 +2442,6 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { - e.mu.RLock() e.rcvListMu.Lock() if s != nil { s.incRef() @@ -2434,7 +2456,6 @@ func (e *endpoint) readyToRead(s *segment) { e.rcvClosed = true } e.rcvListMu.Unlock() - e.mu.RUnlock() e.waiterQueue.Notify(waiter.EventIn) } @@ -2578,9 +2599,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SegTime = time.Now() // Copy EndpointID. - e.mu.Lock() s.ID = stack.TCPEndpointID(e.ID) - e.mu.Unlock() // Copy endpoint rcv state. e.rcvListMu.Lock() @@ -2710,10 +2729,10 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { - e.mu.RLock() + e.LockUser() // Make a copy of the endpoint info. ret := e.EndpointInfo - e.mu.RUnlock() + e.UnlockUser() return &ret } @@ -2728,9 +2747,9 @@ func (e *endpoint) Wait() { e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) defer e.waiterQueue.EventUnregister(&waitEntry) for { - e.mu.Lock() + e.LockUser() running := e.workerRunning - e.mu.Unlock() + e.UnlockUser() if !running { break } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4a46f0ec5..9175de441 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -162,8 +162,8 @@ func (e *endpoint) loadState(state EndpointState) { connectingLoading.Add(1) } // Directly update the state here rather than using e.setEndpointState - // as the endpoint is still being loaded and the stack reference to increment - // metrics is not yet initialized. + // as the endpoint is still being loaded and the stack reference is not + // yet initialized. atomic.StoreUint32((*uint32)(&e.state), uint32(state)) } @@ -180,7 +180,6 @@ func (e *endpoint) afterLoad() { func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() state := e.origEndpointState switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 73098d904..b0f918bb4 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -95,7 +95,7 @@ const ( ) type protocol struct { - mu sync.Mutex + mu sync.RWMutex sackEnabled bool delayEnabled bool sendBufferSize SendBufferSizeOption @@ -273,57 +273,57 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: - p.mu.Lock() + p.mu.RLock() *v = SACKEnabled(p.sackEnabled) - p.mu.Unlock() + p.mu.RUnlock() return nil case *DelayEnabled: - p.mu.Lock() + p.mu.RLock() *v = DelayEnabled(p.delayEnabled) - p.mu.Unlock() + p.mu.RUnlock() return nil case *SendBufferSizeOption: - p.mu.Lock() + p.mu.RLock() *v = p.sendBufferSize - p.mu.Unlock() + p.mu.RUnlock() return nil case *ReceiveBufferSizeOption: - p.mu.Lock() + p.mu.RLock() *v = p.recvBufferSize - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.CongestionControlOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.CongestionControlOption(p.congestionControl) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.AvailableCongestionControlOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.ModerateReceiveBufferOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.TCPLingerTimeoutOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.TCPTimeWaitTimeoutOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) - p.mu.Unlock() + p.mu.RUnlock() return nil default: diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index d80aff1b6..caf8977b3 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -168,7 +168,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. - r.ep.mu.Lock() switch r.ep.EndpointState() { case StateEstablished: r.ep.setEndpointState(StateCloseWait) @@ -183,7 +182,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateFinWait2: r.ep.setEndpointState(StateTimeWait) } - r.ep.mu.Unlock() // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the @@ -208,7 +206,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { - r.ep.mu.Lock() switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -222,7 +219,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateLastAck: r.ep.transitionToStateCloseLocked() } - r.ep.mu.Unlock() } return true @@ -336,10 +332,8 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { - r.ep.mu.RLock() state := r.ep.EndpointState() closed := r.ep.closed - r.ep.mu.RUnlock() if state != StateEstablished { drop, err := r.handleRcvdSegmentClosing(s, state, closed) diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index bd20a7ee9..48a257137 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -28,10 +28,16 @@ type segmentQueue struct { used int } +// emptyLocked determines if the queue is empty. +// Preconditions: q.mu must be held. +func (q *segmentQueue) emptyLocked() bool { + return q.used == 0 +} + // empty determines if the queue is empty. func (q *segmentQueue) empty() bool { q.mu.Lock() - r := q.used == 0 + r := q.emptyLocked() q.mu.Unlock() return r diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 657c3146e..17fed4ec5 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -455,9 +455,7 @@ func (s *sender) retransmitTimerExpired() bool { // Give up if we've waited more than a minute since the last resend or // if a user time out is set and we have exceeded the user specified // timeout since the first retransmission. - s.ep.mu.RLock() uto := s.ep.userTimeout - s.ep.mu.RUnlock() if s.firstRetransmittedSegXmitTime.IsZero() { // We store the original xmitTime of the segment that we are @@ -713,7 +711,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.setEndpointState(StateFinWait1) } - } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 5b2b16afa..39d36d2ba 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2236,9 +2236,17 @@ func TestSegmentMerging(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - // Prevent the endpoint from processing packets. - test.stop(c.EP) + // Send 10 1 byte segments to fill up InitialWindow but don't + // ACK. That should prevent anymore packets from going out. + for i := 0; i < 10; i++ { + view := buffer.NewViewFromBytes([]byte{0}) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %v", i+1, err) + } + } + // Now send the segments that should get merged as the congestion + // window is full and we won't be able to send any more packets. var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) @@ -2248,8 +2256,29 @@ func TestSegmentMerging(t *testing.T) { } } - // Let the endpoint process the segments that we just sent. - test.resume(c.EP) + // Check that we get 10 packets of 1 byte each. + for i := 0; i < 10; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize+1), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+uint32(i)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. + RcvWnd: 30000, + }) // Check that data is received. b := c.GetPacket() @@ -2257,7 +2286,7 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+11), checker.AckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), @@ -2273,7 +2302,7 @@ func TestSegmentMerging(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), + AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), RcvWnd: 30000, }) }) -- cgit v1.2.3 From 3a37f6791745a26d38b69fbe9d4090f8fd0c7827 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 19 Mar 2020 09:59:21 -0700 Subject: Change SocketOperations.readMu to an RWMutex. Also get rid of the readViewHasData as it's not required anymore. Updates #231, #357 PiperOrigin-RevId: 301837227 --- pkg/sentry/socket/netstack/netstack.go | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index a2e1da02f..a6ef7a47e 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,7 +29,6 @@ import ( "io" "math" "reflect" - "sync/atomic" "syscall" "time" @@ -265,14 +264,8 @@ type SocketOperations struct { skType linux.SockType protocol int - // readViewHasData is 1 iff readView has data to be read, 0 otherwise. - // Must be accessed using atomic operations. It must only be written - // with readMu held but can be read without holding readMu. The latter - // is required to avoid deadlocks in epoll Readiness checks. - readViewHasData uint32 - // readMu protects access to the below fields. - readMu sync.Mutex `state:"nosave"` + readMu sync.RWMutex `state:"nosave"` // readView contains the remaining payload from the last packet. readView buffer.View // readCM holds control message information for the last packet read @@ -428,13 +421,11 @@ func (s *SocketOperations) fetchReadView() *syserr.Error { v, cms, err := s.Endpoint.Read(&s.sender) if err != nil { - atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) } s.readView = v s.readCM = cms - atomic.StoreUint32(&s.readViewHasData, 1) return nil } @@ -633,9 +624,11 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Check our cached value iff the caller asked for readability and the // endpoint itself is currently not readable. if (mask & ^r & waiter.EventIn) != 0 { - if atomic.LoadUint32(&s.readViewHasData) == 1 { + s.readMu.RLock() + if len(s.readView) > 0 { r |= waiter.EventIn } + s.readMu.RUnlock() } return r @@ -2342,9 +2335,6 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } copied += n s.readView.TrimFront(n) - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) - } dst = dst.DropFirst(n) if e != nil { @@ -2468,10 +2458,6 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) - } - var flags int if msgLen > int(n) { flags |= linux.MSG_TRUNC -- cgit v1.2.3 From 369cf38bd7186da97e134538cd4839a8a4d1aa2c Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Mon, 23 Mar 2020 16:05:29 -0700 Subject: Fix data race in SetSockOpt. PiperOrigin-RevId: 302539171 --- pkg/sentry/socket/netstack/netstack.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index a6ef7a47e..c19f5639b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2382,9 +2382,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe // caller-supplied buffer. s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) - s.readMu.Unlock() cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) + s.readMu.Unlock() return n, 0, nil, 0, cmsg, err } -- cgit v1.2.3 From a730d74b3230fb32181b9a940c07b61338222874 Mon Sep 17 00:00:00 2001 From: Ian Lewis Date: Mon, 23 Mar 2020 16:11:37 -0700 Subject: Support basic /proc/net/dev metrics for netstack Fixes #506 PiperOrigin-RevId: 302540404 --- pkg/sentry/socket/netstack/stack.go | 73 ++++++++++++++++++++++++++----------- test/syscalls/linux/proc_net.cc | 53 +++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 21 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 0692482e9..a8e2e8c24 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -200,36 +200,66 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { switch stats := stat.(type) { + case *inet.StatDev: + for _, ni := range s.Stack.NICInfo() { + if ni.Name != arg { + continue + } + // TODO(gvisor.dev/issue/2103) Support stubbed stats. + *stats = inet.StatDev{ + // Receive section. + ni.Stats.Rx.Bytes.Value(), // bytes. + ni.Stats.Rx.Packets.Value(), // packets. + 0, // errs. + 0, // drop. + 0, // fifo. + 0, // frame. + 0, // compressed. + 0, // multicast. + // Transmit section. + ni.Stats.Tx.Bytes.Value(), // bytes. + ni.Stats.Tx.Packets.Value(), // packets. + 0, // errs. + 0, // drop. + 0, // fifo. + 0, // colls. + 0, // carrier. + 0, // compressed. + } + break + } case *inet.StatSNMPIP: ip := Metrics.IP + // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPIP{ - 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. - 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + 0, // Ip/Forwarding. + 0, // Ip/DefaultTTL. ip.PacketsReceived.Value(), // InReceives. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + 0, // Ip/InHdrErrors. ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + 0, // Ip/ForwDatagrams. + 0, // Ip/InUnknownProtos. + 0, // Ip/InDiscards. ip.PacketsDelivered.Value(), // InDelivers. ip.PacketsSent.Value(), // OutRequests. ip.OutgoingPacketErrors.Value(), // OutDiscards. - 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + 0, // Ip/OutNoRoutes. + 0, // Support Ip/ReasmTimeout. + 0, // Support Ip/ReasmReqds. + 0, // Support Ip/ReasmOKs. + 0, // Support Ip/ReasmFails. + 0, // Support Ip/FragOKs. + 0, // Support Ip/FragFails. + 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ - 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs. + 0, // Icmp/InMsgs. Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. - 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors. + 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. in.ParamProblem.Value(), // InParmProbs. @@ -241,7 +271,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.TimestampReply.Value(), // InTimestampReps. in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. - 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs. + 0, // Icmp/OutMsgs. Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. out.DstUnreachable.Value(), // OutDestUnreachs. out.TimeExceeded.Value(), // OutTimeExcds. @@ -277,15 +307,16 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { } case *inet.StatSNMPUDP: udp := Metrics.UDP + // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPUDP{ udp.PacketsReceived.Value(), // InDatagrams. udp.UnknownPortErrors.Value(), // NoPorts. - 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors. + 0, // Udp/InErrors. udp.PacketsSent.Value(), // OutDatagrams. udp.ReceiveBufferErrors.Value(), // RcvbufErrors. - 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors. - 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors. - 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti. + 0, // Udp/SndbufErrors. + 0, // Udp/InCsumErrors. + 0, // Udp/IgnoredMulti. } default: return syserr.ErrEndpointOperation.ToError() diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 05c952b99..4e23d1e78 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -92,6 +92,59 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) { EXPECT_EQ(buf, to_write); } +// DeviceEntry is an entry in /proc/net/dev +struct DeviceEntry { + std::string name; + uint64_t stats[16]; +}; + +PosixErrorOr> GetDeviceMetricsFromProc( + const std::string dev) { + std::vector lines = absl::StrSplit(dev, '\n'); + std::vector entries; + + // /proc/net/dev prints 2 lines of headers followed by a line of metrics for + // each network interface. + for (unsigned i = 2; i < lines.size(); i++) { + // Ignore empty lines. + if (lines[i].empty()) { + continue; + } + + std::vector values = + absl::StrSplit(lines[i], ' ', absl::SkipWhitespace()); + + // Interface name + 16 values. + if (values.size() != 17) { + return PosixError(EINVAL, "invalid line: " + lines[i]); + } + + DeviceEntry entry; + entry.name = values[0]; + // Skip the interface name and read only the values. + for (unsigned j = 1; j < 17; j++) { + uint64_t num; + if (!absl::SimpleAtoi(values[j], &num)) { + return PosixError(EINVAL, "invalid value: " + values[j]); + } + entry.stats[j - 1] = num; + } + + entries.push_back(entry); + } + + return entries; +} + +// TEST(ProcNetDev, Format) tests that /proc/net/dev is parsable and +// contains at least one entry. +TEST(ProcNetDev, Format) { + auto dev = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/dev")); + auto entries = ASSERT_NO_ERRNO_AND_VALUE(GetDeviceMetricsFromProc(dev)); + + EXPECT_GT(entries.size(), 0); +} + PosixErrorOr GetSNMPMetricFromProc(const std::string snmp, const std::string& type, const std::string& item) { -- cgit v1.2.3 From 7e4073af12bed2c76bc5757ef3e5fbfba75308a0 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 24 Mar 2020 09:05:06 -0700 Subject: Move tcpip.PacketBuffer and IPTables to stack package. This is a precursor to be being able to build an intrusive list of PacketBuffers for use in queuing disciplines being implemented. Updates #2214 PiperOrigin-RevId: 302677662 --- pkg/sentry/socket/netfilter/BUILD | 1 - pkg/sentry/socket/netfilter/extensions.go | 14 +- pkg/sentry/socket/netfilter/netfilter.go | 121 ++++---- pkg/sentry/socket/netfilter/targets.go | 11 +- pkg/sentry/socket/netfilter/tcp_matcher.go | 11 +- pkg/sentry/socket/netfilter/udp_matcher.go | 13 +- pkg/sentry/socket/netstack/BUILD | 1 - pkg/sentry/socket/netstack/stack.go | 3 +- pkg/tcpip/BUILD | 2 - pkg/tcpip/iptables/BUILD | 18 -- pkg/tcpip/iptables/iptables.go | 314 --------------------- pkg/tcpip/iptables/targets.go | 144 ---------- pkg/tcpip/iptables/types.go | 180 ------------ pkg/tcpip/link/channel/channel.go | 14 +- pkg/tcpip/link/fdbased/endpoint.go | 6 +- pkg/tcpip/link/fdbased/endpoint_test.go | 10 +- pkg/tcpip/link/fdbased/mmap.go | 3 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 4 +- pkg/tcpip/link/loopback/loopback.go | 8 +- pkg/tcpip/link/muxed/injectable.go | 6 +- pkg/tcpip/link/muxed/injectable_test.go | 4 +- pkg/tcpip/link/sharedmem/sharedmem.go | 6 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 26 +- pkg/tcpip/link/sniffer/sniffer.go | 10 +- pkg/tcpip/link/tun/device.go | 2 +- pkg/tcpip/link/waitable/waitable.go | 6 +- pkg/tcpip/link/waitable/waitable_test.go | 18 +- pkg/tcpip/network/arp/arp.go | 12 +- pkg/tcpip/network/arp/arp_test.go | 2 +- pkg/tcpip/network/ip_test.go | 24 +- pkg/tcpip/network/ipv4/BUILD | 1 - pkg/tcpip/network/ipv4/icmp.go | 9 +- pkg/tcpip/network/ipv4/ipv4.go | 21 +- pkg/tcpip/network/ipv4/ipv4_test.go | 18 +- pkg/tcpip/network/ipv6/icmp.go | 10 +- pkg/tcpip/network/ipv6/icmp_test.go | 14 +- pkg/tcpip/network/ipv6/ipv6.go | 10 +- pkg/tcpip/network/ipv6/ipv6_test.go | 4 +- pkg/tcpip/network/ipv6/ndp_test.go | 8 +- pkg/tcpip/packet_buffer.go | 67 ----- pkg/tcpip/packet_buffer_state.go | 27 -- pkg/tcpip/stack/BUILD | 8 +- pkg/tcpip/stack/forwarder.go | 4 +- pkg/tcpip/stack/forwarder_test.go | 36 +-- pkg/tcpip/stack/iptables.go | 311 ++++++++++++++++++++ pkg/tcpip/stack/iptables_targets.go | 144 ++++++++++ pkg/tcpip/stack/iptables_types.go | 180 ++++++++++++ pkg/tcpip/stack/ndp.go | 4 +- pkg/tcpip/stack/ndp_test.go | 14 +- pkg/tcpip/stack/nic.go | 13 +- pkg/tcpip/stack/nic_test.go | 3 +- pkg/tcpip/stack/packet_buffer.go | 66 +++++ pkg/tcpip/stack/packet_buffer_state.go | 26 ++ pkg/tcpip/stack/registration.go | 32 +-- pkg/tcpip/stack/route.go | 6 +- pkg/tcpip/stack/stack.go | 11 +- pkg/tcpip/stack/stack_test.go | 28 +- pkg/tcpip/stack/transport_demuxer.go | 14 +- pkg/tcpip/stack/transport_demuxer_test.go | 2 +- pkg/tcpip/stack/transport_test.go | 27 +- pkg/tcpip/transport/icmp/BUILD | 1 - pkg/tcpip/transport/icmp/endpoint.go | 11 +- pkg/tcpip/transport/icmp/protocol.go | 2 +- pkg/tcpip/transport/packet/BUILD | 1 - pkg/tcpip/transport/packet/endpoint.go | 9 +- pkg/tcpip/transport/raw/BUILD | 1 - pkg/tcpip/transport/raw/endpoint.go | 9 +- pkg/tcpip/transport/tcp/BUILD | 1 - pkg/tcpip/transport/tcp/connect.go | 6 +- pkg/tcpip/transport/tcp/dispatcher.go | 3 +- pkg/tcpip/transport/tcp/endpoint.go | 7 +- pkg/tcpip/transport/tcp/forwarder.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 4 +- pkg/tcpip/transport/tcp/segment.go | 3 +- pkg/tcpip/transport/tcp/testing/context/context.go | 10 +- pkg/tcpip/transport/udp/BUILD | 1 - pkg/tcpip/transport/udp/endpoint.go | 9 +- pkg/tcpip/transport/udp/forwarder.go | 4 +- pkg/tcpip/transport/udp/protocol.go | 6 +- pkg/tcpip/transport/udp/udp_test.go | 4 +- 80 files changed, 1080 insertions(+), 1126 deletions(-) delete mode 100644 pkg/tcpip/iptables/BUILD delete mode 100644 pkg/tcpip/iptables/iptables.go delete mode 100644 pkg/tcpip/iptables/targets.go delete mode 100644 pkg/tcpip/iptables/types.go delete mode 100644 pkg/tcpip/packet_buffer.go delete mode 100644 pkg/tcpip/packet_buffer_state.go create mode 100644 pkg/tcpip/stack/iptables.go create mode 100644 pkg/tcpip/stack/iptables_targets.go create mode 100644 pkg/tcpip/stack/iptables_types.go create mode 100644 pkg/tcpip/stack/packet_buffer.go create mode 100644 pkg/tcpip/stack/packet_buffer_state.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 7cd2ce55b..e801abeb8 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -22,7 +22,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/usermem", ], diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index b4b244abf..0336a32d8 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -19,7 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -37,12 +37,12 @@ type matchMaker interface { // name is the matcher name as stored in the xt_entry_match struct. name() string - // marshal converts from an iptables.Matcher to an ABI struct. - marshal(matcher iptables.Matcher) []byte + // marshal converts from an stack.Matcher to an ABI struct. + marshal(matcher stack.Matcher) []byte // unmarshal converts from the ABI matcher struct to an - // iptables.Matcher. - unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) + // stack.Matcher. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) } // matchMakers maps the name of supported matchers to the matchMaker that @@ -58,7 +58,7 @@ func registerMatchMaker(mm matchMaker) { matchMakers[mm.name()] = mm } -func marshalMatcher(matcher iptables.Matcher) []byte { +func marshalMatcher(matcher stack.Matcher) []byte { matchMaker, ok := matchMakers[matcher.Name()] if !ok { panic(fmt.Sprintf("Unknown matcher of type %T.", matcher)) @@ -86,7 +86,7 @@ func marshalEntryMatch(name string, data []byte) []byte { return append(buf, make([]byte, size-len(buf))...) } -func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) { +func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { matchMaker, ok := matchMakers[match.Name.String()] if !ok { return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String()) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index b5b9be46f..55bcc3ace 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -129,19 +128,19 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) { - ipt := stack.IPTables() +func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { + ipt := stk.IPTables() table, ok := ipt.Tables[tablename.String()] if !ok { - return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename) + return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } // FillDefaultIPTables sets stack's IPTables to the default tables and // populates them with metadata. -func FillDefaultIPTables(stack *stack.Stack) { - ipt := iptables.DefaultTables() +func FillDefaultIPTables(stk *stack.Stack) { + ipt := stack.DefaultTables() // In order to fill in the metadata, we have to translate ipt from its // netstack format to Linux's giant-binary-blob format. @@ -154,14 +153,14 @@ func FillDefaultIPTables(stack *stack.Stack) { ipt.Tables[name] = table } - stack.SetIPTables(ipt) + stk.SetIPTables(ipt) } // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) { +func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) { // Return values. var entries linux.KernelIPTGetEntries var meta metadata @@ -234,19 +233,19 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern return entries, meta, nil } -func marshalTarget(target iptables.Target) []byte { +func marshalTarget(target stack.Target) []byte { switch tg := target.(type) { - case iptables.AcceptTarget: - return marshalStandardTarget(iptables.RuleAccept) - case iptables.DropTarget: - return marshalStandardTarget(iptables.RuleDrop) - case iptables.ErrorTarget: + case stack.AcceptTarget: + return marshalStandardTarget(stack.RuleAccept) + case stack.DropTarget: + return marshalStandardTarget(stack.RuleDrop) + case stack.ErrorTarget: return marshalErrorTarget(errorTargetName) - case iptables.UserChainTarget: + case stack.UserChainTarget: return marshalErrorTarget(tg.Name) - case iptables.ReturnTarget: - return marshalStandardTarget(iptables.RuleReturn) - case iptables.RedirectTarget: + case stack.ReturnTarget: + return marshalStandardTarget(stack.RuleReturn) + case stack.RedirectTarget: return marshalRedirectTarget() case JumpTarget: return marshalJumpTarget(tg) @@ -255,7 +254,7 @@ func marshalTarget(target iptables.Target) []byte { } } -func marshalStandardTarget(verdict iptables.RuleVerdict) []byte { +func marshalStandardTarget(verdict stack.RuleVerdict) []byte { nflog("convert to binary: marshalling standard target") // The target's name will be the empty string. @@ -316,13 +315,13 @@ func marshalJumpTarget(jt JumpTarget) []byte { // translateFromStandardVerdict translates verdicts the same way as the iptables // tool. -func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 { +func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { switch verdict { - case iptables.RuleAccept: + case stack.RuleAccept: return -linux.NF_ACCEPT - 1 - case iptables.RuleDrop: + case stack.RuleDrop: return -linux.NF_DROP - 1 - case iptables.RuleReturn: + case stack.RuleReturn: return linux.NF_RETURN default: // TODO(gvisor.dev/issue/170): Support Jump. @@ -331,18 +330,18 @@ func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 { } // translateToStandardTarget translates from the value in a -// linux.XTStandardTarget to an iptables.Verdict. -func translateToStandardTarget(val int32) (iptables.Target, error) { +// linux.XTStandardTarget to an stack.Verdict. +func translateToStandardTarget(val int32) (stack.Target, error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return iptables.AcceptTarget{}, nil + return stack.AcceptTarget{}, nil case -linux.NF_DROP - 1: - return iptables.DropTarget{}, nil + return stack.DropTarget{}, nil case -linux.NF_QUEUE - 1: return nil, errors.New("unsupported iptables verdict QUEUE") case linux.NF_RETURN: - return iptables.ReturnTarget{}, nil + return stack.ReturnTarget{}, nil default: return nil, fmt.Errorf("unknown iptables verdict %d", val) } @@ -350,7 +349,7 @@ func translateToStandardTarget(val int32) (iptables.Target, error) { // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. -func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { +func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Get the basic rules data (struct ipt_replace). if len(optVal) < linux.SizeOfIPTReplace { nflog("optVal has insufficient size for replace %d", len(optVal)) @@ -362,12 +361,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace) // TODO(gvisor.dev/issue/170): Support other tables. - var table iptables.Table + var table stack.Table switch replace.Name.String() { - case iptables.TablenameFilter: - table = iptables.EmptyFilterTable() - case iptables.TablenameNat: - table = iptables.EmptyNatTable() + case stack.TablenameFilter: + table = stack.EmptyFilterTable() + case stack.TablenameNat: + table = stack.EmptyNatTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -434,7 +433,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { } optVal = optVal[targetSize:] - table.Rules = append(table.Rules, iptables.Rule{ + table.Rules = append(table.Rules, stack.Rule{ Filter: filter, Target: target, Matchers: matchers, @@ -465,11 +464,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { table.Underflows[hk] = ruleIdx } } - if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset { + if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset { nflog("hook %v is unset.", hk) return syserr.ErrInvalidArgument } - if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset { + if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset { nflog("underflow %v is unset.", hk) return syserr.ErrInvalidArgument } @@ -478,7 +477,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Add the user chains. for ruleIdx, rule := range table.Rules { - target, ok := rule.Target.(iptables.UserChainTarget) + target, ok := rule.Target.(stack.UserChainTarget) if !ok { continue } @@ -522,8 +521,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // PREROUTING chain right now, make sure all other chains point to // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook != iptables.Input && hook != iptables.Prerouting { - if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok { + if hook != stack.Input && hook != stack.Prerouting { + if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument } @@ -535,7 +534,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stack.IPTables() + ipt := stk.IPTables() table.SetMetadata(metadata{ HookEntry: replace.HookEntry, Underflow: replace.Underflow, @@ -543,16 +542,16 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { Size: replace.Size, }) ipt.Tables[replace.Name.String()] = table - stack.SetIPTables(ipt) + stk.SetIPTables(ipt) return nil } // parseMatchers parses 0 or more matchers from optVal. optVal should contain // only the matchers. -func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) { +func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) { nflog("set entries: parsing matchers of size %d", len(optVal)) - var matchers []iptables.Matcher + var matchers []stack.Matcher for len(optVal) > 0 { nflog("set entries: optVal has len %d", len(optVal)) @@ -594,7 +593,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma // parseTarget parses a target from optVal. optVal should contain only the // target. -func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target, error) { +func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { nflog("set entries: parsing target of size %d", len(optVal)) if len(optVal) < linux.SizeOfXTEntryTarget { return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) @@ -638,11 +637,11 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target switch name := errorTarget.Name.String(); name { case errorTargetName: nflog("set entries: error target") - return iptables.ErrorTarget{}, nil + return stack.ErrorTarget{}, nil default: // User defined chain. nflog("set entries: user-defined target %q", name) - return iptables.UserChainTarget{Name: name}, nil + return stack.UserChainTarget{Name: name}, nil } case redirectTargetName: @@ -659,8 +658,8 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target buf = optVal[:linux.SizeOfXTRedirectTarget] binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) - // Copy linux.XTRedirectTarget to iptables.RedirectTarget. - var target iptables.RedirectTarget + // Copy linux.XTRedirectTarget to stack.RedirectTarget. + var target stack.RedirectTarget nfRange := redirectTarget.NfRange // RangeSize should be 1. @@ -699,14 +698,14 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String()) } -func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) { +func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if containsUnsupportedFields(iptip) { - return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) + return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) } if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { - return iptables.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } - return iptables.IPHeaderFilter{ + return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), Dst: tcpip.Address(iptip.Dst[:]), DstMask: tcpip.Address(iptip.DstMask[:]), @@ -733,30 +732,30 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { iptip.InverseFlags&^inverseMask != 0 } -func validUnderflow(rule iptables.Rule) bool { +func validUnderflow(rule stack.Rule) bool { if len(rule.Matchers) != 0 { return false } switch rule.Target.(type) { - case iptables.AcceptTarget, iptables.DropTarget: + case stack.AcceptTarget, stack.DropTarget: return true default: return false } } -func hookFromLinux(hook int) iptables.Hook { +func hookFromLinux(hook int) stack.Hook { switch hook { case linux.NF_INET_PRE_ROUTING: - return iptables.Prerouting + return stack.Prerouting case linux.NF_INET_LOCAL_IN: - return iptables.Input + return stack.Input case linux.NF_INET_FORWARD: - return iptables.Forward + return stack.Forward case linux.NF_INET_LOCAL_OUT: - return iptables.Output + return stack.Output case linux.NF_INET_POST_ROUTING: - return iptables.Postrouting + return stack.Postrouting } panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index c421b87cf..c948de876 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,11 +15,10 @@ package netfilter import ( - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// JumpTarget implements iptables.Target. +// JumpTarget implements stack.Target. type JumpTarget struct { // Offset is the byte offset of the rule to jump to. It is used for // marshaling and unmarshaling. @@ -29,7 +28,7 @@ type JumpTarget struct { RuleNum int } -// Action implements iptables.Target.Action. -func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) { - return iptables.RuleJump, jt.RuleNum +// Action implements stack.Target.Action. +func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) { + return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index f9945e214..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -19,9 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -40,7 +39,7 @@ func (tcpMarshaler) name() string { } // marshal implements matchMaker.marshal. -func (tcpMarshaler) marshal(mr iptables.Matcher) []byte { +func (tcpMarshaler) marshal(mr stack.Matcher) []byte { matcher := mr.(*TCPMatcher) xttcp := linux.XTTCP{ SourcePortStart: matcher.sourcePortStart, @@ -53,7 +52,7 @@ func (tcpMarshaler) marshal(mr iptables.Matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) { +func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfXTTCP { return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf)) } @@ -97,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) if netHeader.TransportProtocol() != header.TCPProtocolNumber { @@ -115,7 +114,7 @@ func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac // Now we need the transport header. However, this may not have been set // yet. // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the iptables.Check codepath as matchers are + // ultimately be moved into the stack.Check codepath as matchers are // added. var tcpHeader header.TCP if pkt.TransportHeader != nil { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 86aa11696..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -19,9 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -40,7 +39,7 @@ func (udpMarshaler) name() string { } // marshal implements matchMaker.marshal. -func (udpMarshaler) marshal(mr iptables.Matcher) []byte { +func (udpMarshaler) marshal(mr stack.Matcher) []byte { matcher := mr.(*UDPMatcher) xtudp := linux.XTUDP{ SourcePortStart: matcher.sourcePortStart, @@ -53,7 +52,7 @@ func (udpMarshaler) marshal(mr iptables.Matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) { +func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfXTUDP { return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf)) } @@ -94,11 +93,11 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved - // into the iptables.Check codepath as matchers are added. + // into the stack.Check codepath as matchers are added. if netHeader.TransportProtocol() != header.UDPProtocolNumber { return false, false } @@ -114,7 +113,7 @@ func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac // Now we need the transport header. However, this may not have been set // yet. // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the iptables.Check codepath as matchers are + // ultimately be moved into the stack.Check codepath as matchers are // added. var udpHeader header.UDP if pkt.TransportHeader != nil { diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ab01cb4fa..cbf46b1e9 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -38,7 +38,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index a8e2e8c24..f5fa18136 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -363,7 +362,7 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (iptables.IPTables, error) { +func (s *Stack) IPTables() (stack.IPTables, error) { return s.Stack.IPTables(), nil } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 26f7ba86b..454e07662 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -5,8 +5,6 @@ package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ - "packet_buffer.go", - "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", "timer.go", diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD deleted file mode 100644 index d1b73cfdf..000000000 --- a/pkg/tcpip/iptables/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "iptables", - srcs = [ - "iptables.go", - "targets.go", - "types.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go deleted file mode 100644 index d30571c74..000000000 --- a/pkg/tcpip/iptables/iptables.go +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package iptables supports packet filtering and manipulation via the iptables -// tool. -package iptables - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -// Table names. -const ( - TablenameNat = "nat" - TablenameMangle = "mangle" - TablenameFilter = "filter" -) - -// Chain names as defined by net/ipv4/netfilter/ip_tables.c. -const ( - ChainNamePrerouting = "PREROUTING" - ChainNameInput = "INPUT" - ChainNameForward = "FORWARD" - ChainNameOutput = "OUTPUT" - ChainNamePostrouting = "POSTROUTING" -) - -// HookUnset indicates that there is no hook set for an entrypoint or -// underflow. -const HookUnset = -1 - -// DefaultTables returns a default set of tables. Each chain is set to accept -// all packets. -func DefaultTables() IPTables { - // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for - // iotas. - return IPTables{ - Tables: map[string]Table{ - TablenameNat: Table{ - Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, - }, - BuiltinChains: map[Hook]int{ - Prerouting: 0, - Input: 1, - Output: 2, - Postrouting: 3, - }, - Underflows: map[Hook]int{ - Prerouting: 0, - Input: 1, - Output: 2, - Postrouting: 3, - }, - UserChains: map[string]int{}, - }, - TablenameMangle: Table{ - Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, - }, - BuiltinChains: map[Hook]int{ - Prerouting: 0, - Output: 1, - }, - Underflows: map[Hook]int{ - Prerouting: 0, - Output: 1, - }, - UserChains: map[string]int{}, - }, - TablenameFilter: Table{ - Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, - }, - BuiltinChains: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, - }, - Underflows: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, - }, - UserChains: map[string]int{}, - }, - }, - Priorities: map[Hook][]string{ - Input: []string{TablenameNat, TablenameFilter}, - Prerouting: []string{TablenameMangle, TablenameNat}, - Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, - }, - } -} - -// EmptyFilterTable returns a Table with no rules and the filter table chains -// mapped to HookUnset. -func EmptyFilterTable() Table { - return Table{ - Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, - }, - Underflows: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, - }, - UserChains: map[string]int{}, - } -} - -// EmptyNatTable returns a Table with no rules and the filter table chains -// mapped to HookUnset. -func EmptyNatTable() Table { - return Table{ - Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, - }, - Underflows: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, - }, - UserChains: map[string]int{}, - } -} - -// A chainVerdict is what a table decides should be done with a packet. -type chainVerdict int - -const ( - // chainAccept indicates the packet should continue through netstack. - chainAccept chainVerdict = iota - - // chainAccept indicates the packet should be dropped. - chainDrop - - // chainReturn indicates the packet should return to the calling chain - // or the underflow rule of a builtin chain. - chainReturn -) - -// Check runs pkt through the rules for hook. It returns true when the packet -// should continue traversing the network stack and false when it should be -// dropped. -// -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { - // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] - ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { - // If the table returns Accept, move on to the next table. - case chainAccept: - continue - // The Drop verdict is final. - case chainDrop: - return false - case chainReturn: - // Any Return from a built-in chain means we have to - // call the underflow. - underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt); v { - case RuleAccept: - continue - case RuleDrop: - return false - case RuleJump, RuleReturn: - panic("Underflows should only return RuleAccept or RuleDrop.") - default: - panic(fmt.Sprintf("Unknown verdict: %d", v)) - } - - default: - panic(fmt.Sprintf("Unknown verdict %v.", verdict)) - } - } - - // Every table returned Accept. - return true -} - -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict { - // Start from ruleIdx and walk the list of rules until a rule gives us - // a verdict. - for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { - case RuleAccept: - return chainAccept - - case RuleDrop: - return chainDrop - - case RuleReturn: - return chainReturn - - case RuleJump: - // "Jumping" to the next rule just means we're - // continuing on down the list. - if jumpTo == ruleIdx+1 { - ruleIdx++ - continue - } - switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { - case chainAccept: - return chainAccept - case chainDrop: - return chainDrop - case chainReturn: - ruleIdx++ - continue - default: - panic(fmt.Sprintf("Unknown verdict: %d", verdict)) - } - - default: - panic(fmt.Sprintf("Unknown verdict: %d", verdict)) - } - - } - - // We got through the entire table without a decision. Default to DROP - // for safety. - return chainDrop -} - -// Precondition: pk.NetworkHeader is set. -func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { - rule := table.Rules[ruleIdx] - - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). - if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() - } - - // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { - // Continue on to the next rule. - return RuleJump, ruleIdx + 1 - } - - // Go through each rule matcher. If they all match, run - // the rule target. - for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") - if hotdrop { - return RuleDrop, 0 - } - if !matches { - // Continue on to the next rule. - return RuleJump, ruleIdx + 1 - } - } - - // All the matchers matched, so run the target. - return rule.Target.Action(pkt) -} - -func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. - // Check the transport protocol. - if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { - return false - } - - // Check the destination IP. - dest := hdr.DestinationAddress() - matches := true - for i := range filter.Dst { - if dest[i]&filter.DstMask[i] != filter.Dst[i] { - matches = false - break - } - } - if matches == filter.DstInvert { - return false - } - - return true -} diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go deleted file mode 100644 index e457f2349..000000000 --- a/pkg/tcpip/iptables/targets.go +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -// AcceptTarget accepts packets. -type AcceptTarget struct{} - -// Action implements Target.Action. -func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { - return RuleAccept, 0 -} - -// DropTarget drops packets. -type DropTarget struct{} - -// Action implements Target.Action. -func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { - return RuleDrop, 0 -} - -// ErrorTarget logs an error and drops the packet. It represents a target that -// should be unreachable. -type ErrorTarget struct{} - -// Action implements Target.Action. -func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { - log.Debugf("ErrorTarget triggered.") - return RuleDrop, 0 -} - -// UserChainTarget marks a rule as the beginning of a user chain. -type UserChainTarget struct { - Name string -} - -// Action implements Target.Action. -func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) { - panic("UserChainTarget should never be called.") -} - -// ReturnTarget returns from the current chain. If the chain is a built-in, the -// hook's underflow should be called. -type ReturnTarget struct{} - -// Action implements Target.Action. -func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) { - return RuleReturn, 0 -} - -// RedirectTarget redirects the packet by modifying the destination port/IP. -// Min and Max values for IP and Ports in the struct indicate the range of -// values which can be used to redirect. -type RedirectTarget struct { - // TODO(gvisor.dev/issue/170): Other flags need to be added after - // we support them. - // RangeProtoSpecified flag indicates single port is specified to - // redirect. - RangeProtoSpecified bool - - // Min address used to redirect. - MinIP tcpip.Address - - // Max address used to redirect. - MaxIP tcpip.Address - - // Min port used to redirect. - MinPort uint16 - - // Max port used to redirect. - MaxPort uint16 -} - -// Action implements Target.Action. -// TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for PREROUTING and calls pkt.Clone(), neither -// of which should be the case. -func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer) (RuleVerdict, int) { - newPkt := pkt.Clone() - - // Set network header. - headerView := newPkt.Data.First() - netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] - - hlen := int(netHeader.HeaderLength()) - tlen := int(netHeader.TotalLength()) - newPkt.Data.TrimFront(hlen) - newPkt.Data.CapLength(tlen - hlen) - - // TODO(gvisor.dev/issue/170): Change destination address to - // loopback or interface address on which the packet was - // received. - - // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if - // we need to change dest address (for OUTPUT chain) or ports. - switch protocol := netHeader.TransportProtocol(); protocol { - case header.UDPProtocolNumber: - var udpHeader header.UDP - if newPkt.TransportHeader != nil { - udpHeader = header.UDP(newPkt.TransportHeader) - } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { - return RuleDrop, 0 - } - udpHeader = header.UDP(newPkt.Data.First()) - } - udpHeader.SetDestinationPort(rt.MinPort) - case header.TCPProtocolNumber: - var tcpHeader header.TCP - if newPkt.TransportHeader != nil { - tcpHeader = header.TCP(newPkt.TransportHeader) - } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { - return RuleDrop, 0 - } - tcpHeader = header.TCP(newPkt.TransportHeader) - } - // TODO(gvisor.dev/issue/170): Need to recompute checksum - // and implement nat connection tracking to support TCP. - tcpHeader.SetDestinationPort(rt.MinPort) - default: - return RuleDrop, 0 - } - - return RuleAccept, 0 -} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go deleted file mode 100644 index e7fcf6bff..000000000 --- a/pkg/tcpip/iptables/types.go +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "gvisor.dev/gvisor/pkg/tcpip" -) - -// A Hook specifies one of the hooks built into the network stack. -// -// Userspace app Userspace app -// ^ | -// | v -// [Input] [Output] -// ^ | -// | v -// | routing -// | | -// | v -// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> -type Hook uint - -// These values correspond to values in include/uapi/linux/netfilter.h. -const ( - // Prerouting happens before a packet is routed to applications or to - // be forwarded. - Prerouting Hook = iota - - // Input happens before a packet reaches an application. - Input - - // Forward happens once it's decided that a packet should be forwarded - // to another host. - Forward - - // Output happens after a packet is written by an application to be - // sent out. - Output - - // Postrouting happens just before a packet goes out on the wire. - Postrouting - - // The total number of hooks. - NumHooks -) - -// A RuleVerdict is what a rule decides should be done with a packet. -type RuleVerdict int - -const ( - // RuleAccept indicates the packet should continue through netstack. - RuleAccept RuleVerdict = iota - - // RuleDrop indicates the packet should be dropped. - RuleDrop - - // RuleJump indicates the packet should jump to another chain. - RuleJump - - // RuleReturn indicates the packet should return to the previous chain. - RuleReturn -) - -// IPTables holds all the tables for a netstack. -type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table - - // Priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string -} - -// A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules with some metadata for entrypoints and such. -type Table struct { - // Rules holds the rules that make up the table. - Rules []Rule - - // BuiltinChains maps builtin chains to their entrypoint rule in Rules. - BuiltinChains map[Hook]int - - // Underflows maps builtin chains to their underflow rule in Rules - // (i.e. the rule to execute if the chain returns without a verdict). - Underflows map[Hook]int - - // UserChains holds user-defined chains for the keyed by name. Users - // can give their chains arbitrary names. - UserChains map[string]int - - // Metadata holds information about the Table that is useful to users - // of IPTables, but not to the netstack IPTables code itself. - metadata interface{} -} - -// ValidHooks returns a bitmap of the builtin hooks for the given table. -func (table *Table) ValidHooks() uint32 { - hooks := uint32(0) - for hook := range table.BuiltinChains { - hooks |= 1 << hook - } - return hooks -} - -// Metadata returns the metadata object stored in table. -func (table *Table) Metadata() interface{} { - return table.metadata -} - -// SetMetadata sets the metadata object stored in table. -func (table *Table) SetMetadata(metadata interface{}) { - table.metadata = metadata -} - -// A Rule is a packet processing rule. It consists of two pieces. First it -// contains zero or more matchers, each of which is a specification of which -// packets this rule applies to. If there are no matchers in the rule, it -// applies to any packet. -type Rule struct { - // Filter holds basic IP filtering fields common to every rule. - Filter IPHeaderFilter - - // Matchers is the list of matchers for this rule. - Matchers []Matcher - - // Target is the action to invoke if all the matchers match the packet. - Target Target -} - -// IPHeaderFilter holds basic IP filtering data common to every rule. -type IPHeaderFilter struct { - // Protocol matches the transport protocol. - Protocol tcpip.TransportProtocolNumber - - // Dst matches the destination IP address. - Dst tcpip.Address - - // DstMask masks bits of the destination IP address when comparing with - // Dst. - DstMask tcpip.Address - - // DstInvert inverts the meaning of the destination IP check, i.e. when - // true the filter will match packets that fail the destination - // comparison. - DstInvert bool -} - -// A Matcher is the interface for matching packets. -type Matcher interface { - // Name returns the name of the Matcher. - Name() string - - // Match returns whether the packet matches and whether the packet - // should be "hotdropped", i.e. dropped immediately. This is usually - // used for suspicious packets. - // - // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool) -} - -// A Target is the interface for taking an action for a packet. -type Target interface { - // Action takes an action on the packet and returns a verdict on how - // traversal should (or should not) continue. If the return value is - // Jump, it also returns the index of the rule to jump to. - Action(packet tcpip.PacketBuffer) (RuleVerdict, int) -} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 5944ba190..a8d6653ce 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -28,7 +28,7 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Pkt tcpip.PacketBuffer + Pkt stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO Route stack.Route @@ -203,12 +203,12 @@ func (e *Endpoint) NumQueued() int { } // InjectInbound injects an inbound packet. -func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } @@ -251,7 +251,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() @@ -269,7 +269,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() @@ -280,7 +280,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac off := pkt.DataOffset size := pkt.DataSize p := PacketInfo{ - Pkt: tcpip.PacketBuffer{ + Pkt: stack.PacketBuffer{ Header: pkt.Header, Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), }, @@ -301,7 +301,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: tcpip.PacketBuffer{Data: vv}, + Pkt: stack.PacketBuffer{Data: vv}, Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 3b36b9673..235e647ff 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -386,7 +386,7 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) @@ -440,7 +440,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { var ethHdrBuf []byte // hdr + data iovLen := 2 @@ -610,7 +610,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 2066987eb..c7dbbbc6b 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents tcpip.PacketBuffer + contents stack.PacketBuffer } type context struct { @@ -92,7 +92,7 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { c.ch <- packetInfo{remote, protocol, pkt} } @@ -168,7 +168,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -261,7 +261,7 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{ Header: hdr, Data: buffer.VectorisedView{}, }); err != nil { @@ -324,7 +324,7 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: tcpip.PacketBuffer{ + contents: stack.PacketBuffer{ Data: buffer.View(b).ToVectorisedView(), LinkHeader: buffer.View(hdr), }, diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 62ed1e569..fe2bf3b0b 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -190,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{ + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, stack.PacketBuffer{ Data: buffer.View(pkt).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index c67d684ce..cb4cbea69 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - pkt := tcpip.PacketBuffer{ + pkt := stack.PacketBuffer{ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), LinkHeader: buffer.View(eth), } @@ -296,7 +296,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - pkt := tcpip.PacketBuffer{ + pkt := stack.PacketBuffer{ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), LinkHeader: buffer.View(eth), } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 499cc608f..4039753b7 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } @@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // There should be an ethernet header at the beginning of vv. linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 445b22c17..f5973066d 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,14 +80,14 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) } // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute @@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts [ // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 63b249837..87c734c1f 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) @@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buffer.NewView(0).ToVectorisedView(), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 655e537c4..6461d0108 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { // Add the ethernet header here. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) @@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } @@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 5c729a439..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, @@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ Header: hdr, }); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }) @@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: uu, }); err != want { @@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 3392b7edd..0a6b8945c 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -123,7 +123,7 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("recv", protocol, pkt.Data.First(), nil) } @@ -200,7 +200,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, pkt.Header.View(), gso) } @@ -232,7 +232,7 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { e.dumpPacket(gso, protocol, pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } @@ -240,10 +240,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { view := pkts[0].Data.ToView() for _, pkt := range pkts { - e.dumpPacket(gso, protocol, tcpip.PacketBuffer{ + e.dumpPacket(gso, protocol, stack.PacketBuffer{ Header: pkt.Header, Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), }) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index f6e301304..617446ea2 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -213,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := tcpip.PacketBuffer{ + pkt := stack.PacketBuffer{ Data: buffer.View(data).ToVectorisedView(), } if ethHdr != nil { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index a8de38979..52fe397bf 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,7 +50,7 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } @@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } @@ -112,7 +112,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { return len(pkts), nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 31b11a27a..88224e494 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { e.dispatchCount++ } @@ -65,13 +65,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { e.writeCount += len(pkts) return len(pkts), nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index e9fcc89a8..255098372 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,20 +79,20 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { @@ -113,7 +113,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ Header: hdr, }) fallthrough // also fill the cache from requests @@ -167,7 +167,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 03cf03b6d..b3e239ac7 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ Data: v.ToVectorisedView(), }) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f4d78f8c6..4950d69fc 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } @@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) @@ -150,7 +150,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } @@ -246,7 +246,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -289,7 +289,7 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -379,7 +379,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: vv, }) if want := c.expectedCount; o.controlCalls != want { @@ -444,7 +444,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: frag1.ToVectorisedView(), }) if o.dataCalls != 0 { @@ -452,7 +452,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send second segment. - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: frag2.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -487,7 +487,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -530,7 +530,7 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -644,7 +644,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: view[:len(view)-c.trunc].ToVectorisedView(), }) if want := c.expectedCount; o.controlCalls != want { diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 0fef2b1f1..880ea7de2 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -13,7 +13,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 32bf39e43..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,7 +15,6 @@ package ipv4 import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -25,7 +24,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP @@ -53,7 +52,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived v := pkt.Data.First() @@ -85,7 +84,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ Data: pkt.Data.Clone(nil), NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), }) @@ -99,7 +98,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: vv, TransportHeader: buffer.View(pkt), diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 4f1742938..b3ee6000e 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -125,7 +124,7 @@ func (e *endpoint) GSOMaxSize() uint32 { // packet's stated length matches the length of the header+payload. mtu // includes the IP header and options. This does not support the DontFragment // IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() @@ -165,7 +164,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, if i > 0 { newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -184,7 +183,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -198,7 +197,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr := pkt.Header startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ Header: startOfHdr, Data: emptyVV, NetworkHeader: buffer.View(h), @@ -241,7 +240,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -253,7 +252,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + e.HandlePacket(&loopedR, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -273,7 +272,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } @@ -292,7 +291,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. ip := header.IPv4(pkt.Data.First()) @@ -344,7 +343,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { @@ -361,7 +360,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(iptables.Input, pkt); !ok { + if ok := ipt.Check(stack.Input, pkt); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index e900f1b45..5a864d832 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -113,7 +113,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { +func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) @@ -173,7 +173,7 @@ func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketIn type errorChannel struct { *channel.Endpoint - Ch chan tcpip.PacketBuffer + Ch chan stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -183,7 +183,7 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan tcpip.PacketBuffer, size), + Ch: make(chan stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } @@ -202,7 +202,7 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { select { case e.Ch <- pkt: default: @@ -281,13 +281,13 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := tcpip.PacketBuffer{ + source := stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -295,7 +295,7 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []tcpip.PacketBuffer + var results []stack.PacketBuffer L: for { select { @@ -337,7 +337,7 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -459,7 +459,7 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), }) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 45dc757c7..8640feffc 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to @@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived @@ -243,7 +243,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, }); err != nil { sent.Dropped.Increment() @@ -330,7 +330,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: pkt.Data, }); err != nil { @@ -463,7 +463,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 50c4b6474..bae09ed94 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -56,7 +56,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { return nil } @@ -66,7 +66,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { } type stubLinkAddressCache struct { @@ -187,7 +187,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -326,7 +326,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ Data: vv, }) } @@ -561,7 +561,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -738,7 +738,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -916,7 +916,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 9aef5234b..29e597002 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -112,7 +112,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -124,7 +124,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + e.HandlePacket(&loopedR, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -139,7 +139,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } @@ -161,14 +161,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1cbfa7278..ed98ef22a 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -55,7 +55,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -113,7 +113,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index c9395de52..f924ed9e1 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -135,7 +135,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -238,7 +238,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -304,7 +304,7 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, tcpip.PacketBuffer{ + ep.HandlePacket(r, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -588,7 +588,7 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go deleted file mode 100644 index ab24372e7..000000000 --- a/pkg/tcpip/packet_buffer.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// A PacketBuffer contains all the data of a network packet. -// -// As a PacketBuffer traverses up the stack, it may be necessary to pass it to -// multiple endpoints. Clone() should be called in such cases so that -// modifications to the Data field do not affect other copies. -// -// +stateify savable -type PacketBuffer struct { - // Data holds the payload of the packet. For inbound packets, it also - // holds the headers, which are consumed as the packet moves up the - // stack. Headers are guaranteed not to be split across views. - // - // The bytes backing Data are immutable, but Data itself may be trimmed - // or otherwise modified. - Data buffer.VectorisedView - - // DataOffset is used for GSO output. It is the offset into the Data - // field where the payload of this packet starts. - DataOffset int - - // DataSize is used for GSO output. It is the size of this packet's - // payload. - DataSize int - - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. - Header buffer.Prependable - - // These fields are used by both inbound and outbound packets. They - // typically overlap with the Data and Header fields. - // - // The bytes backing these views are immutable. Each field may be nil - // if either it has not been set yet or no such header exists (e.g. - // packets sent via loopback may not have a link header). - // - // These fields may be Views into other slices (either Data or Header). - // SR dosen't support this, so deep copies are necessary in some cases. - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View -} - -// Clone makes a copy of pk. It clones the Data field, which creates a new -// VectorisedView but does not deep copy the underlying bytes. -// -// Clone also does not deep copy any of its other fields. -func (pk PacketBuffer) Clone() PacketBuffer { - pk.Data = pk.Data.Clone(nil) - return pk -} diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go deleted file mode 100644 index ad3cc24fa..000000000 --- a/pkg/tcpip/packet_buffer_state.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// beforeSave is invoked by stateify. -func (pk *PacketBuffer) beforeSave() { - // Non-Data fields may be slices of the Data field. This causes - // problems for SR, so during save we make each header independent. - pk.Header = pk.Header.DeepCopy() - pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) - pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) - pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) -} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 6c029b2fb..7a43a1d4e 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -21,10 +21,15 @@ go_library( "dhcpv6configurationfromndpra_string.go", "forwarder.go", "icmp_rate_limit.go", + "iptables.go", + "iptables_targets.go", + "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", "ndp.go", "nic.go", + "packet_buffer.go", + "packet_buffer_state.go", "registration.go", "route.go", "stack.go", @@ -34,6 +39,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/ilist", + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", @@ -41,7 +47,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", @@ -65,7 +70,6 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 631953935..6b64cd37f 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go @@ -32,7 +32,7 @@ type pendingPacket struct { nic *NIC route *Route proto tcpip.NetworkProtocolNumber - pkt tcpip.PacketBuffer + pkt PacketBuffer } type forwardQueue struct { @@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue { return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { shouldWait := false f.Lock() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 321b7524d..c45c43d21 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -68,7 +68,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { return &f.id } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt tcpip.PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) @@ -89,7 +89,7 @@ func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) @@ -101,11 +101,11 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -183,7 +183,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress LocalLinkAddress tcpip.LinkAddress - Pkt tcpip.PacketBuffer + Pkt PacketBuffer } type fwdTestLinkEndpoint struct { @@ -196,12 +196,12 @@ type fwdTestLinkEndpoint struct { } // InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } @@ -244,7 +244,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -260,7 +260,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 for _, pkt := range pkts { e.WritePacket(r, gso, protocol, pkt) @@ -273,7 +273,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.Pack // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: tcpip.PacketBuffer{Data: vv}, + Pkt: PacketBuffer{Data: vv}, } select { @@ -355,7 +355,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -392,7 +392,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -423,7 +423,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -453,7 +453,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -461,7 +461,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // forwarded to NIC 2. buf = buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -503,7 +503,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -550,7 +550,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[0] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -603,7 +603,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go new file mode 100644 index 000000000..37907ae24 --- /dev/null +++ b/pkg/tcpip/stack/iptables.go @@ -0,0 +1,311 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// Table names. +const ( + TablenameNat = "nat" + TablenameMangle = "mangle" + TablenameFilter = "filter" +) + +// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +const ( + ChainNamePrerouting = "PREROUTING" + ChainNameInput = "INPUT" + ChainNameForward = "FORWARD" + ChainNameOutput = "OUTPUT" + ChainNamePostrouting = "POSTROUTING" +) + +// HookUnset indicates that there is no hook set for an entrypoint or +// underflow. +const HookUnset = -1 + +// DefaultTables returns a default set of tables. Each chain is set to accept +// all packets. +func DefaultTables() IPTables { + // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for + // iotas. + return IPTables{ + Tables: map[string]Table{ + TablenameNat: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, + }, + Underflows: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, + }, + UserChains: map[string]int{}, + }, + TablenameMangle: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: map[Hook]int{ + Prerouting: 0, + Output: 1, + }, + UserChains: map[string]int{}, + }, + TablenameFilter: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + Underflows: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + UserChains: map[string]int{}, + }, + }, + Priorities: map[Hook][]string{ + Input: []string{TablenameNat, TablenameFilter}, + Prerouting: []string{TablenameMangle, TablenameNat}, + Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + }, + } +} + +// EmptyFilterTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyFilterTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + Underflows: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + UserChains: map[string]int{}, + } +} + +// EmptyNatTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyNatTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + Underflows: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + UserChains: map[string]int{}, + } +} + +// A chainVerdict is what a table decides should be done with a packet. +type chainVerdict int + +const ( + // chainAccept indicates the packet should continue through netstack. + chainAccept chainVerdict = iota + + // chainAccept indicates the packet should be dropped. + chainDrop + + // chainReturn indicates the packet should return to the calling chain + // or the underflow rule of a builtin chain. + chainReturn +) + +// Check runs pkt through the rules for hook. It returns true when the packet +// should continue traversing the network stack and false when it should be +// dropped. +// +// Precondition: pkt.NetworkHeader is set. +func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { + // Go through each table containing the hook. + for _, tablename := range it.Priorities[hook] { + table := it.Tables[tablename] + ruleIdx := table.BuiltinChains[hook] + switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { + // If the table returns Accept, move on to the next table. + case chainAccept: + continue + // The Drop verdict is final. + case chainDrop: + return false + case chainReturn: + // Any Return from a built-in chain means we have to + // call the underflow. + underflow := table.Rules[table.Underflows[hook]] + switch v, _ := underflow.Target.Action(pkt); v { + case RuleAccept: + continue + case RuleDrop: + return false + case RuleJump, RuleReturn: + panic("Underflows should only return RuleAccept or RuleDrop.") + default: + panic(fmt.Sprintf("Unknown verdict: %d", v)) + } + + default: + panic(fmt.Sprintf("Unknown verdict %v.", verdict)) + } + } + + // Every table returned Accept. + return true +} + +// Precondition: pkt.NetworkHeader is set. +func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { + // Start from ruleIdx and walk the list of rules until a rule gives us + // a verdict. + for ruleIdx < len(table.Rules) { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { + case RuleAccept: + return chainAccept + + case RuleDrop: + return chainDrop + + case RuleReturn: + return chainReturn + + case RuleJump: + // "Jumping" to the next rule just means we're + // continuing on down the list. + if jumpTo == ruleIdx+1 { + ruleIdx++ + continue + } + switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + case chainAccept: + return chainAccept + case chainDrop: + return chainDrop + case chainReturn: + ruleIdx++ + continue + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + } + + // We got through the entire table without a decision. Default to DROP + // for safety. + return chainDrop +} + +// Precondition: pk.NetworkHeader is set. +func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { + rule := table.Rules[ruleIdx] + + // If pkt.NetworkHeader hasn't been set yet, it will be contained in + // pkt.Data.First(). + if pkt.NetworkHeader == nil { + pkt.NetworkHeader = pkt.Data.First() + } + + // Check whether the packet matches the IP header filter. + if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + + // Go through each rule matcher. If they all match, run + // the rule target. + for _, matcher := range rule.Matchers { + matches, hotdrop := matcher.Match(hook, pkt, "") + if hotdrop { + return RuleDrop, 0 + } + if !matches { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + } + + // All the matchers matched, so run the target. + return rule.Target.Action(pkt) +} + +func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the destination IP. + dest := hdr.DestinationAddress() + matches := true + for i := range filter.Dst { + if dest[i]&filter.DstMask[i] != filter.Dst[i] { + matches = false + break + } + } + if matches == filter.DstInvert { + return false + } + + return true +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go new file mode 100644 index 000000000..7b4543caf --- /dev/null +++ b/pkg/tcpip/stack/iptables_targets.go @@ -0,0 +1,144 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// AcceptTarget accepts packets. +type AcceptTarget struct{} + +// Action implements Target.Action. +func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + return RuleAccept, 0 +} + +// DropTarget drops packets. +type DropTarget struct{} + +// Action implements Target.Action. +func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + return RuleDrop, 0 +} + +// ErrorTarget logs an error and drops the packet. It represents a target that +// should be unreachable. +type ErrorTarget struct{} + +// Action implements Target.Action. +func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + log.Debugf("ErrorTarget triggered.") + return RuleDrop, 0 +} + +// UserChainTarget marks a rule as the beginning of a user chain. +type UserChainTarget struct { + Name string +} + +// Action implements Target.Action. +func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { + panic("UserChainTarget should never be called.") +} + +// ReturnTarget returns from the current chain. If the chain is a built-in, the +// hook's underflow should be called. +type ReturnTarget struct{} + +// Action implements Target.Action. +func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) { + return RuleReturn, 0 +} + +// RedirectTarget redirects the packet by modifying the destination port/IP. +// Min and Max values for IP and Ports in the struct indicate the range of +// values which can be used to redirect. +type RedirectTarget struct { + // TODO(gvisor.dev/issue/170): Other flags need to be added after + // we support them. + // RangeProtoSpecified flag indicates single port is specified to + // redirect. + RangeProtoSpecified bool + + // Min address used to redirect. + MinIP tcpip.Address + + // Max address used to redirect. + MaxIP tcpip.Address + + // Min port used to redirect. + MinPort uint16 + + // Max port used to redirect. + MaxPort uint16 +} + +// Action implements Target.Action. +// TODO(gvisor.dev/issue/170): Parse headers without copying. The current +// implementation only works for PREROUTING and calls pkt.Clone(), neither +// of which should be the case. +func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { + newPkt := pkt.Clone() + + // Set network header. + headerView := newPkt.Data.First() + netHeader := header.IPv4(headerView) + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + + hlen := int(netHeader.HeaderLength()) + tlen := int(netHeader.TotalLength()) + newPkt.Data.TrimFront(hlen) + newPkt.Data.CapLength(tlen - hlen) + + // TODO(gvisor.dev/issue/170): Change destination address to + // loopback or interface address on which the packet was + // received. + + // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if + // we need to change dest address (for OUTPUT chain) or ports. + switch protocol := netHeader.TransportProtocol(); protocol { + case header.UDPProtocolNumber: + var udpHeader header.UDP + if newPkt.TransportHeader != nil { + udpHeader = header.UDP(newPkt.TransportHeader) + } else { + if len(pkt.Data.First()) < header.UDPMinimumSize { + return RuleDrop, 0 + } + udpHeader = header.UDP(newPkt.Data.First()) + } + udpHeader.SetDestinationPort(rt.MinPort) + case header.TCPProtocolNumber: + var tcpHeader header.TCP + if newPkt.TransportHeader != nil { + tcpHeader = header.TCP(newPkt.TransportHeader) + } else { + if len(pkt.Data.First()) < header.TCPMinimumSize { + return RuleDrop, 0 + } + tcpHeader = header.TCP(newPkt.TransportHeader) + } + // TODO(gvisor.dev/issue/170): Need to recompute checksum + // and implement nat connection tracking to support TCP. + tcpHeader.SetDestinationPort(rt.MinPort) + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go new file mode 100644 index 000000000..2ffb55f2a --- /dev/null +++ b/pkg/tcpip/stack/iptables_types.go @@ -0,0 +1,180 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +// A Hook specifies one of the hooks built into the network stack. +// +// Userspace app Userspace app +// ^ | +// | v +// [Input] [Output] +// ^ | +// | v +// | routing +// | | +// | v +// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> +type Hook uint + +// These values correspond to values in include/uapi/linux/netfilter.h. +const ( + // Prerouting happens before a packet is routed to applications or to + // be forwarded. + Prerouting Hook = iota + + // Input happens before a packet reaches an application. + Input + + // Forward happens once it's decided that a packet should be forwarded + // to another host. + Forward + + // Output happens after a packet is written by an application to be + // sent out. + Output + + // Postrouting happens just before a packet goes out on the wire. + Postrouting + + // The total number of hooks. + NumHooks +) + +// A RuleVerdict is what a rule decides should be done with a packet. +type RuleVerdict int + +const ( + // RuleAccept indicates the packet should continue through netstack. + RuleAccept RuleVerdict = iota + + // RuleDrop indicates the packet should be dropped. + RuleDrop + + // RuleJump indicates the packet should jump to another chain. + RuleJump + + // RuleReturn indicates the packet should return to the previous chain. + RuleReturn +) + +// IPTables holds all the tables for a netstack. +type IPTables struct { + // Tables maps table names to tables. User tables have arbitrary names. + Tables map[string]Table + + // Priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. + Priorities map[Hook][]string +} + +// A Table defines a set of chains and hooks into the network stack. It is +// really just a list of rules with some metadata for entrypoints and such. +type Table struct { + // Rules holds the rules that make up the table. + Rules []Rule + + // BuiltinChains maps builtin chains to their entrypoint rule in Rules. + BuiltinChains map[Hook]int + + // Underflows maps builtin chains to their underflow rule in Rules + // (i.e. the rule to execute if the chain returns without a verdict). + Underflows map[Hook]int + + // UserChains holds user-defined chains for the keyed by name. Users + // can give their chains arbitrary names. + UserChains map[string]int + + // Metadata holds information about the Table that is useful to users + // of IPTables, but not to the netstack IPTables code itself. + metadata interface{} +} + +// ValidHooks returns a bitmap of the builtin hooks for the given table. +func (table *Table) ValidHooks() uint32 { + hooks := uint32(0) + for hook := range table.BuiltinChains { + hooks |= 1 << hook + } + return hooks +} + +// Metadata returns the metadata object stored in table. +func (table *Table) Metadata() interface{} { + return table.metadata +} + +// SetMetadata sets the metadata object stored in table. +func (table *Table) SetMetadata(metadata interface{}) { + table.metadata = metadata +} + +// A Rule is a packet processing rule. It consists of two pieces. First it +// contains zero or more matchers, each of which is a specification of which +// packets this rule applies to. If there are no matchers in the rule, it +// applies to any packet. +type Rule struct { + // Filter holds basic IP filtering fields common to every rule. + Filter IPHeaderFilter + + // Matchers is the list of matchers for this rule. + Matchers []Matcher + + // Target is the action to invoke if all the matchers match the packet. + Target Target +} + +// IPHeaderFilter holds basic IP filtering data common to every rule. +type IPHeaderFilter struct { + // Protocol matches the transport protocol. + Protocol tcpip.TransportProtocolNumber + + // Dst matches the destination IP address. + Dst tcpip.Address + + // DstMask masks bits of the destination IP address when comparing with + // Dst. + DstMask tcpip.Address + + // DstInvert inverts the meaning of the destination IP check, i.e. when + // true the filter will match packets that fail the destination + // comparison. + DstInvert bool +} + +// A Matcher is the interface for matching packets. +type Matcher interface { + // Name returns the name of the Matcher. + Name() string + + // Match returns whether the packet matches and whether the packet + // should be "hotdropped", i.e. dropped immediately. This is usually + // used for suspicious packets. + // + // Precondition: packet.NetworkHeader is set. + Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) +} + +// A Target is the interface for taking an action for a packet. +type Target interface { + // Action takes an action on the packet and returns a verdict on how + // traversal should (or should not) continue. If the return value is + // Jump, it also returns the index of the rule to jump to. + Action(packet PacketBuffer) (RuleVerdict, int) +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index d689a006d..630fdefc5 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -564,7 +564,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, tcpip.PacketBuffer{Header: hdr}, + }, PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -1283,7 +1283,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, tcpip.PacketBuffer{Header: hdr}, + }, PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 4368c236c..06edd05b6 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -602,7 +602,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -918,7 +918,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -953,14 +953,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -969,7 +969,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -977,7 +977,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -986,7 +986,7 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 9dcb1d52c..b6fa647ea 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" ) var ipv4BroadcastAddr = tcpip.ProtocolAddress{ @@ -1144,7 +1143,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1158,7 +1157,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1222,7 +1221,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. if protocol == header.IPv4ProtocolNumber { ipt := n.stack.IPTables() - if ok := ipt.Check(iptables.Prerouting, pkt); !ok { + if ok := ipt.Check(Prerouting, pkt); !ok { // iptables is telling us to drop the packet. return } @@ -1287,7 +1286,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. firstData := pkt.Data.First() @@ -1318,7 +1317,7 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1364,7 +1363,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index edaee3b86..d672fc157 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -17,7 +17,6 @@ package stack import ( "testing" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -45,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go new file mode 100644 index 000000000..1850fa8c3 --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer.go @@ -0,0 +1,66 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package stack + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +// +// +stateify savable +type PacketBuffer struct { + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // DataOffset is used for GSO output. It is the offset into the Data + // field where the payload of this packet starts. + DataOffset int + + // DataSize is used for GSO output. It is the size of this packet's + // payload. + DataSize int + + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. + Header buffer.Prependable + + // These fields are used by both inbound and outbound packets. They + // typically overlap with the Data and Header fields. + // + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other slices (either Data or Header). + // SR dosen't support this, so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +// +// Clone also does not deep copy any of its other fields. +func (pk PacketBuffer) Clone() PacketBuffer { + pk.Data = pk.Data.Clone(nil) + return pk +} diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go new file mode 100644 index 000000000..76602549e --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_state.go @@ -0,0 +1,26 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package stack + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// beforeSave is invoked by stateify. +func (pk *PacketBuffer) beforeSave() { + // Non-Data fields may be slices of the Data field. This causes + // problems for SR, so during save we make each header independent. + pk.Header = pk.Header.DeepCopy() + pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) + pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) + pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index fa28b46b1..ac043b722 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,12 +67,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +100,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, pkt PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +118,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +150,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -180,7 +180,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +189,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -242,15 +242,15 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -265,7 +265,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, pkt PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -322,7 +322,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -354,7 +354,7 @@ const ( // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets // out through the implementer's data link endpoint. When a link header exists, -// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// it sets each PacketBuffer's LinkHeader field before passing it up the // stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is @@ -385,7 +385,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the // given route. pkts must not be zero length. @@ -393,7 +393,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. @@ -426,7 +426,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index f565aafb2..9fbe8a411 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -153,7 +153,7 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } @@ -169,7 +169,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack } // WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } @@ -190,7 +190,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6f423874a..a9584d636 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -31,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" @@ -51,7 +50,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -428,7 +427,7 @@ type Stack struct { // tables are the iptables packet filtering and manipulation rules. The are // protected by tablesMu.` - tables iptables.IPTables + tables IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -738,7 +737,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1701,7 +1700,7 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() iptables.IPTables { +func (s *Stack) IPTables() IPTables { s.tablesMu.RLock() t := s.tables s.tablesMu.RUnlock() @@ -1709,7 +1708,7 @@ func (s *Stack) IPTables() iptables.IPTables { } // SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt iptables.IPTables) { +func (s *Stack) SetIPTables(ipt IPTables) { s.tablesMu.Lock() s.tables = ipt s.tablesMu.Unlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9836b340f..555fcd92f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -90,7 +90,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ @@ -126,7 +126,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, tcpip.PacketBuffer{ + f.HandlePacket(r, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } @@ -153,11 +153,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -287,7 +287,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -299,7 +299,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -311,7 +311,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -322,7 +322,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -334,7 +334,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -356,7 +356,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -414,7 +414,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -2257,7 +2257,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2339,7 +2339,7 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index d4c0359e8..c55e3e8bc 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -85,7 +85,7 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { epsByNic.mu.RLock() mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] @@ -116,7 +116,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { epsByNic.mu.RLock() defer epsByNic.mu.RUnlock() @@ -184,7 +184,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -312,7 +312,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs @@ -403,7 +403,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -453,7 +453,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -477,7 +477,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 0e3e239c5..84311bcc8 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -150,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5d1da2f8b..8ca9ac3cf 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -87,7 +86,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -214,7 +213,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -231,7 +230,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -242,8 +241,8 @@ func (f *fakeTransportEndpoint) State() uint32 { func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} -func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { - return iptables.IPTables{}, nil +func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) { + return stack.IPTables{}, nil } func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} @@ -288,7 +287,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { return true } @@ -368,7 +367,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -379,7 +378,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -390,7 +389,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -445,7 +444,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -456,7 +455,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -467,7 +466,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -622,7 +621,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: req.ToVectorisedView(), }) diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index ac18ec5b1..9ce625c17 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 2a396e9bc..613b12ead 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -135,7 +134,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { +func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil } @@ -441,7 +440,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), @@ -471,7 +470,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: dataVV, TransportHeader: buffer.View(icmpv6), @@ -733,7 +732,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -795,7 +794,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 113d92901..3c47692b2 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index d22de6b26..b989b1209 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 09a1cd436..df49d0995 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -100,8 +99,8 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb } // Abort implements stack.TransportEndpoint.Abort. -func (e *endpoint) Abort() { - e.Close() +func (ep *endpoint) Abort() { + ep.Close() } // Close implements tcpip.Endpoint.Close. @@ -134,7 +133,7 @@ func (ep *endpoint) Close() { func (ep *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (iptables.IPTables, error) { +func (ep *endpoint) IPTables() (stack.IPTables, error) { return ep.stack.IPTables(), nil } @@ -299,7 +298,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index c9baf4600..2eab09088 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -32,7 +32,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/packet", "//pkg/waiter", diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 2ef5fac76..536dafd1e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -161,7 +160,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { +func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil } @@ -342,7 +341,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{ + if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { return 0, nil, err @@ -350,7 +349,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { @@ -574,7 +573,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 2fdf6c0a5..7f94f9646 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -66,7 +66,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 53193afc6..79552fc61 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -705,7 +705,7 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu return nil } -func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { +func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { optLen := len(opts) hdr := &pkt.Header packetSize := pkt.DataSize @@ -752,7 +752,7 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect // Allocate one big slice for all the headers. hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen buf := make([]byte, n*hdrSize) - pkts := make([]tcpip.PacketBuffer, n) + pkts := make([]stack.PacketBuffer, n) for i := range pkts { pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) } @@ -795,7 +795,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) } - pkt := tcpip.PacketBuffer{ + pkt := stack.PacketBuffer{ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), DataOffset: 0, DataSize: data.Size(), diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 90ac956a9..6062ca916 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -187,7 +186,7 @@ func (d *dispatcher) wait() { } } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) if !s.parse() { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index eb8a9d73e..594efaa11 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -1120,7 +1119,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { } // IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { +func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil } @@ -2388,7 +2387,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -2407,7 +2406,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index c9ee5bf06..a094471b8 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b0f918bb4..57985b85d 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -140,7 +140,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { p.dispatcher.queuePacket(r, ep, id, pkt) } @@ -151,7 +151,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 5d0bc4f72..e6fe7985d 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,7 +18,6 @@ import ( "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -61,7 +60,7 @@ type segment struct { xmitCount uint32 } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 8cea20fb5..d4f6bc635 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -307,7 +307,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -363,7 +363,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: s, }) } @@ -371,7 +371,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) { // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: c.BuildSegment(payload, h), }) } @@ -380,7 +380,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) } @@ -548,7 +548,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index adc908e24..b5d2d0ba6 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,7 +32,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0af4514e1..a3372ac58 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -234,7 +233,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { +func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil } @@ -913,7 +912,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{ Header: hdr, Data: data, TransportHeader: buffer.View(udp), @@ -1260,7 +1259,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. hdr := header.UDP(pkt.Data.First()) if int(hdr.Length()) > pkt.Data.Size() { @@ -1327,7 +1326,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index fc706ede2..a674ceb68 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, @@ -61,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - pkt tcpip.PacketBuffer + pkt stack.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 8df089d22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,7 +66,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. hdr := header.UDP(pkt.Data.First()) if int(hdr.Length()) > pkt.Data.Size() { @@ -135,7 +135,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -172,7 +172,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload, }) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 34b7c2360..0905726c1 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -439,7 +439,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -486,7 +486,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), -- cgit v1.2.3 From d8c4eff3f77b2a5dde34389033fe0ce4589b2e82 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 25 Mar 2020 08:10:11 -0700 Subject: Automated rollback of changelist 301837227 PiperOrigin-RevId: 302891559 --- pkg/sentry/socket/netstack/netstack.go | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index c19f5639b..f14c336b9 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,6 +29,7 @@ import ( "io" "math" "reflect" + "sync/atomic" "syscall" "time" @@ -264,8 +265,14 @@ type SocketOperations struct { skType linux.SockType protocol int + // readViewHasData is 1 iff readView has data to be read, 0 otherwise. + // Must be accessed using atomic operations. It must only be written + // with readMu held but can be read without holding readMu. The latter + // is required to avoid deadlocks in epoll Readiness checks. + readViewHasData uint32 + // readMu protects access to the below fields. - readMu sync.RWMutex `state:"nosave"` + readMu sync.Mutex `state:"nosave"` // readView contains the remaining payload from the last packet. readView buffer.View // readCM holds control message information for the last packet read @@ -421,11 +428,13 @@ func (s *SocketOperations) fetchReadView() *syserr.Error { v, cms, err := s.Endpoint.Read(&s.sender) if err != nil { + atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) } s.readView = v s.readCM = cms + atomic.StoreUint32(&s.readViewHasData, 1) return nil } @@ -624,11 +633,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Check our cached value iff the caller asked for readability and the // endpoint itself is currently not readable. if (mask & ^r & waiter.EventIn) != 0 { - s.readMu.RLock() - if len(s.readView) > 0 { + if atomic.LoadUint32(&s.readViewHasData) == 1 { r |= waiter.EventIn } - s.readMu.RUnlock() } return r @@ -2335,6 +2342,9 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } copied += n s.readView.TrimFront(n) + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } dst = dst.DropFirst(n) if e != nil { @@ -2458,6 +2468,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + var flags int if msgLen > int(n) { flags |= linux.MSG_TRUNC -- cgit v1.2.3 From 92b9069b67b927cef25a1490ebd142ad6d65690d Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Fri, 20 Mar 2020 12:00:21 -0700 Subject: Support owner matching for iptables. This feature will match UID and GID of the packet creator, for locally generated packets. This match is only valid in the OUTPUT and POSTROUTING chains. Forwarded packets do not have any socket associated with them. Packets from kernel threads do have a socket, but usually no owner. --- pkg/abi/linux/netfilter.go | 41 ++++++++ pkg/abi/linux/netfilter_test.go | 1 + pkg/sentry/kernel/task.go | 12 +++ pkg/sentry/socket/netfilter/BUILD | 1 + pkg/sentry/socket/netfilter/netfilter.go | 7 +- pkg/sentry/socket/netfilter/owner_matcher.go | 128 ++++++++++++++++++++++++ pkg/sentry/socket/netstack/provider.go | 6 ++ pkg/tcpip/network/ipv4/ipv4.go | 15 +++ pkg/tcpip/stack/packet_buffer.go | 9 +- pkg/tcpip/stack/transport_test.go | 2 + pkg/tcpip/tcpip.go | 12 +++ pkg/tcpip/transport/icmp/endpoint.go | 12 ++- pkg/tcpip/transport/packet/endpoint.go | 2 + pkg/tcpip/transport/raw/endpoint.go | 9 ++ pkg/tcpip/transport/tcp/accept.go | 5 +- pkg/tcpip/transport/tcp/connect.go | 10 +- pkg/tcpip/transport/tcp/endpoint.go | 7 ++ pkg/tcpip/transport/tcp/forwarder.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/udp/endpoint.go | 12 ++- test/iptables/filter_output.go | 143 +++++++++++++++++++++++++++ test/iptables/iptables_test.go | 30 ++++++ 22 files changed, 451 insertions(+), 17 deletions(-) create mode 100644 pkg/sentry/socket/netfilter/owner_matcher.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 80dc09aa9..a8d4f9d69 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -509,3 +509,44 @@ const ( // Enable all flags. XT_UDP_INV_MASK = 0x03 ) + +// IPTOwnerInfo holds data for matching packets with owner. It corresponds +// to struct ipt_owner_info in libxt_owner.c of iptables binary. +type IPTOwnerInfo struct { + // UID is user id which created the packet. + UID uint32 + + // GID is group id which created the packet. + GID uint32 + + // PID is process id of the process which created the packet. + PID uint32 + + // SID is session id which created the packet. + SID uint32 + + // Comm is the command name which created the packet. + Comm [16]byte + + // Match is used to match UID/GID of the socket. See the + // XT_OWNER_* flags below. + Match uint8 + + // Invert flips the meaning of Match field. + Invert uint8 +} + +// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo. +const SizeOfIPTOwnerInfo = 34 + +// Flags in IPTOwnerInfo.Match. Corresponding constants are in +// include/uapi/linux/netfilter/xt_owner.h. +const ( + // Match the UID of the packet. + XT_OWNER_UID = 1 << 0 + // Match the GID of the packet. + XT_OWNER_GID = 1 << 1 + // Match if the socket exists for the packet. Forwarded + // packets do not have an associated socket. + XT_OWNER_SOCKET = 1 << 2 +) diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index 21e237f92..565dd550e 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -29,6 +29,7 @@ func TestSizes(t *testing.T) { {IPTGetEntries{}, SizeOfIPTGetEntries}, {IPTGetinfo{}, SizeOfIPTGetinfo}, {IPTIP{}, SizeOfIPTIP}, + {IPTOwnerInfo{}, SizeOfIPTOwnerInfo}, {IPTReplace{}, SizeOfIPTReplace}, {XTCounters{}, SizeOfXTCounters}, {XTEntryMatch{}, SizeOfXTEntryMatch}, diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 8452ddf5b..d6546735e 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -863,3 +863,15 @@ func (t *Task) SetOOMScoreAdj(adj int32) error { atomic.StoreInt32(&t.tg.oomScoreAdj, adj) return nil } + +// UID returns t's uid. +// TODO(gvisor.dev/issue/170): This method is not namespaced yet. +func (t *Task) UID() uint32 { + return uint32(t.Credentials().EffectiveKUID) +} + +// GID returns t's gid. +// TODO(gvisor.dev/issue/170): This method is not namespaced yet. +func (t *Task) GID() uint32 { + return uint32(t.Credentials().EffectiveKGID) +} diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index e801abeb8..721094bbf 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "extensions.go", "netfilter.go", + "owner_matcher.go", "targets.go", "tcp_matcher.go", "udp_matcher.go", diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 55bcc3ace..878f81fd5 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -517,11 +517,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { } // TODO(gvisor.dev/issue/170): Support other chains. - // Since we only support modifying the INPUT chain and redirect for - // PREROUTING chain right now, make sure all other chains point to - // ACCEPT rules. + // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, + // make sure all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook != stack.Input && hook != stack.Prerouting { + if hook == stack.Forward || hook == stack.Postrouting { if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go new file mode 100644 index 000000000..5949a7c29 --- /dev/null +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +const matcherNameOwner = "owner" + +func init() { + registerMatchMaker(ownerMarshaler{}) +} + +// ownerMarshaler implements matchMaker for owner matching. +type ownerMarshaler struct{} + +// name implements matchMaker.name. +func (ownerMarshaler) name() string { + return matcherNameOwner +} + +// marshal implements matchMaker.marshal. +func (ownerMarshaler) marshal(mr stack.Matcher) []byte { + matcher := mr.(*OwnerMatcher) + iptOwnerInfo := linux.IPTOwnerInfo{ + UID: matcher.uid, + GID: matcher.gid, + } + + // Support for UID match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if matcher.matchUID { + iptOwnerInfo.Match = linux.XT_OWNER_UID + } else if matcher.matchGID { + panic("GID match is not supported.") + } else { + panic("UID match is not set.") + } + + buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) + return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo)) +} + +// unmarshal implements matchMaker.unmarshal. +func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { + if len(buf) < linux.SizeOfIPTOwnerInfo { + return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf)) + } + + // For alignment reasons, the match's total size may + // exceed what's strictly necessary to hold matchData. + var matchData linux.IPTOwnerInfo + binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData) + nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) + + if matchData.Invert != 0 { + return nil, fmt.Errorf("invert flag is not supported for owner match") + } + + // Support for UID match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if matchData.Match&linux.XT_OWNER_UID != linux.XT_OWNER_UID { + return nil, fmt.Errorf("owner match is only supported for uid") + } + + // Check Flags. + var owner OwnerMatcher + owner.uid = matchData.UID + owner.gid = matchData.GID + owner.matchUID = true + + return &owner, nil +} + +type OwnerMatcher struct { + uid uint32 + gid uint32 + matchUID bool + matchGID bool + invert uint8 +} + +// Name implements Matcher.Name. +func (*OwnerMatcher) Name() string { + return matcherNameOwner +} + +// Match implements Matcher.Match. +func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { + // Support only for OUTPUT chain. + // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. + if hook != stack.Output { + return false, true + } + + // If the packet owner is not set, drop the packet. + // Support for uid match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if pkt.Owner == nil || !om.matchUID { + return false, true + } + + // TODO(gvisor.dev/issue/170): Need to add tests to verify + // drop rule when packet UID does not match owner matcher UID. + if pkt.Owner.UID() != om.uid { + return false, false + } + + return true, false +} diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 5f181f017..eb090e79b 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -126,6 +126,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) } else { ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) + + // Assign task to PacketOwner interface to get the UID and GID for + // iptables owner matching. + if e == nil { + ep.SetOwner(t) + } } if e != nil { return nil, syserr.TranslateNetstackError(e) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b3ee6000e..a7d9a8b25 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -244,6 +244,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt); !ok { + // iptables is telling us to drop the packet. + return nil + } + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. @@ -280,7 +288,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac return len(pkts), nil } + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() for i := range pkts { + if ok := ipt.Check(stack.Output, pkts[i]); !ok { + // iptables is telling us to drop the packet. + continue + } ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) pkts[i].NetworkHeader = buffer.View(ip) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9505a4e92..9367de180 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -13,7 +13,10 @@ package stack -import "gvisor.dev/gvisor/pkg/tcpip/buffer" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) // A PacketBuffer contains all the data of a network packet. // @@ -59,6 +62,10 @@ type PacketBuffer struct { // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 + + // Owner is implemented by task to get the uid and gid. + // Only set for locally generated packets. + Owner tcpip.PacketOwner } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8ca9ac3cf..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -56,6 +56,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } +func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} + func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3dc5d87d6..2ef3271f1 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -336,6 +336,15 @@ type ControlMessages struct { PacketInfo IPPacketInfo } +// PacketOwner is used to get UID and GID of the packet. +type PacketOwner interface { + // UID returns UID of the packet. + UID() uint32 + + // GID returns GID of the packet. + GID() uint32 +} + // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) // that exposes functionality like read, write, connect, etc. to users of the // networking stack. @@ -470,6 +479,9 @@ type Endpoint interface { // Stats returns a reference to the endpoint stats. Stats() EndpointStats + + // SetOwner sets the task owner to the endpoint owner. + SetOwner(owner PacketOwner) } // EndpointInfo is the interface implemented by each endpoint info struct. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 613b12ead..b007302fb 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -73,6 +73,9 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -133,6 +136,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil @@ -321,7 +328,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c switch e.NetProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl) + err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) case header.IPv6ProtocolNumber: err = send6(route, e.ID.LocalPort, v, e.ttl) @@ -415,7 +422,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } @@ -444,6 +451,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), + Owner: owner, }) } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index df49d0995..23158173d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -392,3 +392,5 @@ func (ep *endpoint) Info() tcpip.EndpointInfo { func (ep *endpoint) Stats() tcpip.EndpointStats { return &ep.stats } + +func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 536dafd1e..337bc1c71 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -80,6 +80,9 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // NewEndpoint returns a raw endpoint for the given protocols. @@ -159,6 +162,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil @@ -348,10 +355,12 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } break } + hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), + Owner: e.owner, }); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 375ca21f6..7a9dea4ac 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -276,7 +276,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // and then performs the TCP 3-way handshake. // // The new endpoint is returned with e.mu held. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -284,6 +284,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if err != nil { return nil, err } + ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) @@ -414,7 +415,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header }() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 1d245c2c6..3239a5911 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -745,7 +745,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { tf.txHash = e.txHash - if err := sendTCP(r, tf, data, gso); err != nil { + if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() return err } @@ -787,7 +787,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff @@ -816,6 +816,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkts[i].DataSize = packetSize pkts[i].Data = data pkts[i].Hash = tf.txHash + pkts[i].Owner = owner buildTCPHdr(r, tf, &pkts[i], gso) off += packetSize tf.seq = tf.seq.Add(seqnum.Size(packetSize)) @@ -833,14 +834,14 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { - return sendTCPBatch(r, tf, data, gso) + return sendTCPBatch(r, tf, data, gso, owner) } pkt := stack.PacketBuffer{ @@ -849,6 +850,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac DataSize: data.Size(), Data: data, Hash: tf.txHash, + Owner: owner, } buildTCPHdr(r, tf, &pkt, gso) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1ebee0cfe..9b123e968 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -603,6 +603,9 @@ type endpoint struct { // txHash is the transport layer hash to be set on outbound packets // emitted by this endpoint. txHash uint32 + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -1132,6 +1135,10 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index a094471b8..808410c92 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -157,7 +157,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }, queue) + }, queue, nil) if err != nil { return nil, err } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 1377107ca..dce9a1652 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -199,7 +199,7 @@ func replyWithReset(s *segment) { seq: seq, ack: ack, rcvWnd: 0, - }, buffer.VectorisedView{}, nil /* gso */) + }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */) } // SetOption implements stack.TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index a3372ac58..120d3baa3 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -143,6 +143,9 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // +stateify savable @@ -484,7 +487,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -886,7 +889,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -916,6 +919,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Header: hdr, Data: data, TransportHeader: buffer.View(udp), + Owner: owner, }); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err @@ -1356,3 +1360,7 @@ func (*endpoint) Wait() {} func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } + +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index 4582d514c..f6d974b85 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -24,6 +24,11 @@ func init() { RegisterTestCase(FilterOutputDropTCPSrcPort{}) RegisterTestCase(FilterOutputDestination{}) RegisterTestCase(FilterOutputInvertDestination{}) + RegisterTestCase(FilterOutputAcceptTCPOwner{}) + RegisterTestCase(FilterOutputDropTCPOwner{}) + RegisterTestCase(FilterOutputAcceptUDPOwner{}) + RegisterTestCase(FilterOutputDropUDPOwner{}) + RegisterTestCase(FilterOutputOwnerFail{}) } // FilterOutputDropTCPDestPort tests that connections are not accepted on @@ -90,6 +95,144 @@ func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error { return nil } +// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted. +type FilterOutputAcceptTCPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputAcceptTCPOwner) Name() string { + return "FilterOutputAcceptTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection on port %d should be accepted, but got dropped", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection destined to port %d should be accepted, but got dropped", acceptPort) + } + + return nil +} + +// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped. +type FilterOutputDropTCPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputDropTCPOwner) Name() string { + return "FilterOutputDropTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted. +type FilterOutputAcceptUDPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputAcceptUDPOwner) Name() string { + return "FilterOutputAcceptUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Send UDP packets on acceptPort. + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error { + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped. +type FilterOutputDropUDPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputDropUDPOwner) Name() string { + return "FilterOutputDropUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Send UDP packets on dropPort. + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error { + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets should not be received") + } + + return nil +} + +// FilterOutputOwnerFail tests that without uid/gid option, owner rule +// will fail. +type FilterOutputOwnerFail struct{} + +// Name implements TestCase.Name. +func (FilterOutputOwnerFail) Name() string { + return "FilterOutputOwnerFail" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { + return fmt.Errorf("Invalid argument") + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputOwnerFail) LocalAction(ip net.IP) error { + // no-op. + return nil +} + // FilterOutputDestination tests that we can selectively allow packets to // certain destinations. type FilterOutputDestination struct{} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 7f1f70606..493d69052 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -274,6 +274,36 @@ func TestFilterOutputDropTCPSrcPort(t *testing.T) { } } +func TestFilterOutputAcceptTCPOwner(t *testing.T) { + if err := singleTest(FilterOutputAcceptTCPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputDropTCPOwner(t *testing.T) { + if err := singleTest(FilterOutputDropTCPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputAcceptUDPOwner(t *testing.T) { + if err := singleTest(FilterOutputAcceptUDPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputDropUDPOwner(t *testing.T) { + if err := singleTest(FilterOutputDropUDPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputOwnerFail(t *testing.T) { + if err := singleTest(FilterOutputOwnerFail{}); err != nil { + t.Fatal(err) + } +} + func TestJumpSerialize(t *testing.T) { if err := singleTest(FilterInputSerializeJump{}); err != nil { t.Fatal(err) -- cgit v1.2.3 From 5b2396d244ed6283d928a72bdd4cc58d78ef3175 Mon Sep 17 00:00:00 2001 From: Dean Deng Date: Thu, 2 Apr 2020 17:06:19 -0700 Subject: Fix typo in TODO comments. PiperOrigin-RevId: 304508083 --- pkg/sentry/fs/proc/mounts.go | 3 ++- pkg/sentry/fsimpl/tmpfs/filesystem.go | 6 +++--- pkg/sentry/fsimpl/tmpfs/stat_test.go | 4 ++-- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 2 +- pkg/sentry/vfs/mount.go | 3 ++- test/syscalls/linux/socket_inet_loopback.cc | 2 +- 7 files changed, 12 insertions(+), 10 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go index 94deb553b..1fc9c703c 100644 --- a/pkg/sentry/fs/proc/mounts.go +++ b/pkg/sentry/fs/proc/mounts.go @@ -170,7 +170,8 @@ func superBlockOpts(mountPath string, msrc *fs.MountSource) string { // NOTE(b/147673608): If the mount is a cgroup, we also need to include // the cgroup name in the options. For now we just read that from the // path. - // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we + // + // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we // should get this value from the cgroup itself, and not rely on the // path. if msrc.FilesystemType == "cgroup" { diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index e678ecc37..4cf27bf13 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -57,7 +57,7 @@ afterSymlink: } next := nextVFSD.Impl().(*dentry) if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { - // TODO(gvisor.dev/issues/1197): Symlink traversals updates + // TODO(gvisor.dev/issue/1197): Symlink traversals updates // access time. if err := rp.HandleSymlink(symlink.target); err != nil { return nil, err @@ -515,7 +515,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa oldParent.inode.decLinksLocked() newParent.inode.incLinksLocked() } - // TODO(gvisor.dev/issues/1197): Update timestamps and parent directory + // TODO(gvisor.dev/issue/1197): Update timestamps and parent directory // sizes. vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD) return nil @@ -600,7 +600,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu if err != nil { return linux.Statfs{}, err } - // TODO(gvisor.dev/issues/1197): Actually implement statfs. + // TODO(gvisor.dev/issue/1197): Actually implement statfs. return linux.Statfs{}, syserror.ENOSYS } diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go index ebe035dee..3e02e7190 100644 --- a/pkg/sentry/fsimpl/tmpfs/stat_test.go +++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go @@ -29,7 +29,7 @@ func TestStatAfterCreate(t *testing.T) { mode := linux.FileMode(0644) // Run with different file types. - // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets. + // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets. for _, typ := range []string{"file", "dir", "pipe"} { t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { var ( @@ -169,7 +169,7 @@ func TestSetStat(t *testing.T) { mode := linux.FileMode(0644) // Run with different file types. - // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets. + // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets. for _, typ := range []string{"file", "dir", "pipe"} { t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { var ( diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 8bc8818c0..54da15849 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -315,7 +315,7 @@ func (i *inode) statTo(stat *linux.Statx) { stat.Atime = linux.NsecToStatxTimestamp(i.atime) stat.Ctime = linux.NsecToStatxTimestamp(i.ctime) stat.Mtime = linux.NsecToStatxTimestamp(i.mtime) - // TODO(gvisor.dev/issues/1197): Device number. + // TODO(gvisor.dev/issue/1197): Device number. switch impl := i.impl.(type) { case *regularFile: stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f14c336b9..06a5b53bc 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -663,7 +663,7 @@ func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error // This is a hack to work around the fact that both IPv4 and IPv6 ANY are // represented by the empty string. // -// TODO(gvisor.dev/issues/1556): remove this function. +// TODO(gvisor.dev/issue/1556): remove this function. func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress { if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET { addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 7792eb1a0..1b8ecc415 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -835,7 +835,8 @@ func superBlockOpts(mountPath string, mnt *Mount) string { // NOTE(b/147673608): If the mount is a cgroup, we also need to include // the cgroup name in the options. For now we just read that from the // path. - // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we + // + // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we // should get this value from the cgroup itself, and not rely on the // path. if mnt.fs.FilesystemType().Name() == "cgroup" { diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 16888de2a..2ffc86382 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -234,7 +234,7 @@ TEST_P(DualStackSocketTest, AddressOperations) { } } -// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny. +// TODO(gvisor.dev/issue/1556): uncomment V4MappedAny. INSTANTIATE_TEST_SUITE_P( All, DualStackSocketTest, ::testing::Combine( -- cgit v1.2.3 From 5818663ebe26857845685702d99db41c7aa2cf3d Mon Sep 17 00:00:00 2001 From: Dean Deng Date: Fri, 3 Apr 2020 14:07:42 -0700 Subject: Add FileDescriptionImpl for Unix sockets. This change involves several steps: - Refactor the VFS1 unix socket implementation to share methods between VFS1 and VFS2 where possible. Re-implement the rest. - Override the default PRead, Read, PWrite, Write, Ioctl, Release methods in FileDescriptionDefaultImpl. - Add functions to create and initialize a new Dentry/Inode and FileDescription for a Unix socket file. Updates #1476 PiperOrigin-RevId: 304689796 --- pkg/sentry/fsimpl/sockfs/BUILD | 1 + pkg/sentry/fsimpl/sockfs/sockfs.go | 29 +++ pkg/sentry/kernel/BUILD | 1 + pkg/sentry/socket/netstack/netstack.go | 8 +- pkg/sentry/socket/unix/BUILD | 4 + pkg/sentry/socket/unix/unix.go | 89 ++++++--- pkg/sentry/socket/unix/unix_vfs2.go | 348 +++++++++++++++++++++++++++++++++ pkg/sentry/vfs/options.go | 5 + 8 files changed, 456 insertions(+), 29 deletions(-) create mode 100644 pkg/sentry/socket/unix/unix_vfs2.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/fsimpl/sockfs/BUILD b/pkg/sentry/fsimpl/sockfs/BUILD index 790d50e65..52084ddb5 100644 --- a/pkg/sentry/fsimpl/sockfs/BUILD +++ b/pkg/sentry/fsimpl/sockfs/BUILD @@ -7,6 +7,7 @@ go_library( srcs = ["sockfs.go"], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/abi/linux", "//pkg/context", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index c13511de2..3f7ad1d65 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -16,6 +16,7 @@ package sockfs import ( + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -60,6 +61,10 @@ type filesystem struct { } // inode implements kernfs.Inode. +// +// TODO(gvisor.dev/issue/1476): Add device numbers to this inode (which are +// not included in InodeAttrs) to store the numbers of the appropriate +// socket device. Override InodeAttrs.Stat() accordingly. type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink @@ -71,3 +76,27 @@ type inode struct { func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { return nil, syserror.ENXIO } + +// InitSocket initializes a socket FileDescription, with a corresponding +// Dentry in mnt. +// +// fd should be the FileDescription associated with socketImpl, i.e. its first +// field. mnt should be the global socket mount, Kernel.socketMount. +func InitSocket(socketImpl vfs.FileDescriptionImpl, fd *vfs.FileDescription, mnt *vfs.Mount, creds *auth.Credentials) error { + fsimpl := mnt.Filesystem().Impl() + fs := fsimpl.(*kernfs.Filesystem) + + // File mode matches net/socket.c:sock_alloc. + filemode := linux.FileMode(linux.S_IFSOCK | 0600) + i := &inode{} + i.Init(creds, fs.NextIno(), filemode) + + d := &kernfs.Dentry{} + d.Init(i) + + opts := &vfs.FileDescriptionOptions{UseDentryMetadata: true} + if err := fd.Init(socketImpl, linux.O_RDWR, mnt, d.VFSDentry(), opts); err != nil { + return err + } + return nil +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index bb7e3cbc3..e0ff58d8c 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -169,6 +169,7 @@ go_library( "//pkg/sentry/fs/lock", "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", + "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostcpu", "//pkg/sentry/inet", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 06a5b53bc..5d0085462 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -940,7 +940,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -966,7 +966,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1541,7 +1541,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa // SetSockOpt can be used to implement the linux syscall setsockopt(2) for // sockets backed by a commonEndpoint. -func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error { +func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error { switch level { case linux.SOL_SOCKET: return setSockOptSocket(t, s, ep, name, optVal) @@ -1568,7 +1568,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n } // setSockOptSocket implements SetSockOpt when level is SOL_SOCKET. -func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.SO_SNDBUF: if len(optVal) < sizeOfInt32 { diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 08743deba..de2cc4bdf 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -8,23 +8,27 @@ go_library( "device.go", "io.go", "unix.go", + "unix_vfs2.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/fspath", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4d30aa714..7c64f30fa 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -33,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -52,11 +54,8 @@ type SocketOperations struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` - refs.AtomicRefCount - socket.SendReceiveTimeout - ep transport.Endpoint - stype linux.SockType + socketOpsCommon } // New creates a new unix socket. @@ -75,16 +74,29 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty } s := SocketOperations{ - ep: ep, - stype: stype, + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, } s.EnableLeakCheck("unix.SocketOperations") return fs.NewFile(ctx, d, flags, &s) } +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { + refs.AtomicRefCount + socket.SendReceiveTimeout + + ep transport.Endpoint + stype linux.SockType +} + // DecRef implements RefCounter.DecRef. -func (s *SocketOperations) DecRef() { +func (s *socketOpsCommon) DecRef() { s.DecRefWithDestructor(func() { s.ep.Close() }) @@ -97,7 +109,7 @@ func (s *SocketOperations) Release() { s.DecRef() } -func (s *SocketOperations) isPacket() bool { +func (s *socketOpsCommon) isPacket() bool { switch s.stype { case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: return true @@ -110,7 +122,7 @@ func (s *SocketOperations) isPacket() bool { } // Endpoint extracts the transport.Endpoint. -func (s *SocketOperations) Endpoint() transport.Endpoint { +func (s *socketOpsCommon) Endpoint() transport.Endpoint { return s.ep } @@ -143,7 +155,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -155,7 +167,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -178,7 +190,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // Listen implements the linux syscall listen(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return s.ep.Listen(backlog) } @@ -310,6 +322,8 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Create the socket. + // + // TODO(gvisor.dev/issue/2324): Correctly set file permissions. childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}}) if err != nil { return syserr.ErrPortInUse @@ -345,6 +359,31 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, return ep, nil } + if kernel.VFS2Enabled { + p := fspath.Parse(path) + root := t.FSContext().RootDirectoryVFS2() + start := root + relPath := !p.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: p, + FollowFinalSymlink: true, + } + ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop) + root.DecRef() + if relPath { + start.DecRef() + } + if e != nil { + return nil, syserr.FromError(e) + } + return ep, nil + } + // Find the node in the filesystem. root := t.FSContext().RootDirectory() cwd := t.FSContext().WorkingDirectory() @@ -363,12 +402,11 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, // No socket! return nil, syserr.ErrConnectionRefused } - return ep, nil } // Connect implements the linux syscall connect(2) for unix sockets. -func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { ep, err := extractEndpoint(t, sockaddr) if err != nil { return err @@ -379,7 +417,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo return s.ep.Connect(t, ep) } -// Writev implements fs.FileOperations.Write. +// Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { t := kernel.TaskFromContext(ctx) ctrl := control.New(t, s.ep, nil) @@ -399,7 +437,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by // a transport.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { w := EndpointWriter{ Ctx: t, Endpoint: s.ep, @@ -453,27 +491,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } // Passcred implements transport.Credentialer.Passcred. -func (s *SocketOperations) Passcred() bool { +func (s *socketOpsCommon) Passcred() bool { return s.ep.Passcred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. -func (s *SocketOperations) ConnectedPasscred() bool { +func (s *socketOpsCommon) ConnectedPasscred() bool { return s.ep.ConnectedPasscred() } // Readiness implements waiter.Waitable.Readiness. -func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return s.ep.Readiness(mask) } // EventRegister implements waiter.Waitable.EventRegister. -func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.ep.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *SocketOperations) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } @@ -485,7 +523,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := netstack.ConvertShutdown(how) if err != nil { return err @@ -511,7 +549,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -648,12 +686,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } // State implements socket.Socket.State. -func (s *SocketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { return s.ep.State() } // Type implements socket.Socket.Type. -func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { // Unix domain sockets always have a protocol of 0. return linux.AF_UNIX, s.stype, 0 } @@ -706,4 +744,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F func init() { socket.RegisterProvider(linux.AF_UNIX, &provider{}) + socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{}) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go new file mode 100644 index 000000000..ca1388e2c --- /dev/null +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -0,0 +1,348 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unix + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/control" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SocketVFS2 implements socket.SocketVFS2 (and by extension, +// vfs.FileDescriptionImpl) for Unix sockets. +type SocketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + + socketOpsCommon +} + +// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket. +func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { + sock := NewFDImpl(ep, stype) + vfsfd := &sock.vfsfd + if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil { + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// NewFDImpl creates and returns a new SocketVFS2. +func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 { + // You can create AF_UNIX, SOCK_RAW sockets. They're the same as + // SOCK_DGRAM and don't require CAP_NET_RAW. + if stype == linux.SOCK_RAW { + stype = linux.SOCK_DGRAM + } + + return &SocketVFS2{ + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, + } +} + +// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) +} + +// blockingAccept implements a blocking version of accept(2), that is, if no +// connections are ready to be accept, it will block until one becomes ready. +func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { + // Register for notifications. + e, ch := waiter.NewChannelEntry(nil) + s.socketOpsCommon.EventRegister(&e, waiter.EventIn) + defer s.socketOpsCommon.EventUnregister(&e) + + // Try to accept the connection; if it fails, then wait until we get a + // notification. + for { + if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + return ep, err + } + + if err := t.Block(ch); err != nil { + return nil, syserr.FromError(err) + } + } +} + +// Accept implements the linux syscall accept(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + // Issue the accept request to get the new endpoint. + ep, err := s.ep.Accept() + if err != nil { + if err != syserr.ErrWouldBlock || !blocking { + return 0, nil, 0, err + } + + var err *syserr.Error + ep, err = s.blockingAccept(t) + if err != nil { + return 0, nil, 0, err + } + } + + // We expect this to be a FileDescription here. + ns, err := NewVFS2File(t, ep, s.stype) + if err != nil { + return 0, nil, 0, err + } + defer ns.DecRef() + + if flags&linux.SOCK_NONBLOCK != 0 { + ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK) + } + + var addr linux.SockAddr + var addrLen uint32 + if peerRequested { + // Get address of the peer. + var err *syserr.Error + addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) + if err != nil { + return 0, nil, 0, err + } + } + + fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ + CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, + }) + if e != nil { + return 0, nil, 0, syserr.FromError(e) + } + + // TODO: add vfs2 sockets to global table. + return fd, addr, addrLen, nil +} + +// Bind implements the linux syscall bind(2) for unix sockets. +func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { + p, e := extractPath(sockaddr) + if e != nil { + return e + } + + bep, ok := s.ep.(transport.BoundEndpoint) + if !ok { + // This socket can't be bound. + return syserr.ErrInvalidArgument + } + + return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error { + // Is it abstract? + if p[0] == 0 { + if t.IsNetworkNamespaced() { + return syserr.ErrInvalidEndpointState + } + if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + // syserr.ErrPortInUse corresponds to EADDRINUSE. + return syserr.ErrPortInUse + } + } else { + path := fspath.Parse(p) + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + start := root + relPath := !path.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef() + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + } + err := t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{ + // TODO(gvisor.dev/issue/2324): The file permissions should be taken + // from s and t.FSContext().Umask() (see net/unix/af_unix.c:unix_bind), + // but VFS1 just always uses 0400. Resolve this inconsistency. + Mode: linux.S_IFSOCK | 0400, + Endpoint: bep, + }) + if err == syserror.EEXIST { + return syserr.ErrAddressInUse + } + return syserr.FromError(err) + } + + return nil + }) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return netstack.Ioctl(ctx, s.ep, uio, args) +} + +// PRead implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + if dst.NumBytes() == 0 { + return 0, nil + } + return dst.CopyOutFrom(ctx, &EndpointReader{ + Ctx: ctx, + Endpoint: s.ep, + NumRights: 0, + Peek: false, + From: nil, + }) +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + t := kernel.TaskFromContext(ctx) + ctrl := control.New(t, s.ep, nil) + + if src.NumBytes() == 0 { + nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil) + return int64(nInt), err.ToError() + } + + return src.CopyInTo(ctx, &EndpointWriter{ + Ctx: ctx, + Endpoint: s.ep, + Control: ctrl, + To: nil, + }) +} + +// Release implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Release() { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef() +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { + return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) +} + +// providerVFS2 is a unix domain socket provider for VFS2. +type providerVFS2 struct{} + +func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Check arguments. + if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { + return nil, syserr.ErrProtocolNotSupported + } + + // Create the endpoint and socket. + var ep transport.Endpoint + switch stype { + case linux.SOCK_DGRAM, linux.SOCK_RAW: + ep = transport.NewConnectionless(t) + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: + ep = transport.NewConnectioned(t, stype, t.Kernel()) + default: + return nil, syserr.ErrInvalidArgument + } + + f, err := NewVFS2File(t, ep, stype) + if err != nil { + ep.Close() + return nil, err + } + return f, nil +} + +// Pair creates a new pair of AF_UNIX connected sockets. +func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Check arguments. + if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { + return nil, nil, syserr.ErrProtocolNotSupported + } + + switch stype { + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW: + // Ok + default: + return nil, nil, syserr.ErrInvalidArgument + } + + // Create the endpoints and sockets. + ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) + s1, err := NewVFS2File(t, ep1, stype) + if err != nil { + ep1.Close() + ep2.Close() + return nil, nil, err + } + s2, err := NewVFS2File(t, ep2, stype) + if err != nil { + s1.DecRef() + ep2.Close() + return nil, nil, err + } + + return s1, s2, nil +} diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 3e90dc4ed..2f04bf882 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -16,6 +16,7 @@ package vfs import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" ) // GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and @@ -44,6 +45,10 @@ type MknodOptions struct { // DevMinor are the major and minor device numbers for the created device. DevMajor uint32 DevMinor uint32 + + // Endpoint is the endpoint to bind to the created file, if a socket file is + // being created for bind(2) on a Unix domain socket. + Endpoint transport.BoundEndpoint } // MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC. -- cgit v1.2.3 From d5ddb5365086b13c0688c40fc74fa4cc4c5528db Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Tue, 7 Apr 2020 14:32:24 -0700 Subject: Remove out-of-date TODOs. We already have network namespace for netstack. PiperOrigin-RevId: 305341954 --- pkg/sentry/socket/netstack/provider.go | 8 -------- 1 file changed, 8 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index eb090e79b..c3f04b613 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -62,10 +62,6 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in } case linux.SOCK_RAW: - // TODO(b/142504697): "In order to create a raw socket, a - // process must have the CAP_NET_RAW capability in the user - // namespace that governs its network namespace." - raw(7) - // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { @@ -141,10 +137,6 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* } func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { - // TODO(b/142504697): "In order to create a packet socket, a process - // must have the CAP_NET_RAW capability in the user namespace that - // governs its network namespace." - packet(7) - // Packet sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(t) if !creds.HasCapability(linux.CAP_NET_RAW) { -- cgit v1.2.3 From 7928aa345e334f2c68f8f03b71d8cabe79e8db7e Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Thu, 9 Apr 2020 09:30:39 -0700 Subject: Convert int and bool socket options to use GetSockOptInt and GetSockOptBool PiperOrigin-RevId: 305699233 --- pkg/sentry/socket/netstack/netstack.go | 155 +++++------- pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/unix.go | 50 ++-- pkg/tcpip/stack/transport_demuxer_test.go | 35 ++- pkg/tcpip/tcpip.go | 145 ++++++----- pkg/tcpip/transport/icmp/endpoint.go | 47 ++-- pkg/tcpip/transport/raw/endpoint.go | 18 +- pkg/tcpip/transport/tcp/endpoint.go | 387 +++++++++++++----------------- pkg/tcpip/transport/tcp/tcp_test.go | 110 +++++---- pkg/tcpip/transport/udp/endpoint.go | 231 +++++++++--------- pkg/tcpip/transport/udp/udp_test.go | 60 ++--- 11 files changed, 583 insertions(+), 656 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 5d0085462..20e3fa0d2 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -300,7 +300,7 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { + if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { return nil, syserr.TranslateNetstackError(err) } } @@ -965,6 +965,13 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return nil, syserr.ErrProtocolNotAvailable } +func boolToInt32(v bool) int32 { + if v { + return 1 + } + return 0 +} + // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. @@ -998,12 +1005,11 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.PasscredOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.PasscredOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1042,24 +1048,22 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.ReuseAddressOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.ReusePortOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReusePortOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1089,24 +1093,22 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.BroadcastOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.BroadcastOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.KeepaliveEnabledOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { @@ -1156,47 +1158,41 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptInt(tcpip.DelayOption) + v, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - - if v == 0 { - return int32(1), nil - } - return int32(0), nil + return boolToInt32(!v), nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.CorkOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.CorkOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.QuickAckOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.QuickAckOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MaxSegOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MaxSegOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1328,11 +1324,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1342,8 +1334,8 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if outLen == 0 { return make([]byte, 0), nil } - var v tcpip.IPv6TrafficClassOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1365,12 +1357,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil default: emitUnimplementedEventIPv6(t, name) @@ -1386,8 +1373,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.TTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.TTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1403,8 +1390,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastTTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1429,23 +1416,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastLoopOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - if v { - return int32(1), nil - } - return int32(0), nil + return boolToInt32(v), nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { return []byte(nil), nil } - var v tcpip.IPv4TOSOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { @@ -1462,11 +1445,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1477,11 +1456,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil default: emitUnimplementedEventIP(t, name) @@ -1592,7 +1567,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1600,7 +1575,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) @@ -1628,7 +1603,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0)) case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { @@ -1636,7 +1611,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1644,7 +1619,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1716,11 +1691,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o int - if v == 0 { - o = 1 - } - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1728,7 +1699,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -1736,7 +1707,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -1744,7 +1715,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v))) case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { @@ -1855,7 +1826,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) if v == -1 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v))) case linux.IPV6_RECVTCLASS: v, err := parseIntOrChar(optVal) @@ -1940,7 +1911,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if v < 0 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v))) case linux.IP_ADD_MEMBERSHIP: req, err := copyInMulticastRequest(optVal, false /* allowAddr */) @@ -1987,9 +1958,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.MulticastLoopOption(v != 0), - )) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2008,7 +1977,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } else if v < 1 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v))) case linux.IP_TOS: if len(optVal) == 0 { @@ -2018,7 +1987,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v))) case linux.IP_RECVTOS: v, err := parseIntOrChar(optVal) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 74bcd6300..c708b6030 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/ilist", + "//pkg/log", "//pkg/refs", "//pkg/sync", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2ef654235..1f3880cc5 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" @@ -838,24 +839,45 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. Currently not supported. func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.PasscredOption: - e.setPasscred(v != 0) - return nil - } return nil } func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.BroadcastOption: + case tcpip.PasscredOption: + e.setPasscred(v) + case tcpip.ReuseAddressOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + return tcpip.ErrUnknownProtocolOption + } return nil } func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + return tcpip.ErrUnknownProtocolOption + } return nil } func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.PasscredOption: + return e.Passcred(), nil + + default: + log.Warningf("Unsupported socket option: %d", opt) + return false, tcpip.ErrUnknownProtocolOption + } } func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { @@ -914,29 +936,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return int(v), nil default: + log.Warningf("Unsupported socket option: %d", opt) return -1, tcpip.ErrUnknownProtocolOption } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) - } - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: + log.Warningf("Unsupported socket option: %T", opt) return tcpip.ErrUnknownProtocolOption } } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index c65b0c632..2474a7db3 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -206,7 +206,7 @@ func TestTransportDemuxerRegister(t *testing.T) { // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { - reuse int + reuse bool bindToDevice tcpip.NICID } for _, test := range []struct { @@ -221,11 +221,11 @@ func TestBindToDeviceDistribution(t *testing.T) { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. @@ -236,9 +236,9 @@ func TestBindToDeviceDistribution(t *testing.T) { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {reuse: 0, bindToDevice: 1}, - {reuse: 0, bindToDevice: 2}, - {reuse: 0, bindToDevice: 3}, + {reuse: false, bindToDevice: 1}, + {reuse: false, bindToDevice: 2}, + {reuse: false, bindToDevice: 3}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. @@ -253,12 +253,12 @@ func TestBindToDeviceDistribution(t *testing.T) { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {reuse: 1, bindToDevice: 1}, - {reuse: 1, bindToDevice: 1}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 0}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to @@ -309,9 +309,8 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - reusePortOption := tcpip.ReusePortOption(endpoint.reuse) - if err := ep.SetSockOpt(reusePortOption); err != nil { - t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err) + if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { + t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2ef3271f1..aec7126ff 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -520,34 +520,90 @@ type WriteOptions struct { type SockOptBool int const ( + // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether + // datagram sockets are allowed to send packets to a broadcast address. + BroadcastOption SockOptBool = iota + + // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be + // held until segments are full by the TCP transport protocol. + CorkOption + + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + + // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether + // TCP keepalive is enabled for this socket. + KeepaliveEnabledOption + + // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether + // multicast packets sent over a non-loopback interface will be looped back. + MulticastLoopOption + + // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether + // SCM_CREDENTIALS socket control messages are enabled. + // + // Only supported on Unix sockets. + PasscredOption + + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. + QuickAckOption + // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the // IPV6_TCLASS ancillary message is passed with incoming packets. - ReceiveTClassOption SockOptBool = iota + ReceiveTClassOption // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS // ancillary message is passed with incoming packets. ReceiveTOSOption - // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 - // socket is to be restricted to sending and receiving IPv6 packets only. - V6OnlyOption - // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify // if more inforamtion is provided with incoming packets such // as interface index and address. ReceiveIPPacketInfoOption - // TODO(b/146901447): convert existing bool socket options to be handled via - // Get/SetSockOptBool + // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() + // should allow reuse of local address. + ReuseAddressOption + + // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets + // to be bound to an identical socket address. + ReusePortOption + + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 + // socket is to be restricted to sending and receiving IPv6 packets only. + V6OnlyOption ) // SockOptInt represents socket options which values have the int type. type SockOptInt int const ( + // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number + // of un-ACKed TCP keepalives that will be sent before the connection is + // closed. + KeepaliveCountOption SockOptInt = iota + + // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS + // for all subsequent outgoing IPv4 packets from the endpoint. + IPv4TOSOption + + // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS + // for all subsequent outgoing IPv6 packets from the endpoint. + IPv6TrafficClassOption + + // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current + // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. + MaxSegOption + + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default + // TTL value for multicast messages. The default is 1. + MulticastTTLOption + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOptInt = iota + ReceiveQueueSizeOption // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. @@ -561,44 +617,21 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption - // DelayOption is used by SetSockOpt/GetSockOpt to specify if data - // should be sent out immediately by the transport protocol. For TCP, - // it determines if the Nagle algorithm is on or off. - DelayOption - - // TODO(b/137664753): convert all int socket options to be handled via - // GetSockOptInt. + // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop + // limit value for unicast messages. The default is protocol specific. + // + // A zero value indicates the default. + TTLOption ) // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} -// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be -// held until segments are full by the TCP transport protocol. -type CorkOption int - -// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() -// should allow reuse of local address. -type ReuseAddressOption int - -// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets -// to be bound to an identical socket address. -type ReusePortOption int - // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. type BindToDeviceOption NICID -// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. -type QuickAckOption int - -// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether -// SCM_CREDENTIALS socket control messages are enabled. -// -// Only supported on Unix sockets. -type PasscredOption int - // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -607,10 +640,6 @@ type TCPInfoOption struct { RTTVar time.Duration } -// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP keepalive is enabled for this socket. -type KeepaliveEnabledOption int - // KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a // connection must remain idle before the first TCP keepalive packet is sent. // Once this time is reached, KeepaliveIntervalOption is used instead. @@ -620,11 +649,6 @@ type KeepaliveIdleOption time.Duration // interval between sending TCP keepalive packets. type KeepaliveIntervalOption time.Duration -// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number -// of un-ACKed TCP keepalives that will be sent before the connection is -// closed. -type KeepaliveCountOption int - // TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user // specified timeout for a given TCP connection. // See: RFC5482 for details. @@ -638,20 +662,9 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive // buffer moderation. type ModerateReceiveBufferOption bool -// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current -// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. -type MaxSegOption int - -// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop -// limit value for unicast messages. The default is protocol specific. -// -// A zero value indicates the default. -type TTLOption uint8 - // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state // before being marked closed. @@ -668,10 +681,6 @@ type TCPTimeWaitTimeoutOption time.Duration // for a handshake till the specified timeout until a segment with data arrives. type TCPDeferAcceptOption time.Duration -// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default -// TTL value for multicast messages. The default is 1. -type MulticastTTLOption uint8 - // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { @@ -679,10 +688,6 @@ type MulticastInterfaceOption struct { InterfaceAddr Address } -// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether -// multicast packets sent over a non-loopback interface will be looped back. -type MulticastLoopOption bool - // MembershipOption is used by SetSockOpt/GetSockOpt as an argument to // AddMembershipOption and RemoveMembershipOption. type MembershipOption struct { @@ -705,22 +710,10 @@ type RemoveMembershipOption MembershipOption // TCP out-of-band data is delivered along with the normal in-band data. type OutOfBandInlineOption int -// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether -// datagram sockets are allowed to send packets to a broadcast address. -type BroadcastOption int - // DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify // a default TTL. type DefaultTTLOption uint8 -// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv4 packets from the endpoint. -type IPv4TOSOption uint8 - -// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv6 packets from the endpoint. -type IPv6TrafficClassOption uint8 - // IPPacketInfo is the message struture for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b007302fb..3a133eef9 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -348,29 +348,37 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(o) - e.mu.Unlock() - } - return nil } // SetSockOptBool sets a socket option. Currently not supported. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - return nil + return tcpip.ErrUnknownProtocolOption } // SetSockOptInt sets a socket option. Currently not supported. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + default: + return tcpip.ErrUnknownProtocolOption + } return nil } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -397,26 +405,23 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + case tcpip.TTLOption: + e.rcvMu.Lock() + v := int(e.ttl) + e.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.TTLOption: - e.rcvMu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.rcvMu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 337bc1c71..eee754a5a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -533,14 +533,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -548,7 +544,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -576,9 +578,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + default: + return -1, tcpip.ErrUnknownProtocolOption } - - return -1, tcpip.ErrUnknownProtocolOption } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9b123e968..a8d443f73 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -821,7 +821,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var de DelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { - e.SetSockOptInt(tcpip.DelayOption, 1) + e.SetSockOptBool(tcpip.DelayOption, true) } var tcpLT tcpip.TCPLingerTimeoutOption @@ -1409,10 +1409,60 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - e.LockUser() - defer e.UnlockUser() - switch opt { + + case tcpip.BroadcastOption: + e.LockUser() + e.broadcast = v + e.UnlockUser() + + case tcpip.CorkOption: + e.LockUser() + if !v { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + e.UnlockUser() + + case tcpip.DelayOption: + if v { + atomic.StoreUint32(&e.delay, 1) + } else { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.QuickAckOption: + o := uint32(1) + if v { + o = 0 + } + atomic.StoreUint32(&e.slowAck, o) + + case tcpip.ReuseAddressOption: + e.LockUser() + e.reuseAddr = v + e.UnlockUser() + return nil + + case tcpip.ReusePortOption: + e.LockUser() + e.reusePort = v + e.UnlockUser() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1424,7 +1474,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + e.LockUser() e.v6only = v + e.UnlockUser() + default: + return tcpip.ErrUnknownProtocolOption } return nil @@ -1432,7 +1486,40 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = int(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.IPv4TOSOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.userMSS = uint16(userMSS) + e.UnlockUser() + e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1483,7 +1570,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() e.notifyProtocolGoroutine(mask) - return nil case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max @@ -1502,52 +1588,21 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.sndBufMu.Lock() e.sndBufSize = size e.sndBufMu.Unlock() - return nil - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil + case tcpip.TTLOption: + e.LockUser() + e.ttl = uint8(v) + e.UnlockUser() default: - return nil + return tcpip.ErrUnknownProtocolOption } + return nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 switch v := opt.(type) { - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.LockUser() - e.reuseAddr = v != 0 - e.UnlockUser() - return nil - - case tcpip.ReusePortOption: - e.LockUser() - e.reusePort = v != 0 - e.UnlockUser() - return nil - case tcpip.BindToDeviceOption: id := tcpip.NICID(v) if id != 0 && !e.stack.HasNIC(id) { @@ -1556,72 +1611,26 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.LockUser() e.bindToDevice = id e.UnlockUser() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.LockUser() - e.userMSS = uint16(userMSS) - e.UnlockUser() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - - case tcpip.TTLOption: - e.LockUser() - e.ttl = uint8(v) - e.UnlockUser() - return nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v != 0 - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIntervalOption: e.keepalive.Lock() e.keepalive.interval = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil - case tcpip.KeepaliveCountOption: - e.keepalive.Lock() - e.keepalive.count = int(v) - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil + case tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. case tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(v) e.UnlockUser() - return nil - - case tcpip.BroadcastOption: - e.LockUser() - e.broadcast = v != 0 - e.UnlockUser() - return nil case tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and @@ -1652,22 +1661,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.IPv4TOSOption: - e.LockUser() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.UnlockUser() - return nil - - case tcpip.IPv6TrafficClassOption: - e.LockUser() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.UnlockUser() - return nil - case tcpip.TCPLingerTimeoutOption: e.LockUser() if v < 0 { @@ -1688,7 +1681,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.tcpLingerTimeout = time.Duration(v) e.UnlockUser() - return nil case tcpip.TCPDeferAcceptOption: e.LockUser() @@ -1697,11 +1689,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.deferAccept = time.Duration(v) e.UnlockUser() - return nil default: return nil } + return nil } // readyReceiveSize returns the number of bytes ready to be received. @@ -1723,6 +1715,43 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.BroadcastOption: + e.LockUser() + v := e.broadcast + e.UnlockUser() + return v, nil + + case tcpip.CorkOption: + return atomic.LoadUint32(&e.cork) != 0, nil + + case tcpip.DelayOption: + return atomic.LoadUint32(&e.delay) != 0, nil + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + return v, nil + + case tcpip.QuickAckOption: + v := atomic.LoadUint32(&e.slowAck) == 0 + return v, nil + + case tcpip.ReuseAddressOption: + e.LockUser() + v := e.reuseAddr + e.UnlockUser() + + return v, nil + + case tcpip.ReusePortOption: + e.LockUser() + v := e.reusePort + e.UnlockUser() + + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1734,14 +1763,41 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.UnlockUser() return v, nil - } - return false, tcpip.ErrUnknownProtocolOption + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + v := e.keepalive.count + e.keepalive.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + v := header.TCPDefaultMSS + return v, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1757,12 +1813,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil - case tcpip.DelayOption: - var o int - if v := atomic.LoadUint32(&e.delay); v != 0 { - o = 1 - } - return o, nil + case tcpip.TTLOption: + e.LockUser() + v := int(e.ttl) + e.UnlockUser() + return v, nil default: return -1, tcpip.ErrUnknownProtocolOption @@ -1779,61 +1834,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.lastErrorMu.Unlock() return err - case *tcpip.MaxSegOption: - // This is just stubbed out. Linux never returns the user_mss - // value as it either returns the defaultMSS or returns the - // actual current MSS. Netstack just returns the defaultMSS - // always for now. - *o = header.TCPDefaultMSS - return nil - - case *tcpip.CorkOption: - *o = 0 - if v := atomic.LoadUint32(&e.cork); v != 0 { - *o = 1 - } - return nil - - case *tcpip.ReuseAddressOption: - e.LockUser() - v := e.reuseAddr - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.ReusePortOption: - e.LockUser() - v := e.reusePort - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.BindToDeviceOption: e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) e.UnlockUser() - return nil - - case *tcpip.QuickAckOption: - *o = 1 - if v := atomic.LoadUint32(&e.slowAck); v != 0 { - *o = 0 - } - return nil - - case *tcpip.TTLOption: - e.LockUser() - *o = tcpip.TTLOption(e.ttl) - e.UnlockUser() - return nil case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} @@ -1846,92 +1850,45 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { o.RTTVar = snd.rtt.rttvar snd.rtt.Unlock() } - return nil - - case *tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) e.keepalive.Unlock() - return nil case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) e.keepalive.Unlock() - return nil - - case *tcpip.KeepaliveCountOption: - e.keepalive.Lock() - *o = tcpip.KeepaliveCountOption(e.keepalive.count) - e.keepalive.Unlock() - return nil case *tcpip.TCPUserTimeoutOption: e.LockUser() *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - return nil case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 - return nil - - case *tcpip.BroadcastOption: - e.LockUser() - v := e.broadcast - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc e.UnlockUser() - return nil - - case *tcpip.IPv4TOSOption: - e.LockUser() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.UnlockUser() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.LockUser() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.UnlockUser() - return nil case *tcpip.TCPLingerTimeoutOption: e.LockUser() *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) e.UnlockUser() - return nil case *tcpip.TCPDeferAcceptOption: e.LockUser() *o = tcpip.TCPDeferAcceptOption(e.deferAccept) e.UnlockUser() - return nil default: return tcpip.ErrUnknownProtocolOption } + return nil } // checkV4MappedLocked determines the effective network protocol and converts diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ce3df7478..32d0af6c4 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -728,7 +728,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize tests := []struct { name string - setMSS uint16 + setMSS int expMSS uint16 }{ { @@ -756,15 +756,14 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { c.Create(-1) // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) } // Get expected window size. rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) @@ -818,15 +817,14 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) { c.CreateV6Endpoint(true) // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) } // Get expected window size. rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) @@ -1077,17 +1075,17 @@ func TestTOSV4(t *testing.T) { c.EP = ep const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) } - var v tcpip.IPv4TOSOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Errorf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) } testV4Connect(t, c, checker.TOS(tos, 0)) @@ -1125,17 +1123,17 @@ func TestTrafficClassV6(t *testing.T) { c.CreateV6Endpoint(false) const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { + t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) } - var v tcpip.IPv6TrafficClassOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) } // Test the connection request. @@ -1711,7 +1709,7 @@ func TestNoWindowShrinking(t *testing.T) { c.CreateConnected(789, 30000, 10) if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err) } we, ch := waiter.NewChannelEntry(nil) @@ -1984,7 +1982,7 @@ func TestScaledWindowAccept(t *testing.T) { // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2057,7 +2055,7 @@ func TestNonScaledWindowAccept(t *testing.T) { // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2221,10 +2219,10 @@ func TestSegmentMerging(t *testing.T) { { "cork", func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(1)) + ep.SetSockOptBool(tcpip.CorkOption, true) }, func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(0)) + ep.SetSockOptBool(tcpip.CorkOption, false) }, }, } @@ -2316,7 +2314,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -2364,7 +2362,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -2397,7 +2395,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOptInt(tcpip.DelayOption, 0) + c.EP.SetSockOptBool(tcpip.DelayOption, false) // Check that data is received. second := c.GetPacket() @@ -2434,8 +2432,8 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, } for _, test := range tests { @@ -2576,12 +2574,12 @@ func TestSetTTL(t *testing.T) { t.Fatalf("NewEndpoint failed: %v", err) } - if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("Unexpected return value from Connect: %s", err) } // Receive SYN packet. @@ -2621,7 +2619,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2765,7 +2763,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { const rcvBufferSize = 0x20000 const wndScale = 2 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } // Start connection attempt. @@ -3882,26 +3880,26 @@ func TestMinMaxBufferSizes(t *testing.T) { // Set values below the min. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) @@ -4147,11 +4145,11 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { case "ipv4": case "ipv6": if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) + t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) } case "dual": if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) + t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) } default: t.Fatalf("unknown network: '%s'", network) @@ -4477,8 +4475,8 @@ func TestKeepalive(t *testing.T) { const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -5770,14 +5768,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { func TestDelayEnabled(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - checkDelayOption(t, c, false, 0) // Delay is disabled by default. + checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { delayEnabled tcp.DelayEnabled - wantDelayOption int + wantDelayOption bool }{ - {delayEnabled: false, wantDelayOption: 0}, - {delayEnabled: true, wantDelayOption: 1}, + {delayEnabled: false, wantDelayOption: false}, + {delayEnabled: true, wantDelayOption: true}, } { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5788,7 +5786,7 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { t.Helper() var gotDelayEnabled tcp.DelayEnabled @@ -5803,12 +5801,12 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del if err != nil { t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) } - gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { - t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) } if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) } } @@ -6620,8 +6618,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) // Set userTimeout to be the duration for 3 keepalive probes. userTimeout := 30 * time.Millisecond diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 120d3baa3..492cc1fcb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -501,11 +501,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v + e.mu.Unlock() + + case tcpip.MulticastLoopOption: + e.mu.Lock() + e.multicastLoop = v + e.mu.Unlock() + case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v e.mu.Unlock() - return nil case tcpip.ReceiveTClassOption: // We only support this option on v6 endpoints. @@ -516,7 +525,18 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Lock() e.receiveTClass = v e.mu.Unlock() - return nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + + case tcpip.ReuseAddressOption: + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v + e.mu.Unlock() case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. @@ -533,13 +553,8 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { } e.v6only = v - return nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - return nil + default: + return tcpip.ErrUnknownProtocolOption } return nil @@ -547,22 +562,40 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return nil -} + switch opt { + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) e.mu.Unlock() - case tcpip.MulticastTTLOption: + case tcpip.IPv4TOSOption: e.mu.Lock() - e.multicastTTL = uint8(v) + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) e.mu.Unlock() + case tcpip.ReceiveBufferSizeOption: + case tcpip.SendBufferSizeOption: + + default: + return tcpip.ErrUnknownProtocolOption + } + + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.MulticastInterfaceOption: e.mu.Lock() defer e.mu.Unlock() @@ -686,16 +719,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = bool(v) - e.mu.Unlock() - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - case tcpip.BindToDeviceOption: id := tcpip.NICID(v) if id != 0 && !e.stack.HasNIC(id) { @@ -704,26 +727,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() - return nil - - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v != 0 - e.mu.Unlock() - - return nil - - case tcpip.IPv4TOSOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil - - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil } return nil } @@ -731,6 +734,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.BroadcastOption: + e.mu.RLock() + v := e.broadcast + e.mu.RUnlock() + return v, nil + + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.MulticastLoopOption: + e.mu.RLock() + v := e.multicastLoop + e.mu.RUnlock() + return v, nil + case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS @@ -748,6 +766,22 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil + + case tcpip.ReuseAddressOption: + return false, nil + + case tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -760,19 +794,32 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil + default: + return false, tcpip.ErrUnknownProtocolOption } - - return false, tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.IPv4TOSOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.MulticastTTLOption: + e.mu.Lock() + v := int(e.multicastTTL) + e.mu.Unlock() + return v, nil + case tcpip.ReceiveQueueSizeOption: v := 0 e.rcvMu.Lock() @@ -794,29 +841,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := e.rcvBufSizeMax e.rcvMu.Unlock() return v, nil - } - return -1, tcpip.ErrUnknownProtocolOption + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - return nil - - case *tcpip.TTLOption: - e.mu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() - return nil - - case *tcpip.MulticastTTLOption: - e.mu.Lock() - *o = tcpip.MulticastTTLOption(e.multicastTTL) - e.mu.Unlock() - return nil - case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -824,67 +864,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.multicastAddr, } e.mu.Unlock() - return nil - - case *tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - - *o = tcpip.MulticastLoopOption(v) - return nil - - case *tcpip.ReuseAddressOption: - *o = 0 - return nil - - case *tcpip.ReusePortOption: - e.mu.RLock() - v := e.reusePort - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.BindToDeviceOption: e.mu.RLock() *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.BroadcastOption: - e.mu.RLock() - v := e.broadcast - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.IPv4TOSOption: - e.mu.RLock() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() - return nil default: return tcpip.ErrUnknownProtocolOption } + return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0905726c1..b3ee688b7 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -343,11 +343,11 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOptBool failed: %s", err) } } else if flow.isBroadcast() { - if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { - c.t.Fatal("SetSockOpt failed:", err) + if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + c.t.Fatalf("SetSockOptBool failed: %s", err) } } } @@ -1271,8 +1271,8 @@ func TestTTL(t *testing.T) { c.createEndpointForFlow(flow) const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { + c.t.Fatalf("SetSockOptInt failed: %s", err) } var wantTTL uint8 @@ -1311,8 +1311,8 @@ func TestSetTTL(t *testing.T) { c.createEndpointForFlow(flow) - if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } var p stack.NetworkProtocol @@ -1346,25 +1346,26 @@ func TestSetTOS(t *testing.T) { c.createEndpointForFlow(flow) const tos = testTOS - var v tcpip.IPv4TOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) + c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err) + if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) + if v != tos { + c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) } testWrite(c, flow, checker.TOS(tos, 0)) @@ -1381,25 +1382,26 @@ func TestSetTClass(t *testing.T) { c.createEndpointForFlow(flow) const tClass = testTOS - var v tcpip.IPv6TrafficClassOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err) + if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { + c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tClass); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) + if v != tClass { + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) } // The header getter for TClass is called TOS, so use that checker. @@ -1430,7 +1432,7 @@ func TestReceiveTosTClass(t *testing.T) { // Verify that setting and reading the option works. v, err := c.ep.GetSockOptBool(option) if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) } // Test for expected default value. if v != false { @@ -1444,7 +1446,7 @@ func TestReceiveTosTClass(t *testing.T) { got, err := c.ep.GetSockOptBool(option) if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) } if got != want { -- cgit v1.2.3 From c9195349c9ac24ccb538e92b308225dfa4184c42 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Thu, 9 Apr 2020 17:59:30 -0700 Subject: Replace type assertion with TaskFromContext. This should fix panic at aio callback. PiperOrigin-RevId: 305798549 --- pkg/sentry/socket/netstack/netstack.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 20e3fa0d2..7ac38764d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -535,7 +535,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } if resCh != nil { - t := ctx.(*kernel.Task) + t := kernel.TaskFromContext(ctx) if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err).ToError() } @@ -608,7 +608,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } if resCh != nil { - t := ctx.(*kernel.Task) + t := kernel.TaskFromContext(ctx) if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err).ToError() } -- cgit v1.2.3 From 7c0f3bc8576addbec001095d754a756691d26df3 Mon Sep 17 00:00:00 2001 From: Dave Bailey Date: Tue, 21 Apr 2020 09:34:42 -0700 Subject: Sentry metrics updates. Sentry metrics with nanoseconds units are labeled as such, and non-cumulative sentry metrics are supported. PiperOrigin-RevId: 307621080 --- pkg/metric/metric.go | 41 +++++++++++++++++----------------- pkg/metric/metric.proto | 10 ++++++++- pkg/metric/metric_test.go | 22 +++++++++++------- pkg/sentry/fs/file.go | 2 +- pkg/sentry/fs/gofer/file.go | 4 ++-- pkg/sentry/fs/tmpfs/inode_file.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 14 ++++++++---- 7 files changed, 58 insertions(+), 37 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index 895253625..64aa365ce 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -39,16 +39,11 @@ var ( // Uint64Metric encapsulates a uint64 that represents some kind of metric to be // monitored. // -// All metrics must be cumulative, meaning that their values will only increase -// over time. -// // Metrics are not saved across save/restore and thus reset to zero on restore. // -// TODO(b/67298402): Support non-cumulative metrics. // TODO(b/67298427): Support metric fields. type Uint64Metric struct { - // value is the actual value of the metric. It must be accessed - // atomically. + // value is the actual value of the metric. It must be accessed atomically. value uint64 } @@ -110,13 +105,10 @@ type customUint64Metric struct { // Register must only be called at init and will return and error if called // after Initialized. // -// All metrics must be cumulative, meaning that the return values of value must -// only increase over time. -// // Preconditions: // * name must be globally unique. // * Initialize/Disable have not been called. -func RegisterCustomUint64Metric(name string, sync bool, description string, value func() uint64) error { +func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error { if initialized { return ErrInitializationDone } @@ -129,9 +121,10 @@ func RegisterCustomUint64Metric(name string, sync bool, description string, valu metadata: &pb.MetricMetadata{ Name: name, Description: description, - Cumulative: true, + Cumulative: cumulative, Sync: sync, - Type: pb.MetricMetadata_UINT64, + Type: pb.MetricMetadata_TYPE_UINT64, + Units: units, }, value: value, } @@ -140,24 +133,32 @@ func RegisterCustomUint64Metric(name string, sync bool, description string, valu // MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric and panics // if it returns an error. -func MustRegisterCustomUint64Metric(name string, sync bool, description string, value func() uint64) { - if err := RegisterCustomUint64Metric(name, sync, description, value); err != nil { +func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func() uint64) { + if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value); err != nil { panic(fmt.Sprintf("Unable to register metric %q: %v", name, err)) } } -// NewUint64Metric creates and registers a new metric with the given name. +// NewUint64Metric creates and registers a new cumulative metric with the given name. // // Metrics must be statically defined (i.e., at init). -func NewUint64Metric(name string, sync bool, description string) (*Uint64Metric, error) { +func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string) (*Uint64Metric, error) { var m Uint64Metric - return &m, RegisterCustomUint64Metric(name, sync, description, m.Value) + return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value) } -// MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an -// error. +// MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an error. func MustCreateNewUint64Metric(name string, sync bool, description string) *Uint64Metric { - m, err := NewUint64Metric(name, sync, description) + m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description) + if err != nil { + panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) + } + return m +} + +// MustCreateNewUint64NanosecondsMetric calls NewUint64Metric and panics if it returns an error. +func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description string) *Uint64Metric { + m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NANOSECONDS, description) if err != nil { panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) } diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto index a2c2bd1ba..3cc89047d 100644 --- a/pkg/metric/metric.proto +++ b/pkg/metric/metric.proto @@ -36,10 +36,18 @@ message MetricMetadata { // the monitoring system. bool sync = 4; - enum Type { UINT64 = 0; } + enum Type { TYPE_UINT64 = 0; } // type is the type of the metric value. Type type = 5; + + enum Units { + UNITS_NONE = 0; + UNITS_NANOSECONDS = 1; + } + + // units is the units of the metric value. + Units units = 6; } // MetricRegistration contains the metadata for all metrics that will be in diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go index 34969385a..c425ea532 100644 --- a/pkg/metric/metric_test.go +++ b/pkg/metric/metric_test.go @@ -66,12 +66,12 @@ const ( func TestInitialize(t *testing.T) { defer reset() - _, err := NewUint64Metric("/foo", false, fooDescription) + _, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription) if err != nil { t.Fatalf("NewUint64Metric got err %v want nil", err) } - _, err = NewUint64Metric("/bar", true, barDescription) + _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NANOSECONDS, barDescription) if err != nil { t.Fatalf("NewUint64Metric got err %v want nil", err) } @@ -94,8 +94,8 @@ func TestInitialize(t *testing.T) { foundFoo := false foundBar := false for _, m := range mr.Metrics { - if m.Type != pb.MetricMetadata_UINT64 { - t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_UINT64) + if m.Type != pb.MetricMetadata_TYPE_UINT64 { + t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_TYPE_UINT64) } if !m.Cumulative { t.Errorf("Metadata %+v Cumulative got false want true", m) @@ -110,6 +110,9 @@ func TestInitialize(t *testing.T) { if m.Sync { t.Errorf("/foo %+v Sync got true want false", m) } + if m.Units != pb.MetricMetadata_UNITS_NONE { + t.Errorf("/foo %+v Units got %v want %v", m, m.Units, pb.MetricMetadata_UNITS_NONE) + } case "/bar": foundBar = true if m.Description != barDescription { @@ -118,6 +121,9 @@ func TestInitialize(t *testing.T) { if !m.Sync { t.Errorf("/bar %+v Sync got true want false", m) } + if m.Units != pb.MetricMetadata_UNITS_NANOSECONDS { + t.Errorf("/bar %+v Units got %v want %v", m, m.Units, pb.MetricMetadata_UNITS_NANOSECONDS) + } } } @@ -132,12 +138,12 @@ func TestInitialize(t *testing.T) { func TestDisable(t *testing.T) { defer reset() - _, err := NewUint64Metric("/foo", false, fooDescription) + _, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription) if err != nil { t.Fatalf("NewUint64Metric got err %v want nil", err) } - _, err = NewUint64Metric("/bar", true, barDescription) + _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NONE, barDescription) if err != nil { t.Fatalf("NewUint64Metric got err %v want nil", err) } @@ -161,12 +167,12 @@ func TestDisable(t *testing.T) { func TestEmitMetricUpdate(t *testing.T) { defer reset() - foo, err := NewUint64Metric("/foo", false, fooDescription) + foo, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription) if err != nil { t.Fatalf("NewUint64Metric got err %v want nil", err) } - _, err = NewUint64Metric("/bar", true, barDescription) + _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NONE, barDescription) if err != nil { t.Fatalf("NewUint64Metric got err %v want nil", err) } diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 78100e448..846252c89 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -44,7 +44,7 @@ var ( RecordWaitTime = false reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.") - readWait = metric.MustCreateNewUint64Metric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.") + readWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.") ) // IncrementWait increments the given wait time metric, if enabled. diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 23296f246..b2fcab127 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -37,9 +37,9 @@ var ( opens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a 9P file was opened from a gofer.") opensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a host file was opened from a gofer.") reads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") - readWait9P = metric.MustCreateNewUint64Metric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.") + readWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.") readsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.") - readWaitHost = metric.MustCreateNewUint64Metric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.") + readWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.") ) // fileOperations implements fs.FileOperations for a remote file system. diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index 25abbc151..1dc75291d 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -39,7 +39,7 @@ var ( opensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.") opensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.") reads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.") - readWait = metric.MustCreateNewUint64Metric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.") + readWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.") ) // fileInodeOperations implements fs.InodeOperations for a regular tmpfs file. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 7ac38764d..d5879c10f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -63,7 +63,13 @@ import ( func mustCreateMetric(name, description string) *tcpip.StatCounter { var cm tcpip.StatCounter - metric.MustRegisterCustomUint64Metric(name, false /* sync */, description, cm.Value) + metric.MustRegisterCustomUint64Metric(name, true /* cumulative */, false /* sync */, description, cm.Value) + return &cm +} + +func mustCreateGauge(name, description string) *tcpip.StatCounter { + var cm tcpip.StatCounter + metric.MustRegisterCustomUint64Metric(name, false /* cumulative */, false /* sync */, description, cm.Value) return &cm } @@ -151,10 +157,10 @@ var Metrics = tcpip.Stats{ TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), - CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."), - CurrentConnected: mustCreateMetric("/netstack/tcp/current_open", "Number of connections that are in connected state."), + CurrentEstablished: mustCreateGauge("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."), + CurrentConnected: mustCreateGauge("/netstack/tcp/current_open", "Number of connections that are in connected state."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), - EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "Number of times established TCP connections made a transition to CLOSED state."), EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), -- cgit v1.2.3 From 82bae30ceea0984c94af3085866b58ec9e69ea67 Mon Sep 17 00:00:00 2001 From: Dean Deng Date: Fri, 1 May 2020 12:53:15 -0700 Subject: Port netstack, hostinet, and netlink sockets to VFS2. All three follow the same pattern: 1. Refactor VFS1 sockets into socketOpsCommon, so that most of the methods can be shared with VFS2. 2. Create a FileDescriptionImpl with the corresponding socket operations, rewriting the few that cannot be shared with VFS1. 3. Set up a VFS2 socket provider that creates a socket by setting up a dentry in the global Kernel.socketMount and connecting it with a new FileDescription. This mostly completes the work for porting sockets to VFS2, and many syscall tests can be enabled as a result. There are several networking-related syscall tests that are still not passing: 1. net gofer tests 2. socketpair gofer tests 2. sendfile tests (splice is not implemented in VFS2 yet) Updates #1478, #1484, #1485 PiperOrigin-RevId: 309457331 --- pkg/sentry/fsimpl/gofer/socket.go | 9 +- pkg/sentry/fsimpl/host/host.go | 2 +- pkg/sentry/fsimpl/host/socket.go | 53 +++-- pkg/sentry/fsimpl/sockfs/sockfs.go | 5 + pkg/sentry/socket/hostinet/BUILD | 5 + pkg/sentry/socket/hostinet/socket.go | 105 ++++++--- pkg/sentry/socket/hostinet/socket_unsafe.go | 10 +- pkg/sentry/socket/hostinet/socket_vfs2.go | 183 ++++++++++++++++ pkg/sentry/socket/netlink/BUILD | 5 + pkg/sentry/socket/netlink/provider.go | 5 + pkg/sentry/socket/netlink/provider_vfs2.go | 71 ++++++ pkg/sentry/socket/netlink/socket.go | 72 +++--- pkg/sentry/socket/netlink/socket_vfs2.go | 138 ++++++++++++ pkg/sentry/socket/netstack/BUILD | 5 + pkg/sentry/socket/netstack/netstack.go | 72 +++--- pkg/sentry/socket/netstack/netstack_vfs2.go | 327 ++++++++++++++++++++++++++++ pkg/sentry/socket/netstack/provider.go | 4 + pkg/sentry/socket/netstack/provider_vfs2.go | 141 ++++++++++++ pkg/sentry/socket/unix/unix_vfs2.go | 4 +- 19 files changed, 1090 insertions(+), 126 deletions(-) create mode 100644 pkg/sentry/socket/hostinet/socket_vfs2.go create mode 100644 pkg/sentry/socket/netlink/provider_vfs2.go create mode 100644 pkg/sentry/socket/netlink/socket_vfs2.go create mode 100644 pkg/sentry/socket/netstack/netstack_vfs2.go create mode 100644 pkg/sentry/socket/netstack/provider_vfs2.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index 73835df91..d6dbe9092 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -87,7 +87,9 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec returnConnect(c, c) ce.Unlock() - c.Init() + if err := c.Init(); err != nil { + return syserr.FromError(err) + } return nil } @@ -99,7 +101,10 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect if err != nil { return nil, err } - c.Init() + + if err := c.Init(); err != nil { + return nil, syserr.FromError(err) + } // We don't need the receiver. c.CloseRecv() diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 33bd205af..34fbc69af 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -411,7 +411,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount) (*vfs.F return nil, syserror.ENOTTY } - ep, err := newEndpoint(ctx, i.hostFD) + ep, err := newEndpoint(ctx, i.hostFD, &i.queue) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 39dd624a0..38f1fbfba 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -34,17 +34,16 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// Create a new host-backed endpoint from the given fd. -func newEndpoint(ctx context.Context, hostFD int) (transport.Endpoint, error) { +// Create a new host-backed endpoint from the given fd and its corresponding +// notification queue. +func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) { // Set up an external transport.Endpoint using the host fd. addr := fmt.Sprintf("hostfd:[%d]", hostFD) - var q waiter.Queue - e, err := NewConnectedEndpoint(ctx, hostFD, &q, addr, true /* saveable */) + e, err := NewConnectedEndpoint(ctx, hostFD, addr, true /* saveable */) if err != nil { return nil, err.ToError() } - e.Init() - ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) + ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), queue, e, e) return ep, nil } @@ -77,8 +76,6 @@ type ConnectedEndpoint struct { // addr is the address at which this endpoint is bound. addr string - queue *waiter.Queue - // sndbuf is the size of the send buffer. // // N.B. When this is smaller than the host size, we present it via @@ -134,11 +131,10 @@ func (c *ConnectedEndpoint) init() *syserr.Error { // The caller is responsible for calling Init(). Additionaly, Release needs to // be called twice because ConnectedEndpoint is both a transport.Receiver and // transport.ConnectedEndpoint. -func NewConnectedEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) { +func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) { e := ConnectedEndpoint{ - fd: hostFD, - addr: addr, - queue: queue, + fd: hostFD, + addr: addr, } if err := e.init(); err != nil { @@ -151,13 +147,6 @@ func NewConnectedEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, return &e, nil } -// Init will do the initialization required without holding other locks. -func (c *ConnectedEndpoint) Init() { - if err := fdnotifier.AddFD(int32(c.fd), c.queue); err != nil { - panic(err) - } -} - // Send implements transport.ConnectedEndpoint.Send. func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { c.mu.RLock() @@ -332,7 +321,6 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { } func (c *ConnectedEndpoint) destroyLocked() { - fdnotifier.RemoveFD(int32(c.fd)) c.fd = -1 } @@ -350,14 +338,20 @@ func (c *ConnectedEndpoint) Release() { func (c *ConnectedEndpoint) CloseUnread() {} // SCMConnectedEndpoint represents an endpoint backed by a host fd that was -// passed through a gofer Unix socket. It is almost the same as -// ConnectedEndpoint, with the following differences: +// passed through a gofer Unix socket. It resembles ConnectedEndpoint, with the +// following differences: // - SCMConnectedEndpoint is not saveable, because the host cannot guarantee // the same descriptor number across S/R. -// - SCMConnectedEndpoint holds ownership of its fd and is responsible for -// closing it. +// - SCMConnectedEndpoint holds ownership of its fd and notification queue. type SCMConnectedEndpoint struct { ConnectedEndpoint + + queue *waiter.Queue +} + +// Init will do the initialization required without holding other locks. +func (e *SCMConnectedEndpoint) Init() error { + return fdnotifier.AddFD(int32(e.fd), e.queue) } // Release implements transport.ConnectedEndpoint.Release and @@ -368,6 +362,7 @@ func (e *SCMConnectedEndpoint) Release() { if err := syscall.Close(e.fd); err != nil { log.Warningf("Failed to close host fd %d: %v", err) } + fdnotifier.RemoveFD(int32(e.fd)) e.destroyLocked() e.mu.Unlock() }) @@ -380,11 +375,13 @@ func (e *SCMConnectedEndpoint) Release() { // be called twice because ConnectedEndpoint is both a transport.Receiver and // transport.ConnectedEndpoint. func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) { - e := SCMConnectedEndpoint{ConnectedEndpoint{ - fd: hostFD, - addr: addr, + e := SCMConnectedEndpoint{ + ConnectedEndpoint: ConnectedEndpoint{ + fd: hostFD, + addr: addr, + }, queue: queue, - }} + } if err := e.init(); err != nil { return nil, err diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index 271134af8..dac2389fc 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -74,6 +74,11 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr } // NewDentry constructs and returns a sockfs dentry. +// +// TODO(gvisor.dev/issue/1476): Currently, we are using +// sockfs.filesystem.NextIno() to get inode numbers. We should use +// device-specific numbers, so that we are not using the same generator for +// netstack, unix, etc. func NewDentry(creds *auth.Credentials, ino uint64) *vfs.Dentry { // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 023bad156..deedd35f7 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -10,6 +10,7 @@ go_library( "save_restore.go", "socket.go", "socket_unsafe.go", + "socket_vfs2.go", "sockopt_impl.go", "stack.go", ], @@ -25,11 +26,15 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/fsimpl/sockfs", + "//pkg/sentry/hostfd", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", "//pkg/sentry/socket", "//pkg/sentry/socket/control", + "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip/stack", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 22f78d2e2..b49433326 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -49,6 +50,8 @@ const ( maxControlLen = 1024 ) +// LINT.IfChange + // socketOperations implements fs.FileOperations and socket.Socket for a socket // implemented using a host socket. type socketOperations struct { @@ -59,23 +62,37 @@ type socketOperations struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + socketOpsCommon +} + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { socket.SendReceiveTimeout family int // Read-only. stype linux.SockType // Read-only. protocol int // Read-only. - fd int // must be O_NONBLOCK queue waiter.Queue + + // fd is the host socket fd. It must have O_NONBLOCK, so that operations + // will return EWOULDBLOCK instead of blocking on the host. This allows us to + // handle blocking behavior independently in the sentry. + fd int } var _ = socket.Socket(&socketOperations{}) func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { s := &socketOperations{ - family: family, - stype: stype, - protocol: protocol, - fd: fd, + socketOpsCommon: socketOpsCommon{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + }, } if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) @@ -86,28 +103,33 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc } // Release implements fs.FileOperations.Release. -func (s *socketOperations) Release() { +func (s *socketOpsCommon) Release() { fdnotifier.RemoveFD(int32(s.fd)) syscall.Close(s.fd) } // Readiness implements waiter.Waitable.Readiness. -func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return fdnotifier.NonBlockingPoll(int32(s.fd), mask) } // EventRegister implements waiter.Waitable.EventRegister. -func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.queue.EventRegister(e, mask) fdnotifier.UpdateFD(int32(s.fd)) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *socketOperations) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.queue.EventUnregister(e) fdnotifier.UpdateFD(int32(s.fd)) } +// Ioctl implements fs.FileOperations.Ioctl. +func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return ioctl(ctx, s.fd, io, args) +} + // Read implements fs.FileOperations.Read. func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { @@ -155,7 +177,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } // Connect implements socket.Socket.Connect. -func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -195,7 +217,7 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { var peerAddr linux.SockAddr var peerAddrBuf []byte var peerAddrlen uint32 @@ -209,7 +231,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } // Conservatively ignore all flags specified by the application and add - // SOCK_NONBLOCK since socketOperations requires it. + // SOCK_NONBLOCK since socketOpsCommon requires it. fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) if blocking { var ch chan struct{} @@ -235,23 +257,41 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) - if err != nil { - syscall.Close(fd) - return 0, nil, 0, err - } - defer f.DecRef() + var ( + kfd int32 + kerr error + ) + if kernel.VFS2Enabled { + f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&syscall.SOCK_NONBLOCK)) + if err != nil { + syscall.Close(fd) + return 0, nil, 0, err + } + defer f.DecRef() - kfd, kerr := t.NewFDFrom(0, f, kernel.FDFlags{ - CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, - }) - t.Kernel().RecordSocket(f) + kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{ + CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, + }) + t.Kernel().RecordSocketVFS2(f) + } else { + f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) + if err != nil { + syscall.Close(fd) + return 0, nil, 0, err + } + defer f.DecRef() + + kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{ + CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, + }) + t.Kernel().RecordSocket(f) + } return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } // Bind implements socket.Socket.Bind. -func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -264,12 +304,12 @@ func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Listen implements socket.Socket.Listen. -func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return syserr.FromError(syscall.Listen(s.fd, backlog)) } // Shutdown implements socket.Socket.Shutdown. -func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { switch how { case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR: return syserr.FromError(syscall.Shutdown(s.fd, how)) @@ -279,7 +319,7 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -328,7 +368,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt } // SetSockOpt implements socket.Socket.SetSockOpt. -func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { +func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { // Whitelist options and constrain option length. optlen := setSockOptLen(t, level, name) switch level { @@ -374,7 +414,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary @@ -496,7 +536,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } // SendMsg implements socket.Socket.SendMsg. -func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { // Whitelist flags. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { return 0, syserr.ErrInvalidArgument @@ -585,7 +625,7 @@ func translateIOSyscallError(err error) error { } // State implements socket.Socket.State. -func (s *socketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { info := linux.TCPInfo{} buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo) if err != nil { @@ -607,7 +647,7 @@ func (s *socketOperations) State() uint32 { } // Type implements socket.Socket.Type. -func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { return s.family, s.stype, s.protocol } @@ -663,8 +703,11 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int return nil, nil, nil } +// LINT.ThenChange(./socket_vfs2.go) + func init() { for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) + socket.RegisterProviderVFS2(family, &socketProviderVFS2{}) } } diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index cd67234d2..3f420c2ec 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/syserr" @@ -54,12 +53,11 @@ func writev(fd int, srcs []syscall.Iovec) (uint64, error) { return uint64(n), nil } -// Ioctl implements fs.FileOperations.Ioctl. -func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := uintptr(args[1].Int()); cmd { case syscall.TIOCINQ, syscall.TIOCOUTQ: var val int32 - if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(s.fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 { + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 { return 0, translateIOSyscallError(errno) } var buf [4]byte @@ -93,7 +91,7 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) { } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) @@ -104,7 +102,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go new file mode 100644 index 000000000..b03ca2f26 --- /dev/null +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -0,0 +1,183 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostinet + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +type socketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + // TODO(gvisor.dev/issue/1484): VFS1 stores internal metadata for hostinet. + // We should perhaps rely on the host, much like in hostfs. + vfs.DentryMetadataFileDescriptionImpl + + socketOpsCommon +} + +var _ = socket.SocketVFS2(&socketVFS2{}) + +func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { + mnt := t.Kernel().SocketMount() + fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) + d := sockfs.NewDentry(t.Credentials(), fs.NextIno()) + + s := &socketVFS2{ + socketOpsCommon: socketOpsCommon{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + }, + } + if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { + return nil, syserr.FromError(err) + } + vfsfd := &s.vfsfd + if err := vfsfd.Init(s, linux.O_RDWR|(flags&linux.O_NONBLOCK), mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *socketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *socketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return ioctl(ctx, s.fd, uio, args) +} + +// PRead implements vfs.FileDescriptionImpl. +func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + reader := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags) + n, err := dst.CopyOutFrom(ctx, reader) + hostfd.PutReadWriterAt(reader) + return int64(n), err +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *socketVFS2) PWrite(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + writer := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags) + n, err := src.CopyInTo(ctx, writer) + hostfd.PutReadWriterAt(writer) + return int64(n), err +} + +type socketProviderVFS2 struct { + family int +} + +// Socket implements socket.ProviderVFS2.Socket. +func (p *socketProviderVFS2) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Check that we are using the host network stack. + stack := t.NetworkContext() + if stack == nil { + return nil, nil + } + if _, ok := stack.(*Stack); !ok { + return nil, nil + } + + // Only accept TCP and UDP. + stype := stypeflags & linux.SOCK_TYPE_MASK + switch stype { + case syscall.SOCK_STREAM: + switch protocol { + case 0, syscall.IPPROTO_TCP: + // ok + default: + return nil, nil + } + case syscall.SOCK_DGRAM: + switch protocol { + case 0, syscall.IPPROTO_UDP: + // ok + default: + return nil, nil + } + default: + return nil, nil + } + + // Conservatively ignore all flags specified by the application and add + // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 + // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. + fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) + if err != nil { + return nil, syserr.FromError(err) + } + return newVFS2Socket(t, p.family, stype, protocol, fd, uint32(stypeflags&syscall.SOCK_NONBLOCK)) +} + +// Pair implements socket.Provider.Pair. +func (p *socketProviderVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Not supported by AF_INET/AF_INET6. + return nil, nil, nil +} diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 1911cd9b8..09ca00a4a 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -7,7 +7,9 @@ go_library( srcs = [ "message.go", "provider.go", + "provider_vfs2.go", "socket.go", + "socket_vfs2.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -18,6 +20,8 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", @@ -25,6 +29,7 @@ go_library( "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index b0dc70e5c..0d45e5053 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -67,6 +67,8 @@ func RegisterProvider(protocol int, provider Provider) { protocols[protocol] = provider } +// LINT.IfChange + // socketProvider implements socket.Provider. type socketProvider struct { } @@ -105,7 +107,10 @@ func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.Fi return nil, nil, syserr.ErrNotSupported } +// LINT.ThenChange(./provider_vfs2.go) + // init registers the socket provider. func init() { socket.RegisterProvider(linux.AF_NETLINK, &socketProvider{}) + socket.RegisterProviderVFS2(linux.AF_NETLINK, &socketProviderVFS2{}) } diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go new file mode 100644 index 000000000..dcd92b5cd --- /dev/null +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -0,0 +1,71 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netlink + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" +) + +// socketProviderVFS2 implements socket.Provider. +type socketProviderVFS2 struct { +} + +// Socket implements socket.Provider.Socket. +func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Netlink sockets must be specified as datagram or raw, but they + // behave the same regardless of type. + if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW { + return nil, syserr.ErrSocketNotSupported + } + + provider, ok := protocols[protocol] + if !ok { + return nil, syserr.ErrProtocolNotSupported + } + + p, err := provider(t) + if err != nil { + return nil, err + } + + s, err := NewVFS2(t, stype, p) + if err != nil { + return nil, err + } + + vfsfd := &s.vfsfd + mnt := t.Kernel().SocketMount() + fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) + d := sockfs.NewDentry(t.Credentials(), fs.NextIno()) + if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// Pair implements socket.Provider.Pair by returning an error. +func (*socketProviderVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Netlink sockets never supports creating socket pairs. + return nil, nil, syserr.ErrNotSupported +} diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 2ca02567d..81f34c5a2 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -58,6 +58,8 @@ var errNoFilter = syserr.New("no filter attached", linux.ENOENT) // netlinkSocketDevice is the netlink socket virtual device. var netlinkSocketDevice = device.NewAnonDevice() +// LINT.IfChange + // Socket is the base socket type for netlink sockets. // // This implementation only supports userspace sending and receiving messages @@ -74,6 +76,14 @@ type Socket struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + socketOpsCommon +} + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { socket.SendReceiveTimeout // ports provides netlink port allocation. @@ -140,17 +150,19 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke } return &Socket{ - ports: t.Kernel().NetlinkPorts(), - protocol: protocol, - skType: skType, - ep: ep, - connection: connection, - sendBufferSize: defaultSendBufferSize, + socketOpsCommon: socketOpsCommon{ + ports: t.Kernel().NetlinkPorts(), + protocol: protocol, + skType: skType, + ep: ep, + connection: connection, + sendBufferSize: defaultSendBufferSize, + }, }, nil } // Release implements fs.FileOperations.Release. -func (s *Socket) Release() { +func (s *socketOpsCommon) Release() { s.connection.Release() s.ep.Close() @@ -160,7 +172,7 @@ func (s *Socket) Release() { } // Readiness implements waiter.Waitable.Readiness. -func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { // ep holds messages to be read and thus handles EventIn readiness. ready := s.ep.Readiness(mask) @@ -174,18 +186,18 @@ func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask { } // EventRegister implements waiter.Waitable.EventRegister. -func (s *Socket) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.ep.EventRegister(e, mask) // Writable readiness never changes, so no registration is needed. } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *Socket) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } // Passcred implements transport.Credentialer.Passcred. -func (s *Socket) Passcred() bool { +func (s *socketOpsCommon) Passcred() bool { s.mu.Lock() passcred := s.passcred s.mu.Unlock() @@ -193,7 +205,7 @@ func (s *Socket) Passcred() bool { } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. -func (s *Socket) ConnectedPasscred() bool { +func (s *socketOpsCommon) ConnectedPasscred() bool { // This socket is connected to the kernel, which doesn't need creds. // // This is arbitrary, as ConnectedPasscred on this type has no callers. @@ -227,7 +239,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { // port of 0 defaults to the ThreadGroup ID. // // Preconditions: mu is held. -func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error { +func (s *socketOpsCommon) bindPort(t *kernel.Task, port int32) *syserr.Error { if s.bound { // Re-binding is only allowed if the port doesn't change. if port != s.portID { @@ -251,7 +263,7 @@ func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error { } // Bind implements socket.Socket.Bind. -func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { a, err := ExtractSockAddr(sockaddr) if err != nil { return err @@ -269,7 +281,7 @@ func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Connect implements socket.Socket.Connect. -func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { a, err := ExtractSockAddr(sockaddr) if err != nil { return err @@ -300,25 +312,25 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr } // Accept implements socket.Socket.Accept. -func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Netlink sockets never support accept. return 0, nil, 0, syserr.ErrNotSupported } // Listen implements socket.Socket.Listen. -func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // Netlink sockets never support listen. return syserr.ErrNotSupported } // Shutdown implements socket.Socket.Shutdown. -func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // Netlink sockets never support shutdown. return syserr.ErrNotSupported } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -369,7 +381,7 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem. } // SetSockOpt implements socket.Socket.SetSockOpt. -func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { +func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { switch level { case linux.SOL_SOCKET: switch name { @@ -466,7 +478,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy } // GetSockName implements socket.Socket.GetSockName. -func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { s.mu.Lock() defer s.mu.Unlock() @@ -478,7 +490,7 @@ func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er } // GetPeerName implements socket.Socket.GetPeerName. -func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { sa := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, // TODO(b/68878065): Support non-kernel peers. For now the peer @@ -489,7 +501,7 @@ func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er } // RecvMsg implements socket.Socket.RecvMsg. -func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { from := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: 0, @@ -590,7 +602,7 @@ func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) var kernelCreds = &kernelSCM{} // sendResponse sends the response messages in ms back to userspace. -func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { +func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { // Linux combines multiple netlink messages into a single datagram. bufs := make([][]byte, 0, len(ms.Messages)) for _, m := range ms.Messages { @@ -666,7 +678,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { // processMessages handles each message in buf, passing it to the protocol // handler for final handling. -func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error { +func (s *socketOpsCommon) processMessages(ctx context.Context, buf []byte) *syserr.Error { for len(buf) > 0 { msg, rest, ok := ParseMessage(buf) if !ok { @@ -698,7 +710,7 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error } // sendMsg is the core of message send, used for SendMsg and Write. -func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { dstPort := int32(0) if len(to) != 0 { @@ -745,7 +757,7 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, } // SendMsg implements socket.Socket.SendMsg. -func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { return s.sendMsg(t, src, to, flags, controlMessages) } @@ -756,11 +768,13 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, } // State implements socket.Socket.State. -func (s *Socket) State() uint32 { +func (s *socketOpsCommon) State() uint32 { return s.ep.State() } // Type implements socket.Socket.Type. -func (s *Socket) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { return linux.AF_NETLINK, s.skType, s.protocol.Protocol() } + +// LINT.ThenChange(./socket_vfs2.go) diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go new file mode 100644 index 000000000..b854bf990 --- /dev/null +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -0,0 +1,138 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netlink + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/unix" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SocketVFS2 is the base VFS2 socket type for netlink sockets. +// +// This implementation only supports userspace sending and receiving messages +// to/from the kernel. +// +// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer. +type SocketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + + socketOpsCommon +} + +var _ socket.SocketVFS2 = (*SocketVFS2)(nil) +var _ transport.Credentialer = (*SocketVFS2)(nil) + +// NewVFS2 creates a new SocketVFS2. +func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketVFS2, *syserr.Error) { + // Datagram endpoint used to buffer kernel -> user messages. + ep := transport.NewConnectionless(t) + + // Bind the endpoint for good measure so we can connect to it. The + // bound address will never be exposed. + if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { + ep.Close() + return nil, err + } + + // Create a connection from which the kernel can write messages. + connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) + if err != nil { + ep.Close() + return nil, err + } + + return &SocketVFS2{ + socketOpsCommon: socketOpsCommon{ + ports: t.Kernel().NetlinkPorts(), + protocol: protocol, + skType: skType, + ep: ep, + connection: connection, + sendBufferSize: defaultSendBufferSize, + }, + }, nil +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (*SocketVFS2) Ioctl(context.Context, usermem.IO, arch.SyscallArguments) (uintptr, error) { + // TODO(b/68878065): no ioctls supported. + return 0, syserror.ENOTTY +} + +// PRead implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + if dst.NumBytes() == 0 { + return 0, nil + } + return dst.CopyOutFrom(ctx, &unix.EndpointReader{ + Endpoint: s.ep, + }) +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) + return int64(n), err.ToError() +} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index cbf46b1e9..ccf9fcf5c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -7,7 +7,9 @@ go_library( srcs = [ "device.go", "netstack.go", + "netstack_vfs2.go", "provider.go", + "provider_vfs2.go", "save_restore.go", "stack.go", ], @@ -25,6 +27,8 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -32,6 +36,7 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", + "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d5879c10f..81053d8ef 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -252,6 +252,8 @@ type commonEndpoint interface { GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) } +// LINT.IfChange + // SocketOperations encapsulates all the state needed to represent a network stack // endpoint in the kernel context. // @@ -263,6 +265,14 @@ type SocketOperations struct { fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + socketOpsCommon +} + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { socket.SendReceiveTimeout *waiter.Queue @@ -314,11 +324,13 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue dirent := socket.NewDirent(t, netstackDevice) defer dirent.DecRef() return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ - Queue: queue, - family: family, - Endpoint: endpoint, - skType: skType, - protocol: protocol, + socketOpsCommon: socketOpsCommon{ + Queue: queue, + family: family, + Endpoint: endpoint, + skType: skType, + protocol: protocol, + }, }), nil } @@ -417,7 +429,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { } } -func (s *SocketOperations) isPacketBased() bool { +func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } @@ -425,7 +437,7 @@ func (s *SocketOperations) isPacketBased() bool { // empty. It assumes that the socket is locked. // // Precondition: s.readMu must be held. -func (s *SocketOperations) fetchReadView() *syserr.Error { +func (s *socketOpsCommon) fetchReadView() *syserr.Error { if len(s.readView) > 0 { return nil } @@ -446,7 +458,7 @@ func (s *SocketOperations) fetchReadView() *syserr.Error { } // Release implements fs.FileOperations.Release. -func (s *SocketOperations) Release() { +func (s *socketOpsCommon) Release() { s.Endpoint.Close() } @@ -633,7 +645,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } // Readiness returns a mask of ready events for socket s. -func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) // Check our cached value iff the caller asked for readability and the @@ -647,7 +659,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { return r } -func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error { +func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { if family == uint16(s.family) { return nil } @@ -670,7 +682,7 @@ func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error // represented by the empty string. // // TODO(gvisor.dev/issue/1556): remove this function. -func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress { +func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress { if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET { addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" } @@ -679,7 +691,7 @@ func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpi // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. -func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { addr, family, err := AddressAndFamily(sockaddr) if err != nil { return err @@ -725,7 +737,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < 2 { return syserr.ErrInvalidArgument } @@ -771,13 +783,13 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Listen implements the linux syscall listen(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog)) } // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { +func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -863,7 +875,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) { // Shutdown implements the linux syscall shutdown(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := ConvertShutdown(how) if err != nil { return err @@ -2258,7 +2270,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2270,7 +2282,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2285,7 +2297,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, // caller. // // Precondition: s.readMu must be locked. -func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { +func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { var err *syserr.Error var copied int @@ -2337,7 +2349,7 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq return 0, err } -func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { +func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { if !s.sockOptInq { return } @@ -2352,7 +2364,7 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -2461,7 +2473,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } -func (s *SocketOperations) controlMessages() socket.ControlMessages { +func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ IP: tcpip.ControlMessages{ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, @@ -2480,7 +2492,7 @@ func (s *SocketOperations) controlMessages() socket.ControlMessages { // successfully writing packet data out to userspace. // // Precondition: s.readMu must be locked. -func (s *SocketOperations) updateTimestamp() { +func (s *socketOpsCommon) updateTimestamp() { // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. if !s.sockOptTimestamp { s.timestampValid = true @@ -2490,7 +2502,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -2558,7 +2570,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // SendMsg implements the linux syscall sendmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { // Reject Unix control messages. if !controlMessages.Unix.Empty() { return 0, syserr.ErrInvalidArgument @@ -2634,6 +2646,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return s.socketOpsCommon.ioctl(ctx, io, args) +} + +func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. @@ -2973,7 +2989,7 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. -func (s *SocketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { if s.family != linux.AF_INET && s.family != linux.AF_INET6 { // States not implemented for this socket's family. return 0 @@ -3033,6 +3049,8 @@ func (s *SocketOperations) State() uint32 { } // Type implements socket.Socket.Type. -func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { return s.family, s.skType, s.protocol } + +// LINT.ThenChange(./netstack_vfs2.go) diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go new file mode 100644 index 000000000..eec71035d --- /dev/null +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -0,0 +1,327 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netstack + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SocketVFS2 encapsulates all the state needed to represent a network stack +// endpoint in the kernel context. +type SocketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + + socketOpsCommon +} + +// NewVFS2 creates a new endpoint socket. +func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { + if skType == linux.SOCK_STREAM { + if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + } + + mnt := t.Kernel().SocketMount() + fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) + d := sockfs.NewDentry(t.Credentials(), fs.NextIno()) + + s := &SocketVFS2{ + socketOpsCommon: socketOpsCommon{ + Queue: queue, + family: family, + Endpoint: endpoint, + skType: skType, + protocol: protocol, + }, + } + vfsfd := &s.vfsfd + if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// PRead implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + if dst.NumBytes() == 0 { + return 0, nil + } + n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) + if err == syserr.ErrWouldBlock { + return int64(n), syserror.ErrWouldBlock + } + if err != nil { + return 0, err.ToError() + } + return int64(n), nil +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + f := &ioSequencePayload{ctx: ctx, src: src} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := kernel.TaskFromContext(ctx) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) + } + + if err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + + if int64(n) < src.NumBytes() { + return int64(n), syserror.ErrWouldBlock + } + + return int64(n), nil +} + +// Accept implements the linux syscall accept(2) for sockets backed by +// tcpip.Endpoint. +func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + // Issue the accept request to get the new endpoint. + ep, wq, terr := s.Endpoint.Accept() + if terr != nil { + if terr != tcpip.ErrWouldBlock || !blocking { + return 0, nil, 0, syserr.TranslateNetstackError(terr) + } + + var err *syserr.Error + ep, wq, err = s.blockingAccept(t) + if err != nil { + return 0, nil, 0, err + } + } + + ns, err := NewVFS2(t, s.family, s.skType, s.protocol, wq, ep) + if err != nil { + return 0, nil, 0, err + } + defer ns.DecRef() + + if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil { + return 0, nil, 0, syserr.FromError(err) + } + + var addr linux.SockAddr + var addrLen uint32 + if peerRequested { + // Get address of the peer and write it to peer slice. + var err *syserr.Error + addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) + if err != nil { + return 0, nil, 0, err + } + } + + fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ + CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, + }) + + t.Kernel().RecordSocketVFS2(ns) + + return fd, addr, addrLen, syserr.FromError(e) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return s.socketOpsCommon.ioctl(ctx, uio, args) +} + +// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by +// tcpip.Endpoint. +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { + // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is + // implemented specifically for netstack.SocketVFS2 rather than + // commonEndpoint. commonEndpoint should be extended to support socket + // options where the implementation is not shared, as unix sockets need + // their own support for SO_TIMESTAMP. + if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP { + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + val := int32(0) + s.readMu.Lock() + defer s.readMu.Unlock() + if s.sockOptTimestamp { + val = 1 + } + return val, nil + } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + val := int32(0) + s.readMu.Lock() + defer s.readMu.Unlock() + if s.sockOptInq { + val = 1 + } + return val, nil + } + + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + switch name { + case linux.IPT_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) + if err != nil { + return nil, err + } + return info, nil + + case linux.IPT_SO_GET_ENTRIES: + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) + if err != nil { + return nil, err + } + return entries, nil + + } + } + + return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) +} + +// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by +// tcpip.Endpoint. +func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { + // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is + // implemented specifically for netstack.SocketVFS2 rather than + // commonEndpoint. commonEndpoint should be extended to support socket + // options where the implementation is not shared, as unix sockets need + // their own support for SO_TIMESTAMP. + if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP { + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + s.readMu.Lock() + defer s.readMu.Unlock() + s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0 + return nil + } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + s.readMu.Lock() + defer s.readMu.Unlock() + s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0 + return nil + } + + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + switch name { + case linux.IPT_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIPTReplace { + return syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal) + + case linux.IPT_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + } + } + + return SetSockOpt(t, s, s.Endpoint, level, name, optVal) +} diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index c3f04b613..ead3b2b79 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -33,6 +33,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + // provider is an inet socket provider. type provider struct { family int @@ -167,6 +169,8 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol return New(t, linux.AF_PACKET, stype, protocol, wq, ep) } +// LINT.ThenChange(./provider_vfs2.go) + // Pair just returns nil sockets (not supported). func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go new file mode 100644 index 000000000..2a01143f6 --- /dev/null +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -0,0 +1,141 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netstack + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/waiter" +) + +// providerVFS2 is an inet socket provider. +type providerVFS2 struct { + family int + netProto tcpip.NetworkProtocolNumber +} + +// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET +// family. +func (p *providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Fail right away if we don't have a stack. + stack := t.NetworkContext() + if stack == nil { + // Don't propagate an error here. Instead, allow the socket + // code to continue searching for another provider. + return nil, nil + } + eps, ok := stack.(*Stack) + if !ok { + return nil, nil + } + + // Packet sockets are handled separately, since they are neither INET + // nor INET6 specific. + if p.family == linux.AF_PACKET { + return packetSocketVFS2(t, eps, stype, protocol) + } + + // Figure out the transport protocol. + transProto, associated, err := getTransportProtocol(t, stype, protocol) + if err != nil { + return nil, err + } + + // Create the endpoint. + var ep tcpip.Endpoint + var e *tcpip.Error + wq := &waiter.Queue{} + if stype == linux.SOCK_RAW { + ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) + } else { + ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) + + // Assign task to PacketOwner interface to get the UID and GID for + // iptables owner matching. + if e == nil { + ep.SetOwner(t) + } + } + if e != nil { + return nil, syserr.TranslateNetstackError(e) + } + + return NewVFS2(t, p.family, stype, int(transProto), wq, ep) +} + +func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Packet sockets require CAP_NET_RAW. + creds := auth.CredentialsFromContext(t) + if !creds.HasCapability(linux.CAP_NET_RAW) { + return nil, syserr.ErrNotPermitted + } + + // "cooked" packets don't contain link layer information. + var cooked bool + switch stype { + case linux.SOCK_DGRAM: + cooked = true + case linux.SOCK_RAW: + cooked = false + default: + return nil, syserr.ErrProtocolNotSupported + } + + // protocol is passed in network byte order, but netstack wants it in + // host order. + netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + + wq := &waiter.Queue{} + ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return NewVFS2(t, linux.AF_PACKET, stype, protocol, wq, ep) +} + +// Pair just returns nil sockets (not supported). +func (*providerVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + return nil, nil, nil +} + +// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET. +func init() { + // Providers backed by netstack. + p := []providerVFS2{ + { + family: linux.AF_INET, + netProto: ipv4.ProtocolNumber, + }, + { + family: linux.AF_INET6, + netProto: ipv6.ProtocolNumber, + }, + { + family: linux.AF_PACKET, + }, + } + + for i := range p { + socket.RegisterProviderVFS2(p[i].family, &p[i]) + } +} diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 23db93f33..5edc3cdf4 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -229,7 +229,7 @@ func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset i // Read implements vfs.FileDescriptionImpl. func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { // All flags other than RWF_NOWAIT should be ignored. - // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. if opts.Flags != 0 { return 0, syserror.EOPNOTSUPP } @@ -254,7 +254,7 @@ func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset // Write implements vfs.FileDescriptionImpl. func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { // All flags other than RWF_NOWAIT should be ignored. - // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. if opts.Flags != 0 { return 0, syserror.EOPNOTSUPP } -- cgit v1.2.3 From faf89dd31a44b8409b32919d7193834e194ecc56 Mon Sep 17 00:00:00 2001 From: Dean Deng Date: Tue, 5 May 2020 12:09:39 -0700 Subject: Update vfs2 socket TODOs. Three updates: - Mark all vfs2 socket syscalls as supported. - Use the same dev number and ino number generator for all types of sockets, unlike in VFS1. - Do not use host fd for hostinet metadata. Fixes #1476, #1478, #1484, 1485, #2017. PiperOrigin-RevId: 309994579 --- pkg/sentry/fsimpl/sockfs/sockfs.go | 9 +---- pkg/sentry/socket/hostinet/socket_vfs2.go | 7 ++-- pkg/sentry/socket/netstack/netstack_vfs2.go | 3 ++ pkg/sentry/socket/unix/unix_vfs2.go | 3 ++ .../syscalls/linux/vfs2/linux64_override_amd64.go | 40 ++++++++++------------ 5 files changed, 30 insertions(+), 32 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index 3f085d3ca..239a9f4b4 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -53,9 +53,7 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem { // inode implements kernfs.Inode. // -// TODO(gvisor.dev/issue/1476): Add device numbers to this inode (which are -// not included in InodeAttrs) to store the numbers of the appropriate -// socket device. Override InodeAttrs.Stat() accordingly. +// TODO(gvisor.dev/issue/1193): Device numbers. type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink @@ -69,11 +67,6 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr } // NewDentry constructs and returns a sockfs dentry. -// -// TODO(gvisor.dev/issue/1476): Currently, we are using -// sockfs.filesystem.NextIno() to get inode numbers. We should use -// device-specific numbers, so that we are not using the same generator for -// netstack, unix, etc. func NewDentry(creds *auth.Credentials, ino uint64) *vfs.Dentry { // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index b03ca2f26..a8278bffc 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -36,8 +36,11 @@ import ( type socketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - // TODO(gvisor.dev/issue/1484): VFS1 stores internal metadata for hostinet. - // We should perhaps rely on the host, much like in hostfs. + + // We store metadata for hostinet sockets internally. Technically, we should + // access metadata (e.g. through stat, chmod) on the host for correctness, + // but this is not very useful for inet socket fds, which do not belong to a + // concrete file anyway. vfs.DentryMetadataFileDescriptionImpl socketOpsCommon diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index eec71035d..f7d9b2ff4 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" @@ -41,6 +42,8 @@ type SocketVFS2 struct { socketOpsCommon } +var _ = socket.SocketVFS2(&SocketVFS2{}) + // NewVFS2 creates a new endpoint socket. func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { if skType == linux.SOCK_STREAM { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 5edc3cdf4..06d838868 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" @@ -43,6 +44,8 @@ type SocketVFS2 struct { socketOpsCommon } +var _ = socket.SocketVFS2(&SocketVFS2{}) + // NewSockfsFile creates a new socket file in the global sockfs mount and // returns a corresponding file description. func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go index 074f58e5d..47c5d18e7 100644 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go @@ -44,22 +44,21 @@ func Override(table map[uintptr]kernel.Syscall) { table[32] = syscalls.Supported("dup", Dup) table[33] = syscalls.Supported("dup2", Dup2) delete(table, 40) // sendfile - // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. - table[41] = syscalls.PartiallySupported("socket", Socket, "In process of porting socket syscalls to VFS2.", nil) - table[42] = syscalls.PartiallySupported("connect", Connect, "In process of porting socket syscalls to VFS2.", nil) - table[43] = syscalls.PartiallySupported("accept", Accept, "In process of porting socket syscalls to VFS2.", nil) - table[44] = syscalls.PartiallySupported("sendto", SendTo, "In process of porting socket syscalls to VFS2.", nil) - table[45] = syscalls.PartiallySupported("recvfrom", RecvFrom, "In process of porting socket syscalls to VFS2.", nil) - table[46] = syscalls.PartiallySupported("sendmsg", SendMsg, "In process of porting socket syscalls to VFS2.", nil) - table[47] = syscalls.PartiallySupported("recvmsg", RecvMsg, "In process of porting socket syscalls to VFS2.", nil) - table[48] = syscalls.PartiallySupported("shutdown", Shutdown, "In process of porting socket syscalls to VFS2.", nil) - table[49] = syscalls.PartiallySupported("bind", Bind, "In process of porting socket syscalls to VFS2.", nil) - table[50] = syscalls.PartiallySupported("listen", Listen, "In process of porting socket syscalls to VFS2.", nil) - table[51] = syscalls.PartiallySupported("getsockname", GetSockName, "In process of porting socket syscalls to VFS2.", nil) - table[52] = syscalls.PartiallySupported("getpeername", GetPeerName, "In process of porting socket syscalls to VFS2.", nil) - table[53] = syscalls.PartiallySupported("socketpair", SocketPair, "In process of porting socket syscalls to VFS2.", nil) - table[54] = syscalls.PartiallySupported("setsockopt", SetSockOpt, "In process of porting socket syscalls to VFS2.", nil) - table[55] = syscalls.PartiallySupported("getsockopt", GetSockOpt, "In process of porting socket syscalls to VFS2.", nil) + table[41] = syscalls.Supported("socket", Socket) + table[42] = syscalls.Supported("connect", Connect) + table[43] = syscalls.Supported("accept", Accept) + table[44] = syscalls.Supported("sendto", SendTo) + table[45] = syscalls.Supported("recvfrom", RecvFrom) + table[46] = syscalls.Supported("sendmsg", SendMsg) + table[47] = syscalls.Supported("recvmsg", RecvMsg) + table[48] = syscalls.Supported("shutdown", Shutdown) + table[49] = syscalls.Supported("bind", Bind) + table[50] = syscalls.Supported("listen", Listen) + table[51] = syscalls.Supported("getsockname", GetSockName) + table[52] = syscalls.Supported("getpeername", GetPeerName) + table[53] = syscalls.Supported("socketpair", SocketPair) + table[54] = syscalls.Supported("setsockopt", SetSockOpt) + table[55] = syscalls.Supported("getsockopt", GetSockOpt) table[59] = syscalls.Supported("execve", Execve) table[72] = syscalls.Supported("fcntl", Fcntl) delete(table, 73) // flock @@ -145,8 +144,7 @@ func Override(table map[uintptr]kernel.Syscall) { delete(table, 285) // fallocate table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime) table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime) - // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. - table[288] = syscalls.PartiallySupported("accept4", Accept4, "In process of porting socket syscalls to VFS2.", nil) + table[288] = syscalls.Supported("accept4", Accept4) delete(table, 289) // signalfd4 table[290] = syscalls.Supported("eventfd2", Eventfd2) table[291] = syscalls.Supported("epoll_create1", EpollCreate1) @@ -155,11 +153,9 @@ func Override(table map[uintptr]kernel.Syscall) { delete(table, 294) // inotify_init1 table[295] = syscalls.Supported("preadv", Preadv) table[296] = syscalls.Supported("pwritev", Pwritev) - // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. - table[299] = syscalls.PartiallySupported("recvmmsg", RecvMMsg, "In process of porting socket syscalls to VFS2.", nil) + table[299] = syscalls.Supported("recvmmsg", RecvMMsg) table[306] = syscalls.Supported("syncfs", Syncfs) - // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. - table[307] = syscalls.PartiallySupported("sendmmsg", SendMMsg, "In process of porting socket syscalls to VFS2.", nil) + table[307] = syscalls.Supported("sendmmsg", SendMMsg) table[316] = syscalls.Supported("renameat2", Renameat2) delete(table, 319) // memfd_create table[322] = syscalls.Supported("execveat", Execveat) -- cgit v1.2.3 From 9115f26851b6f00ae01e9c130e3b5b342495c9e5 Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Thu, 7 May 2020 14:00:36 -0700 Subject: Allocate device numbers for VFS2 filesystems. Updates #1197, #1198, #1672 PiperOrigin-RevId: 310432006 --- pkg/abi/linux/dev.go | 4 + pkg/sentry/fsimpl/devpts/devpts.go | 40 +++++-- pkg/sentry/fsimpl/ext/ext.go | 16 ++- pkg/sentry/fsimpl/ext/filesystem.go | 8 +- pkg/sentry/fsimpl/ext/inode.go | 2 + pkg/sentry/fsimpl/gofer/gofer.go | 19 +++- pkg/sentry/fsimpl/host/host.go | 143 +++++++++++++------------ pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go | 4 +- pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 36 +++++-- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 6 +- pkg/sentry/fsimpl/kernfs/symlink.go | 8 +- pkg/sentry/fsimpl/pipefs/BUILD | 1 + pkg/sentry/fsimpl/pipefs/pipefs.go | 79 +++++++++----- pkg/sentry/fsimpl/proc/filesystem.go | 29 +++-- pkg/sentry/fsimpl/proc/subtasks.go | 12 +-- pkg/sentry/fsimpl/proc/task.go | 64 +++++------ pkg/sentry/fsimpl/proc/task_fds.go | 28 ++--- pkg/sentry/fsimpl/proc/task_files.go | 16 +-- pkg/sentry/fsimpl/proc/task_net.go | 38 +++---- pkg/sentry/fsimpl/proc/tasks.go | 47 ++++---- pkg/sentry/fsimpl/proc/tasks_files.go | 8 +- pkg/sentry/fsimpl/proc/tasks_sys.go | 108 +++++++++---------- pkg/sentry/fsimpl/sockfs/BUILD | 1 + pkg/sentry/fsimpl/sockfs/sockfs.go | 46 ++++++-- pkg/sentry/fsimpl/sys/sys.go | 21 +++- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 39 ++++--- pkg/sentry/kernel/kernel.go | 10 +- pkg/sentry/socket/hostinet/BUILD | 1 - pkg/sentry/socket/hostinet/socket_vfs2.go | 4 +- pkg/sentry/socket/netlink/BUILD | 1 - pkg/sentry/socket/netlink/provider_vfs2.go | 4 +- pkg/sentry/socket/netstack/BUILD | 1 - pkg/sentry/socket/netstack/netstack_vfs2.go | 4 +- pkg/sentry/socket/unix/BUILD | 1 - pkg/sentry/socket/unix/unix_vfs2.go | 4 +- pkg/sentry/vfs/anonfs.go | 2 +- pkg/sentry/vfs/device.go | 2 +- runsc/boot/loader.go | 5 +- 38 files changed, 517 insertions(+), 345 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go index 89f9a793f..fa3ae5f18 100644 --- a/pkg/abi/linux/dev.go +++ b/pkg/abi/linux/dev.go @@ -36,6 +36,10 @@ func DecodeDeviceID(rdev uint32) (uint16, uint32) { // // See Documentations/devices.txt and uapi/linux/major.h. const ( + // UNNAMED_MAJOR is the major device number for "unnamed" devices, whose + // minor numbers are dynamically allocated by the kernel. + UNNAMED_MAJOR = 0 + // MEM_MAJOR is the major device number for "memory" character devices. MEM_MAJOR = 1 diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index 94db8fe5c..c03c65445 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -51,21 +51,37 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, syserror.EINVAL } - fs, root := fstype.newFilesystem(vfsObj, creds) - return fs.VFSFilesystem(), root.VFSDentry(), nil + fs, root, err := fstype.newFilesystem(vfsObj, creds) + if err != nil { + return nil, nil, err + } + return fs.Filesystem.VFSFilesystem(), root.VFSDentry(), nil +} + +type filesystem struct { + kernfs.Filesystem + + devMinor uint32 } // newFilesystem creates a new devpts filesystem with root directory and ptmx // master inode. It returns the filesystem and root Dentry. -func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*kernfs.Filesystem, *kernfs.Dentry) { - fs := &kernfs.Filesystem{} - fs.VFSFilesystem().Init(vfsObj, fstype, fs) +func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + fs := &filesystem{ + devMinor: devMinor, + } + fs.Filesystem.VFSFilesystem().Init(vfsObj, fstype, fs) // Construct the root directory. This is always inode id 1. root := &rootInode{ slaves: make(map[uint32]*slaveInode), } - root.InodeAttrs.Init(creds, 1, linux.ModeDirectory|0555) + root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) root.dentry.Init(root) @@ -74,7 +90,7 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds master := &masterInode{ root: root, } - master.InodeAttrs.Init(creds, 2, linux.ModeCharacterDevice|0666) + master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) master.dentry.Init(master) // Add the master as a child of the root. @@ -83,7 +99,13 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds }) root.IncLinks(links) - return fs, &root.dentry + return fs, &root.dentry, nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() } // rootInode is the root directory inode for the devpts mounts. @@ -140,7 +162,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - slave.InodeAttrs.Init(creds, uint64(idx+3), linux.ModeCharacterDevice|0600) + slave.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) slave.dentry.Init(slave) i.slaves[idx] = slave diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index 7176af6d1..dac6effbf 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -105,36 +105,50 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // EACCESS should be returned according to mount(2). Filesystem independent // flags (like readonly) are currently not available in pkg/sentry/vfs. + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + dev, err := getDeviceFd(source, opts) if err != nil { return nil, nil, err } - fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} + fs := filesystem{ + dev: dev, + inodeCache: make(map[uint32]*inode), + devMinor: devMinor, + } fs.vfsfs.Init(vfsObj, &fsType, &fs) fs.sb, err = readSuperBlock(dev) if err != nil { + fs.vfsfs.DecRef() return nil, nil, err } if fs.sb.Magic() != linux.EXT_SUPER_MAGIC { // mount(2) specifies that EINVAL should be returned if the superblock is // invalid. + fs.vfsfs.DecRef() return nil, nil, syserror.EINVAL } // Refuse to mount if the filesystem is incompatible. if !isCompatible(fs.sb) { + fs.vfsfs.DecRef() return nil, nil, syserror.EINVAL } fs.bgs, err = readBlockGroups(dev, fs.sb) if err != nil { + fs.vfsfs.DecRef() return nil, nil, err } rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode) if err != nil { + fs.vfsfs.DecRef() return nil, nil, err } rootInode.incRef() diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 77b644275..557963e03 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -64,6 +64,10 @@ type filesystem struct { // bgs represents all the block group descriptors for the filesystem. // Immutable after initialization. bgs []disklayout.BlockGroup + + // devMinor is this filesystem's device minor number. Immutable after + // initialization. + devMinor uint32 } // Compiles only if filesystem implements vfs.FilesystemImpl. @@ -366,7 +370,9 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() {} +func (fs *filesystem) Release() { + fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) +} // Sync implements vfs.FilesystemImpl.Sync. func (fs *filesystem) Sync(ctx context.Context) error { diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index a98512350..485f86f4b 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -204,6 +204,8 @@ func (in *inode) statTo(stat *linux.Statx) { stat.Atime = in.diskInode.AccessTime().StatxTimestamp() stat.Ctime = in.diskInode.ChangeTime().StatxTimestamp() stat.Mtime = in.diskInode.ModificationTime().StatxTimestamp() + stat.DevMajor = linux.UNNAMED_MAJOR + stat.DevMinor = in.fs.devMinor // TODO(b/134676337): Set stat.Blocks which is the number of 512 byte blocks // (including metadata blocks) required to represent this file. } diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 9ab8fdc65..e68e37ebc 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -81,6 +81,9 @@ type filesystem struct { // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock + // devMinor is the filesystem's minor device number. devMinor is immutable. + devMinor uint32 + // uid and gid are the effective KUID and KGID of the filesystem's creator, // and are used as the owner and group for files that don't specify one. // uid and gid are immutable. @@ -399,14 +402,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // Construct the filesystem object. + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + attachFile.close(ctx) + client.Close() + return nil, nil, err + } fs := &filesystem{ mfp: mfp, opts: fsopts, iopts: iopts, - uid: creds.EffectiveKUID, - gid: creds.EffectiveKGID, client: client, clock: ktime.RealtimeClockFromContext(ctx), + devMinor: devMinor, + uid: creds.EffectiveKUID, + gid: creds.EffectiveKGID, syncableDentries: make(map[*dentry]struct{}), specialFileFDs: make(map[*specialFileFD]struct{}), } @@ -464,6 +474,8 @@ func (fs *filesystem) Release() { // Close the connection to the server. This implicitly clunks all fids. fs.client.Close() } + + fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } // dentry implements vfs.DentryImpl. @@ -796,7 +808,8 @@ func (d *dentry) statTo(stat *linux.Statx) { stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime)) stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime)) stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime)) - // TODO(gvisor.dev/issue/1198): device number + stat.DevMajor = linux.UNNAMED_MAJOR + stat.DevMinor = d.fs.devMinor } func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error { diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 144e04905..55de9c438 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -115,15 +115,28 @@ func (filesystemType) Name() string { // // Note that there should only ever be one instance of host.filesystem, // a global mount for host fds. -func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem { - fs := &filesystem{} +func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, err + } + fs := &filesystem{ + devMinor: devMinor, + } fs.VFSFilesystem().Init(vfsObj, filesystemType{}, fs) - return fs.VFSFilesystem() + return fs.VFSFilesystem(), nil } // filesystem implements vfs.FilesystemImpl. type filesystem struct { kernfs.Filesystem + + devMinor uint32 +} + +func (fs *filesystem) Release() { + fs.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() } func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { @@ -219,7 +232,7 @@ func (i *inode) Mode() linux.FileMode { } // Stat implements kernfs.Inode. -func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { if opts.Mask&linux.STATX__RESERVED != 0 { return linux.Statx{}, syserror.EINVAL } @@ -227,73 +240,73 @@ func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, erro return linux.Statx{}, syserror.EINVAL } + fs := vfsfs.Impl().(*filesystem) + // Limit our host call only to known flags. mask := opts.Mask & linux.STATX_ALL var s unix.Statx_t err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s) - // Fallback to fstat(2), if statx(2) is not supported on the host. - // - // TODO(b/151263641): Remove fallback. if err == syserror.ENOSYS { - return i.fstat(opts) - } else if err != nil { + // Fallback to fstat(2), if statx(2) is not supported on the host. + // + // TODO(b/151263641): Remove fallback. + return i.fstat(fs) + } + if err != nil { return linux.Statx{}, err } - ls := linux.Statx{Mask: mask} - // Unconditionally fill blksize, attributes, and device numbers, as indicated - // by /include/uapi/linux/stat.h. - // - // RdevMajor/RdevMinor are left as zero, so as not to expose host device - // numbers. - // - // TODO(gvisor.dev/issue/1672): Use kernfs-specific, internally defined - // device numbers. If we use the device number from the host, it may collide - // with another sentry-internal device number. We handle device/inode - // numbers without relying on the host to prevent collisions. - ls.Blksize = s.Blksize - ls.Attributes = s.Attributes - ls.AttributesMask = s.Attributes_mask - - if mask&linux.STATX_TYPE != 0 { + // Unconditionally fill blksize, attributes, and device numbers, as + // indicated by /include/uapi/linux/stat.h. Inode number is always + // available, since we use our own rather than the host's. + ls := linux.Statx{ + Mask: linux.STATX_INO, + Blksize: s.Blksize, + Attributes: s.Attributes, + Ino: i.ino, + AttributesMask: s.Attributes_mask, + DevMajor: linux.UNNAMED_MAJOR, + DevMinor: fs.devMinor, + } + + // Copy other fields that were returned by the host. RdevMajor/RdevMinor + // are never copied (and therefore left as zero), so as not to expose host + // device numbers. + ls.Mask |= s.Mask & linux.STATX_ALL + if s.Mask&linux.STATX_TYPE != 0 { ls.Mode |= s.Mode & linux.S_IFMT } - if mask&linux.STATX_MODE != 0 { + if s.Mask&linux.STATX_MODE != 0 { ls.Mode |= s.Mode &^ linux.S_IFMT } - if mask&linux.STATX_NLINK != 0 { + if s.Mask&linux.STATX_NLINK != 0 { ls.Nlink = s.Nlink } - if mask&linux.STATX_UID != 0 { + if s.Mask&linux.STATX_UID != 0 { ls.UID = s.Uid } - if mask&linux.STATX_GID != 0 { + if s.Mask&linux.STATX_GID != 0 { ls.GID = s.Gid } - if mask&linux.STATX_ATIME != 0 { + if s.Mask&linux.STATX_ATIME != 0 { ls.Atime = unixToLinuxStatxTimestamp(s.Atime) } - if mask&linux.STATX_BTIME != 0 { + if s.Mask&linux.STATX_BTIME != 0 { ls.Btime = unixToLinuxStatxTimestamp(s.Btime) } - if mask&linux.STATX_CTIME != 0 { + if s.Mask&linux.STATX_CTIME != 0 { ls.Ctime = unixToLinuxStatxTimestamp(s.Ctime) } - if mask&linux.STATX_MTIME != 0 { + if s.Mask&linux.STATX_MTIME != 0 { ls.Mtime = unixToLinuxStatxTimestamp(s.Mtime) } - if mask&linux.STATX_SIZE != 0 { + if s.Mask&linux.STATX_SIZE != 0 { ls.Size = s.Size } - if mask&linux.STATX_BLOCKS != 0 { + if s.Mask&linux.STATX_BLOCKS != 0 { ls.Blocks = s.Blocks } - // Use our own internal inode number. - if mask&linux.STATX_INO != 0 { - ls.Ino = i.ino - } - return ls, nil } @@ -305,36 +318,30 @@ func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, erro // of a mask or sync flags. fstat(2) does not provide any metadata // equivalent to Statx.Attributes, Statx.AttributesMask, or Statx.Btime, so // those fields remain empty. -func (i *inode) fstat(opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) fstat(fs *filesystem) (linux.Statx, error) { var s unix.Stat_t if err := unix.Fstat(i.hostFD, &s); err != nil { return linux.Statx{}, err } - // Note that rdev numbers are left as 0; do not expose host device numbers. - ls := linux.Statx{ - Mask: linux.STATX_BASIC_STATS, - Blksize: uint32(s.Blksize), - Nlink: uint32(s.Nlink), - UID: s.Uid, - GID: s.Gid, - Mode: uint16(s.Mode), - Size: uint64(s.Size), - Blocks: uint64(s.Blocks), - Atime: timespecToStatxTimestamp(s.Atim), - Ctime: timespecToStatxTimestamp(s.Ctim), - Mtime: timespecToStatxTimestamp(s.Mtim), - } - - // Use our own internal inode number. - // - // TODO(gvisor.dev/issue/1672): Use a kernfs-specific device number as well. - // If we use the device number from the host, it may collide with another - // sentry-internal device number. We handle device/inode numbers without - // relying on the host to prevent collisions. - ls.Ino = i.ino - - return ls, nil + // As with inode.Stat(), we always use internal device and inode numbers, + // and never expose the host's represented device numbers. + return linux.Statx{ + Mask: linux.STATX_BASIC_STATS, + Blksize: uint32(s.Blksize), + Nlink: uint32(s.Nlink), + UID: s.Uid, + GID: s.Gid, + Mode: uint16(s.Mode), + Ino: i.ino, + Size: uint64(s.Size), + Blocks: uint64(s.Blocks), + Atime: timespecToStatxTimestamp(s.Atim), + Ctime: timespecToStatxTimestamp(s.Ctim), + Mtime: timespecToStatxTimestamp(s.Mtim), + DevMajor: linux.UNNAMED_MAJOR, + DevMinor: fs.devMinor, + }, nil } // SetStat implements kernfs.Inode. @@ -453,8 +460,6 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount) (*vfs.F } // fileDescription is embedded by host fd implementations of FileDescriptionImpl. -// -// TODO(gvisor.dev/issue/1672): Implement Waitable interface. type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -471,12 +476,12 @@ type fileDescription struct { // SetStat implements vfs.FileDescriptionImpl. func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) - return f.inode.SetStat(ctx, nil, creds, opts) + return f.inode.SetStat(ctx, f.vfsfd.Mount().Filesystem(), creds, opts) } // Stat implements vfs.FileDescriptionImpl. func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) { - return f.inode.Stat(nil, opts) + return f.inode.Stat(f.vfsfd.Mount().Filesystem(), opts) } // Release implements vfs.FileDescriptionImpl. diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index c7779fc11..1568a9d49 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -44,11 +44,11 @@ type DynamicBytesFile struct { var _ Inode = (*DynamicBytesFile)(nil) // Init initializes a dynamic bytes file. -func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { +func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - f.InodeAttrs.Init(creds, ino, linux.ModeRegular|perm) + f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) f.data = data } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 615592d5f..982daa2e6 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -192,15 +192,17 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, // // Must be initialized by Init prior to first use. type InodeAttrs struct { - ino uint64 - mode uint32 - uid uint32 - gid uint32 - nlink uint32 + devMajor uint32 + devMinor uint32 + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 } // Init initializes this InodeAttrs. -func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMode) { +func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { if mode.FileType() == 0 { panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) } @@ -209,6 +211,8 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMo if mode.FileType() == linux.ModeDirectory { nlink = 2 } + a.devMajor = devMajor + a.devMinor = devMinor atomic.StoreUint64(&a.ino, ino) atomic.StoreUint32(&a.mode, uint32(mode)) atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) @@ -216,6 +220,16 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMo atomic.StoreUint32(&a.nlink, nlink) } +// DevMajor returns the device major number. +func (a *InodeAttrs) DevMajor() uint32 { + return a.devMajor +} + +// DevMinor returns the device minor number. +func (a *InodeAttrs) DevMinor() uint32 { + return a.devMinor +} + // Ino returns the inode id. func (a *InodeAttrs) Ino() uint64 { return atomic.LoadUint64(&a.ino) @@ -232,6 +246,8 @@ func (a *InodeAttrs) Mode() linux.FileMode { func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.DevMajor = a.devMajor + stat.DevMinor = a.devMinor stat.Ino = atomic.LoadUint64(&a.ino) stat.Mode = uint16(a.Mode()) stat.UID = atomic.LoadUint32(&a.uid) @@ -544,9 +560,9 @@ type StaticDirectory struct { var _ Inode = (*StaticDirectory)(nil) // NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry { +func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry { inode := &StaticDirectory{} - inode.Init(creds, ino, perm) + inode.Init(creds, devMajor, devMinor, ino, perm) dentry := &Dentry{} dentry.Init(inode) @@ -559,11 +575,11 @@ func NewStaticDir(creds *auth.Credentials, ino uint64, perm linux.FileMode, chil } // Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, ino uint64, perm linux.FileMode) { +func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - s.InodeAttrs.Init(creds, ino, linux.ModeDirectory|perm) + s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } // Open implements kernfs.Inode. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 1c5d3e7e7..412cf6ac9 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -75,7 +75,7 @@ type file struct { func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry { f := &file{} f.content = content - f.DynamicBytesFile.Init(creds, fs.NextIno(), f, 0777) + f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) d := &kernfs.Dentry{} d.Init(f) @@ -107,7 +107,7 @@ type readonlyDir struct { func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { dir := &readonlyDir{} - dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) dir.dentry.Init(dir) @@ -137,7 +137,7 @@ type dir struct { func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { dir := &dir{} dir.fs = fs - dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) dir.dentry.Init(dir) diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 0aa6dc979..2ab3f53fd 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -35,9 +35,9 @@ type StaticSymlink struct { var _ Inode = (*StaticSymlink)(nil) // NewStaticSymlink creates a new symlink file pointing to 'target'. -func NewStaticSymlink(creds *auth.Credentials, ino uint64, target string) *Dentry { +func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) *Dentry { inode := &StaticSymlink{} - inode.Init(creds, ino, target) + inode.Init(creds, devMajor, devMinor, ino, target) d := &Dentry{} d.Init(inode) @@ -45,9 +45,9 @@ func NewStaticSymlink(creds *auth.Credentials, ino uint64, target string) *Dentr } // Init initializes the instance. -func (s *StaticSymlink) Init(creds *auth.Credentials, ino uint64, target string) { +func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { s.target = target - s.InodeAttrs.Init(creds, ino, linux.ModeSymlink|0777) + s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } // Readlink implements Inode. diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD index 0d411606f..5950a2d59 100644 --- a/pkg/sentry/fsimpl/pipefs/BUILD +++ b/pkg/sentry/fsimpl/pipefs/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/fspath", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 5375e5e75..cab771211 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -17,8 +17,11 @@ package pipefs import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" @@ -40,20 +43,36 @@ func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFile panic("pipefs.filesystemType.GetFilesystem should never be called") } -// TODO(gvisor.dev/issue/1193): -// -// - kernfs does not provide a way to implement statfs, from which we -// should indicate PIPEFS_MAGIC. -// -// - kernfs does not provide a way to override names for -// vfs.FilesystemImpl.PrependPath(); pipefs inodes should use synthetic -// name fmt.Sprintf("pipe:[%d]", inode.ino). +type filesystem struct { + kernfs.Filesystem + + devMinor uint32 +} // NewFilesystem sets up and returns a new vfs.Filesystem implemented by pipefs. -func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem { - fs := &kernfs.Filesystem{} - fs.VFSFilesystem().Init(vfsObj, filesystemType{}, fs) - return fs.VFSFilesystem() +func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, err + } + fs := &filesystem{ + devMinor: devMinor, + } + fs.Filesystem.VFSFilesystem().Init(vfsObj, filesystemType{}, fs) + return fs.Filesystem.VFSFilesystem(), nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + inode := vd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode) + b.PrependComponent(fmt.Sprintf("pipe:[%d]", inode.ino)) + return vfs.PrependPathSyntheticError{} } // inode implements kernfs.Inode. @@ -71,11 +90,11 @@ type inode struct { ctime ktime.Time } -func newInode(ctx context.Context, fs *kernfs.Filesystem) *inode { +func newInode(ctx context.Context, fs *filesystem) *inode { creds := auth.CredentialsFromContext(ctx) return &inode{ pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), - ino: fs.NextIno(), + ino: fs.Filesystem.NextIno(), uid: creds.EffectiveKUID, gid: creds.EffectiveKGID, ctime: ktime.NowFromContext(ctx), @@ -98,19 +117,20 @@ func (i *inode) Mode() linux.FileMode { func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds()) return linux.Statx{ - Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, - Blksize: usermem.PageSize, - Nlink: 1, - UID: uint32(i.uid), - GID: uint32(i.gid), - Mode: pipeMode, - Ino: i.ino, - Size: 0, - Blocks: 0, - Atime: ts, - Ctime: ts, - Mtime: ts, - // TODO(gvisor.dev/issue/1197): Device number. + Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, + Blksize: usermem.PageSize, + Nlink: 1, + UID: uint32(i.uid), + GID: uint32(i.gid), + Mode: pipeMode, + Ino: i.ino, + Size: 0, + Blocks: 0, + Atime: ts, + Ctime: ts, + Mtime: ts, + DevMajor: linux.UNNAMED_MAJOR, + DevMinor: vfsfs.Impl().(*filesystem).devMinor, }, nil } @@ -122,6 +142,9 @@ func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth. return syserror.EPERM } +// TODO(gvisor.dev/issue/1193): kernfs does not provide a way to implement +// statfs, from which we should indicate PIPEFS_MAGIC. + // Open implements kernfs.Inode.Open. func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags) @@ -132,7 +155,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) { - fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) + fs := mnt.Filesystem().Impl().(*filesystem) inode := newInode(ctx, fs) var d kernfs.Dentry d.Init(inode) diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 104fc9030..609210253 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -41,6 +41,12 @@ func (FilesystemType) Name() string { return Name } +type filesystem struct { + kernfs.Filesystem + + devMinor uint32 +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { k := kernel.KernelFromContext(ctx) @@ -51,8 +57,13 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF if pidns == nil { return nil, nil, fmt.Errorf("procfs requires a PID namespace") } - - procfs := &kernfs.Filesystem{} + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + procfs := &filesystem{ + devMinor: devMinor, + } procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) var cgroups map[string]string @@ -61,21 +72,27 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF cgroups = data.Cgroups } - _, dentry := newTasksInode(procfs, k, pidns, cgroups) + _, dentry := procfs.newTasksInode(k, pidns, cgroups) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil } +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() +} + // dynamicInode is an overfitted interface for common Inodes with // dynamicByteSource types used in procfs. type dynamicInode interface { kernfs.Inode vfs.DynamicBytesSource - Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) + Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) } -func newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry { - inode.Init(creds, ino, inode, perm) +func (fs *filesystem) newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry { + inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) d := &kernfs.Dentry{} d.Init(inode) diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index a5cfa8333..36a911db4 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -37,23 +37,23 @@ type subtasksInode struct { kernfs.OrderedChildren kernfs.AlwaysValid + fs *filesystem task *kernel.Task pidns *kernel.PIDNamespace - inoGen InoGenerator cgroupControllers map[string]string } var _ kernfs.Inode = (*subtasksInode)(nil) -func newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, inoGen InoGenerator, cgroupControllers map[string]string) *kernfs.Dentry { +func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *kernfs.Dentry { subInode := &subtasksInode{ + fs: fs, task: task, pidns: pidns, - inoGen: inoGen, cgroupControllers: cgroupControllers, } // Note: credentials are overridden by taskOwnedInode. - subInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) + subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) inode := &taskOwnedInode{Inode: subInode, owner: task} @@ -78,7 +78,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, e return nil, syserror.ENOENT } - subTaskDentry := newTaskInode(i.inoGen, subTask, i.pidns, false, i.cgroupControllers) + subTaskDentry := i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers) return subTaskDentry.VFSDentry(), nil } @@ -102,7 +102,7 @@ func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallb dirent := vfs.Dirent{ Name: strconv.FormatUint(uint64(tid), 10), Type: linux.DT_DIR, - Ino: i.inoGen.NextIno(), + Ino: i.fs.NextIno(), NextOff: offset + 1, } if err := cb.Handle(dirent); err != nil { diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 66419d91b..482055db1 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -43,45 +43,45 @@ type taskInode struct { var _ kernfs.Inode = (*taskInode)(nil) -func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry { +func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry { // TODO(gvisor.dev/issue/164): Fail with ESRCH if task exited. contents := map[string]*kernfs.Dentry{ - "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}), - "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), - "comm": newComm(task, inoGen.NextIno(), 0444), - "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), - "exe": newExeSymlink(task, inoGen.NextIno()), - "fd": newFDDirInode(task, inoGen), - "fdinfo": newFDInfoDirInode(task, inoGen), - "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}), - "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)), - "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}), - "mountinfo": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mountInfoData{task: task}), - "mounts": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mountsData{task: task}), - "net": newTaskNetDir(task, inoGen), - "ns": newTaskOwnedDir(task, inoGen.NextIno(), 0511, map[string]*kernfs.Dentry{ - "net": newNamespaceSymlink(task, inoGen.NextIno(), "net"), - "pid": newNamespaceSymlink(task, inoGen.NextIno(), "pid"), - "user": newNamespaceSymlink(task, inoGen.NextIno(), "user"), + "auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}), + "cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), + "comm": fs.newComm(task, fs.NextIno(), 0444), + "environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), + "exe": fs.newExeSymlink(task, fs.NextIno()), + "fd": fs.newFDDirInode(task), + "fdinfo": fs.newFDInfoDirInode(task), + "gid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), + "io": fs.newTaskOwnedFile(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), + "maps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mapsData{task: task}), + "mountinfo": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountInfoData{task: task}), + "mounts": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountsData{task: task}), + "net": fs.newTaskNetDir(task), + "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]*kernfs.Dentry{ + "net": fs.newNamespaceSymlink(task, fs.NextIno(), "net"), + "pid": fs.newNamespaceSymlink(task, fs.NextIno(), "pid"), + "user": fs.newNamespaceSymlink(task, fs.NextIno(), "user"), }), - "oom_score": newTaskOwnedFile(task, inoGen.NextIno(), 0444, newStaticFile("0\n")), - "oom_score_adj": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &oomScoreAdj{task: task}), - "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}), - "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), - "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}), - "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}), - "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}), + "oom_score": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newStaticFile("0\n")), + "oom_score_adj": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}), + "smaps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &smapsData{task: task}), + "stat": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), + "statm": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statmData{task: task}), + "status": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}), + "uid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { - contents["task"] = newSubtasks(task, pidns, inoGen, cgroupControllers) + contents["task"] = fs.newSubtasks(task, pidns, cgroupControllers) } if len(cgroupControllers) > 0 { - contents["cgroup"] = newTaskOwnedFile(task, inoGen.NextIno(), 0444, newCgroupData(cgroupControllers)) + contents["cgroup"] = fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) } taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. - taskInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) + taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode := &taskOwnedInode{Inode: taskInode, owner: task} dentry := &kernfs.Dentry{} @@ -126,9 +126,9 @@ type taskOwnedInode struct { var _ kernfs.Inode = (*taskOwnedInode)(nil) -func newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry { +func (fs *filesystem) newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry { // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), ino, inode, perm) + inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) taskInode := &taskOwnedInode{Inode: inode, owner: task} d := &kernfs.Dentry{} @@ -136,11 +136,11 @@ func newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode return d } -func newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry { +func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry { dir := &kernfs.StaticDirectory{} // Note: credentials are overridden by taskOwnedInode. - dir.Init(task.Credentials(), ino, perm) + dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) inode := &taskOwnedInode{Inode: dir, owner: task} d := &kernfs.Dentry{} diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 8ad976073..44ccc9e4a 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -53,8 +53,8 @@ func taskFDExists(t *kernel.Task, fd int32) bool { } type fdDir struct { - inoGen InoGenerator - task *kernel.Task + fs *filesystem + task *kernel.Task // When produceSymlinks is set, dirents produces for the FDs are reported // as symlink. Otherwise, they are reported as regular files. @@ -85,7 +85,7 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, abs dirent := vfs.Dirent{ Name: strconv.FormatUint(uint64(fd), 10), Type: typ, - Ino: i.inoGen.NextIno(), + Ino: i.fs.NextIno(), NextOff: offset + 1, } if err := cb.Handle(dirent); err != nil { @@ -110,15 +110,15 @@ type fdDirInode struct { var _ kernfs.Inode = (*fdDirInode)(nil) -func newFDDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { +func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry { inode := &fdDirInode{ fdDir: fdDir{ - inoGen: inoGen, + fs: fs, task: task, produceSymlink: true, }, } - inode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -137,7 +137,7 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro if !taskFDExists(i.task, fd) { return nil, syserror.ENOENT } - taskDentry := newFDSymlink(i.task, fd, i.inoGen.NextIno()) + taskDentry := i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()) return taskDentry.VFSDentry(), nil } @@ -186,12 +186,12 @@ type fdSymlink struct { var _ kernfs.Inode = (*fdSymlink)(nil) -func newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry { +func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry { inode := &fdSymlink{ task: task, fd: fd, } - inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777) + inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) d := &kernfs.Dentry{} d.Init(inode) @@ -234,14 +234,14 @@ type fdInfoDirInode struct { var _ kernfs.Inode = (*fdInfoDirInode)(nil) -func newFDInfoDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { +func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry { inode := &fdInfoDirInode{ fdDir: fdDir{ - inoGen: inoGen, - task: task, + fs: fs, + task: task, }, } - inode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -264,7 +264,7 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, task: i.task, fd: fd, } - dentry := newTaskOwnedFile(i.task, i.inoGen.NextIno(), 0444, data) + dentry := i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data) return dentry.VFSDentry(), nil } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 515f25327..2f297e48a 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -241,9 +241,9 @@ type commInode struct { task *kernel.Task } -func newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry { +func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry { inode := &commInode{task: task} - inode.DynamicBytesFile.Init(task.Credentials(), ino, &commData{task: task}, perm) + inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) d := &kernfs.Dentry{} d.Init(inode) @@ -596,9 +596,9 @@ type exeSymlink struct { var _ kernfs.Inode = (*exeSymlink)(nil) -func newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry { +func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry { inode := &exeSymlink{task: task} - inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777) + inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) d := &kernfs.Dentry{} d.Init(inode) @@ -729,7 +729,7 @@ type namespaceSymlink struct { task *kernel.Task } -func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry { +func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry { // Namespace symlinks should contain the namespace name and the inode number // for the namespace instance, so for example user:[123456]. We currently fake // the inode number by sticking the symlink inode in its place. @@ -737,7 +737,7 @@ func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentr inode := &namespaceSymlink{task: task} // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), ino, target) + inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) taskInode := &taskOwnedInode{Inode: inode, owner: task} d := &kernfs.Dentry{} @@ -780,11 +780,11 @@ type namespaceInode struct { var _ kernfs.Inode = (*namespaceInode)(nil) // Init initializes a namespace inode. -func (i *namespaceInode) Init(creds *auth.Credentials, ino uint64, perm linux.FileMode) { +func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - i.InodeAttrs.Init(creds, ino, linux.ModeRegular|perm) + i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } // Open implements Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 9c329341a..6bde27376 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -37,7 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func newTaskNetDir(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { +func (fs *filesystem) newTaskNetDir(task *kernel.Task) *kernfs.Dentry { k := task.Kernel() pidns := task.PIDNamespace() root := auth.NewRootCredentials(pidns.UserNamespace()) @@ -57,37 +57,37 @@ func newTaskNetDir(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]*kernfs.Dentry{ - "dev": newDentry(root, inoGen.NextIno(), 0444, &netDevData{stack: stack}), - "snmp": newDentry(root, inoGen.NextIno(), 0444, &netSnmpData{stack: stack}), + "dev": fs.newDentry(root, fs.NextIno(), 0444, &netDevData{stack: stack}), + "snmp": fs.newDentry(root, fs.NextIno(), 0444, &netSnmpData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, if the file contains a header the stub is just the header // otherwise it is an empty file. - "arp": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(arp)), - "netlink": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(netlink)), - "netstat": newDentry(root, inoGen.NextIno(), 0444, &netStatData{}), - "packet": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(packet)), - "protocols": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(protocols)), + "arp": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(arp)), + "netlink": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(netlink)), + "netstat": fs.newDentry(root, fs.NextIno(), 0444, &netStatData{}), + "packet": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(packet)), + "protocols": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(protocols)), // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000, // high res timer ticks per sec (ClockGetres returns 1ns resolution). - "psched": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(psched)), - "ptype": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(ptype)), - "route": newDentry(root, inoGen.NextIno(), 0444, &netRouteData{stack: stack}), - "tcp": newDentry(root, inoGen.NextIno(), 0444, &netTCPData{kernel: k}), - "udp": newDentry(root, inoGen.NextIno(), 0444, &netUDPData{kernel: k}), - "unix": newDentry(root, inoGen.NextIno(), 0444, &netUnixData{kernel: k}), + "psched": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(psched)), + "ptype": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(ptype)), + "route": fs.newDentry(root, fs.NextIno(), 0444, &netRouteData{stack: stack}), + "tcp": fs.newDentry(root, fs.NextIno(), 0444, &netTCPData{kernel: k}), + "udp": fs.newDentry(root, fs.NextIno(), 0444, &netUDPData{kernel: k}), + "unix": fs.newDentry(root, fs.NextIno(), 0444, &netUnixData{kernel: k}), } if stack.SupportsIPv6() { - contents["if_inet6"] = newDentry(root, inoGen.NextIno(), 0444, &ifinet6{stack: stack}) - contents["ipv6_route"] = newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")) - contents["tcp6"] = newDentry(root, inoGen.NextIno(), 0444, &netTCP6Data{kernel: k}) - contents["udp6"] = newDentry(root, inoGen.NextIno(), 0444, newStaticFile(upd6)) + contents["if_inet6"] = fs.newDentry(root, fs.NextIno(), 0444, &ifinet6{stack: stack}) + contents["ipv6_route"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")) + contents["tcp6"] = fs.newDentry(root, fs.NextIno(), 0444, &netTCP6Data{kernel: k}) + contents["udp6"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(upd6)) } } - return newTaskOwnedDir(task, inoGen.NextIno(), 0555, contents) + return fs.newTaskOwnedDir(task, fs.NextIno(), 0555, contents) } // ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6. diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 5aeda8c9b..b51d43954 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -33,11 +33,6 @@ const ( threadSelfName = "thread-self" ) -// InoGenerator generates unique inode numbers for a given filesystem. -type InoGenerator interface { - NextIno() uint64 -} - // tasksInode represents the inode for /proc/ directory. // // +stateify savable @@ -48,8 +43,8 @@ type tasksInode struct { kernfs.OrderedChildren kernfs.AlwaysValid - inoGen InoGenerator - pidns *kernel.PIDNamespace + fs *filesystem + pidns *kernel.PIDNamespace // '/proc/self' and '/proc/thread-self' have custom directory offsets in // Linux. So handle them outside of OrderedChildren. @@ -64,29 +59,29 @@ type tasksInode struct { var _ kernfs.Inode = (*tasksInode)(nil) -func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) { +func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]*kernfs.Dentry{ - "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))), - "filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}), - "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}), - "sys": newSysDir(root, inoGen, k), - "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}), - "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"), - "net": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/net"), - "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}), - "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}), - "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}), + "cpuinfo": fs.newDentry(root, fs.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))), + "filesystems": fs.newDentry(root, fs.NextIno(), 0444, &filesystemsData{}), + "loadavg": fs.newDentry(root, fs.NextIno(), 0444, &loadavgData{}), + "sys": fs.newSysDir(root, k), + "meminfo": fs.newDentry(root, fs.NextIno(), 0444, &meminfoData{}), + "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), + "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), + "stat": fs.newDentry(root, fs.NextIno(), 0444, &statData{}), + "uptime": fs.newDentry(root, fs.NextIno(), 0444, &uptimeData{}), + "version": fs.newDentry(root, fs.NextIno(), 0444, &versionData{}), } inode := &tasksInode{ pidns: pidns, - inoGen: inoGen, - selfSymlink: newSelfSymlink(root, inoGen.NextIno(), pidns).VFSDentry(), - threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), pidns).VFSDentry(), + fs: fs, + selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(), + threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(), cgroupControllers: cgroupControllers, } - inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -118,7 +113,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro return nil, syserror.ENOENT } - taskDentry := newTaskInode(i.inoGen, task, i.pidns, true, i.cgroupControllers) + taskDentry := i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers) return taskDentry.VFSDentry(), nil } @@ -144,7 +139,7 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback dirent := vfs.Dirent{ Name: selfName, Type: linux.DT_LNK, - Ino: i.inoGen.NextIno(), + Ino: i.fs.NextIno(), NextOff: offset + 1, } if err := cb.Handle(dirent); err != nil { @@ -156,7 +151,7 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback dirent := vfs.Dirent{ Name: threadSelfName, Type: linux.DT_LNK, - Ino: i.inoGen.NextIno(), + Ino: i.fs.NextIno(), NextOff: offset + 1, } if err := cb.Handle(dirent); err != nil { @@ -189,7 +184,7 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback dirent := vfs.Dirent{ Name: strconv.FormatUint(uint64(tid), 10), Type: linux.DT_DIR, - Ino: i.inoGen.NextIno(), + Ino: i.fs.NextIno(), NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1, } if err := cb.Handle(dirent); err != nil { diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index e5f13b69e..7d8983aa5 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -41,9 +41,9 @@ type selfSymlink struct { var _ kernfs.Inode = (*selfSymlink)(nil) -func newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry { +func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry { inode := &selfSymlink{pidns: pidns} - inode.Init(creds, ino, linux.ModeSymlink|0777) + inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) d := &kernfs.Dentry{} d.Init(inode) @@ -83,9 +83,9 @@ type threadSelfSymlink struct { var _ kernfs.Inode = (*threadSelfSymlink)(nil) -func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry { +func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry { inode := &threadSelfSymlink{pidns: pidns} - inode.Init(creds, ino, linux.ModeSymlink|0777) + inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) d := &kernfs.Dentry{} d.Init(inode) diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 0e90e02fe..6dac2afa4 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -30,89 +30,89 @@ import ( ) // newSysDir returns the dentry corresponding to /proc/sys directory. -func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { - return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "kernel": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "hostname": newDentry(root, inoGen.NextIno(), 0444, &hostnameData{}), - "shmall": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMALL)), - "shmmax": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMAX)), - "shmmni": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMNI)), +func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry { + return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "kernel": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "hostname": fs.newDentry(root, fs.NextIno(), 0444, &hostnameData{}), + "shmall": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMALL)), + "shmmax": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMAX)), + "shmmni": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMNI)), }), - "vm": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "mmap_min_addr": newDentry(root, inoGen.NextIno(), 0444, &mmapMinAddrData{k: k}), - "overcommit_memory": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0\n")), + "vm": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "mmap_min_addr": fs.newDentry(root, fs.NextIno(), 0444, &mmapMinAddrData{k: k}), + "overcommit_memory": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0\n")), }), - "net": newSysNetDir(root, inoGen, k), + "net": fs.newSysNetDir(root, k), }) } // newSysNetDir returns the dentry corresponding to /proc/sys/net directory. -func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { +func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry { var contents map[string]*kernfs.Dentry // TODO(gvisor.dev/issue/1833): Support for using the network stack in the // network namespace of the calling process. if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]*kernfs.Dentry{ - "ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}), + "ipv4": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("16000 65535")), - "ip_local_reserved_ports": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")), - "ipfrag_time": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("30")), - "ip_nonlocal_bind": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "ip_no_pmtu_disc": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), + "ip_local_port_range": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("16000 65535")), + "ip_local_reserved_ports": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")), + "ipfrag_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("30")), + "ip_nonlocal_bind": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "ip_no_pmtu_disc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")), // tcp_allowed_congestion_control tell the user what they are able to // do as an unprivledged process so we leave it empty. - "tcp_allowed_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")), - "tcp_available_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("reno")), - "tcp_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("reno")), + "tcp_allowed_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")), + "tcp_available_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")), + "tcp_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")), // Many of the following stub files are features netstack doesn't // support. The unsupported features return "0" to indicate they are // disabled. - "tcp_base_mss": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1280")), - "tcp_dsack": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_early_retrans": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_fack": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_fastopen": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_fastopen_key": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")), - "tcp_invalid_ratelimit": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_keepalive_intvl": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_keepalive_probes": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_keepalive_time": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("7200")), - "tcp_mtu_probing": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_no_metrics_save": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), - "tcp_probe_interval": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_probe_threshold": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "tcp_retries1": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("3")), - "tcp_retries2": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("15")), - "tcp_rfc1337": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), - "tcp_slow_start_after_idle": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), - "tcp_synack_retries": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("5")), - "tcp_syn_retries": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("3")), - "tcp_timestamps": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")), + "tcp_base_mss": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1280")), + "tcp_dsack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_early_retrans": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_fack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_fastopen": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_fastopen_key": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")), + "tcp_invalid_ratelimit": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_keepalive_intvl": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_keepalive_probes": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_keepalive_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("7200")), + "tcp_mtu_probing": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_no_metrics_save": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")), + "tcp_probe_interval": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_probe_threshold": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "tcp_retries1": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")), + "tcp_retries2": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("15")), + "tcp_rfc1337": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")), + "tcp_slow_start_after_idle": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")), + "tcp_synack_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")), + "tcp_syn_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")), + "tcp_timestamps": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")), }), - "core": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ - "default_qdisc": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("pfifo_fast")), - "message_burst": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("10")), - "message_cost": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("5")), - "optmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")), - "rmem_default": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")), - "rmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")), - "somaxconn": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("128")), - "wmem_default": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")), - "wmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")), + "core": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "default_qdisc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("pfifo_fast")), + "message_burst": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("10")), + "message_cost": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")), + "optmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")), + "rmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")), + "rmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")), + "somaxconn": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("128")), + "wmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")), + "wmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")), }), } } - return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, contents) + return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents) } // mmapMinAddrData implements vfs.DynamicBytesSource for diff --git a/pkg/sentry/fsimpl/sockfs/BUILD b/pkg/sentry/fsimpl/sockfs/BUILD index 52084ddb5..9453277b8 100644 --- a/pkg/sentry/fsimpl/sockfs/BUILD +++ b/pkg/sentry/fsimpl/sockfs/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/fspath", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index 239a9f4b4..ee0828a15 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -16,8 +16,11 @@ package sockfs import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -41,19 +44,42 @@ func (filesystemType) Name() string { return "sockfs" } +type filesystem struct { + kernfs.Filesystem + + devMinor uint32 +} + // NewFilesystem sets up and returns a new sockfs filesystem. // // Note that there should only ever be one instance of sockfs.Filesystem, // backing a global socket mount. -func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem { - fs := &kernfs.Filesystem{} - fs.VFSFilesystem().Init(vfsObj, filesystemType{}, fs) - return fs.VFSFilesystem() +func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, err + } + fs := &filesystem{ + devMinor: devMinor, + } + fs.Filesystem.VFSFilesystem().Init(vfsObj, filesystemType{}, fs) + return fs.Filesystem.VFSFilesystem(), nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + inode := vd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode) + b.PrependComponent(fmt.Sprintf("socket:[%d]", inode.InodeAttrs.Ino())) + return vfs.PrependPathSyntheticError{} } // inode implements kernfs.Inode. -// -// TODO(gvisor.dev/issue/1193): Device numbers. type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink @@ -67,11 +93,15 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr } // NewDentry constructs and returns a sockfs dentry. -func NewDentry(creds *auth.Credentials, ino uint64) *vfs.Dentry { +// +// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). +func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry { + fs := mnt.Filesystem().Impl().(*filesystem) + // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) i := &inode{} - i.Init(creds, ino, filemode) + i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) d := &kernfs.Dentry{} d.Init(i) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 00f7d6214..0af373604 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -37,6 +37,8 @@ type FilesystemType struct{} // filesystem implements vfs.FilesystemImpl. type filesystem struct { kernfs.Filesystem + + devMinor uint32 } // Name implements vfs.FilesystemType.Name. @@ -46,7 +48,14 @@ func (FilesystemType) Name() string { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - fs := &filesystem{} + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + fs := &filesystem{ + devMinor: devMinor, + } fs.VFSFilesystem().Init(vfsObj, &fsType, fs) k := kernel.KernelFromContext(ctx) maxCPUCores := k.ApplicationCores() @@ -77,6 +86,12 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return fs.VFSFilesystem(), root.VFSDentry(), nil } +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() +} + // dir implements kernfs.Inode. type dir struct { kernfs.InodeAttrs @@ -90,7 +105,7 @@ type dir struct { func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { d := &dir{} - d.InodeAttrs.Init(creds, fs.NextIno(), linux.ModeDirectory|0755) + d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) d.dentry.Init(d) @@ -127,7 +142,7 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) *kernfs.Dentry { c := &cpuFile{maxCores: maxCores} - c.DynamicBytesFile.Init(creds, fs.NextIno(), c, mode) + c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) d := &kernfs.Dentry{} d.Init(c) return d diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index efc931468..405928bd0 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -63,6 +63,9 @@ type filesystem struct { // clock is a realtime clock used to set timestamps in file operations. clock time.Clock + // devMinor is the filesystem's minor device number. devMinor is immutable. + devMinor uint32 + // mu serializes changes to the Dentry tree. mu sync.RWMutex @@ -96,11 +99,6 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt if memFileProvider == nil { panic("MemoryFileProviderFromContext returned nil") } - clock := time.RealtimeClockFromContext(ctx) - fs := filesystem{ - memFile: memFileProvider.MemoryFile(), - clock: clock, - } rootFileType := uint16(linux.S_IFDIR) newFSType := vfs.FilesystemType(&fstype) @@ -114,6 +112,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + clock := time.RealtimeClockFromContext(ctx) + fs := filesystem{ + memFile: memFileProvider.MemoryFile(), + clock: clock, + devMinor: devMinor, + } fs.vfsfs.Init(vfsObj, newFSType, &fs) var root *dentry @@ -125,6 +133,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt case linux.S_IFDIR: root = &fs.newDirectory(creds, 01777).dentry default: + fs.vfsfs.DecRef() return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType) } return &fs.vfsfs, &root.vfsd, nil @@ -132,6 +141,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release() { + fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } // dentry implements vfs.DentryImpl. @@ -188,8 +198,8 @@ func (d *dentry) DecRef() { // inode represents a filesystem object. type inode struct { - // clock is a realtime clock used to set timestamps in file operations. - clock time.Clock + // fs is the owning filesystem. fs is immutable. + fs *filesystem // refs is a reference count. refs is accessed using atomic memory // operations. @@ -232,7 +242,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, if mode.FileType() == 0 { panic("file type is required in FileMode") } - i.clock = fs.clock + i.fs = fs i.refs = 1 i.mode = uint32(mode) i.uid = uint32(creds.EffectiveKUID) @@ -327,7 +337,8 @@ func (i *inode) statTo(stat *linux.Statx) { stat.Atime = linux.NsecToStatxTimestamp(i.atime) stat.Ctime = linux.NsecToStatxTimestamp(i.ctime) stat.Mtime = linux.NsecToStatxTimestamp(i.mtime) - // TODO(gvisor.dev/issue/1197): Device number. + stat.DevMajor = linux.UNNAMED_MAJOR + stat.DevMinor = i.fs.devMinor switch impl := i.impl.(type) { case *regularFile: stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS @@ -401,7 +412,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu return syserror.EINVAL } } - now := i.clock.Now().Nanoseconds() + now := i.fs.clock.Now().Nanoseconds() if mask&linux.STATX_ATIME != 0 { if stat.Atime.Nsec == linux.UTIME_NOW { atomic.StoreInt64(&i.atime, now) @@ -518,7 +529,7 @@ func (i *inode) touchAtime(mnt *vfs.Mount) { if err := mnt.CheckBeginWrite(); err != nil { return } - now := i.clock.Now().Nanoseconds() + now := i.fs.clock.Now().Nanoseconds() i.mu.Lock() atomic.StoreInt64(&i.atime, now) i.mu.Unlock() @@ -527,7 +538,7 @@ func (i *inode) touchAtime(mnt *vfs.Mount) { // Preconditions: The caller has called vfs.Mount.CheckBeginWrite(). func (i *inode) touchCtime() { - now := i.clock.Now().Nanoseconds() + now := i.fs.clock.Now().Nanoseconds() i.mu.Lock() atomic.StoreInt64(&i.ctime, now) i.mu.Unlock() @@ -535,7 +546,7 @@ func (i *inode) touchCtime() { // Preconditions: The caller has called vfs.Mount.CheckBeginWrite(). func (i *inode) touchCMtime() { - now := i.clock.Now().Nanoseconds() + now := i.fs.clock.Now().Nanoseconds() i.mu.Lock() atomic.StoreInt64(&i.mtime, now) atomic.StoreInt64(&i.ctime, now) @@ -545,7 +556,7 @@ func (i *inode) touchCMtime() { // Preconditions: The caller has called vfs.Mount.CheckBeginWrite() and holds // inode.mu. func (i *inode) touchCMtimeLocked() { - now := i.clock.Now().Nanoseconds() + now := i.fs.clock.Now().Nanoseconds() atomic.StoreInt64(&i.mtime, now) atomic.StoreInt64(&i.ctime, now) } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 271ea5faf..3617da8c6 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -373,7 +373,10 @@ func (k *Kernel) Init(args InitKernelArgs) error { return fmt.Errorf("failed to initialize VFS: %v", err) } - pipeFilesystem := pipefs.NewFilesystem(&k.vfs) + pipeFilesystem, err := pipefs.NewFilesystem(&k.vfs) + if err != nil { + return fmt.Errorf("failed to create pipefs filesystem: %v", err) + } defer pipeFilesystem.DecRef() pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{}) if err != nil { @@ -381,7 +384,10 @@ func (k *Kernel) Init(args InitKernelArgs) error { } k.pipeMount = pipeMount - socketFilesystem := sockfs.NewFilesystem(&k.vfs) + socketFilesystem, err := sockfs.NewFilesystem(&k.vfs) + if err != nil { + return fmt.Errorf("failed to create sockfs filesystem: %v", err) + } defer socketFilesystem.DecRef() socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{}) if err != nil { diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index deedd35f7..e82d6cd1e 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -26,7 +26,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostfd", "//pkg/sentry/inet", diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index a8278bffc..677743113 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -50,8 +49,7 @@ var _ = socket.SocketVFS2(&socketVFS2{}) func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) - d := sockfs.NewDentry(t.Credentials(), fs.NextIno()) + d := sockfs.NewDentry(t.Credentials(), mnt) s := &socketVFS2{ socketOpsCommon: socketOpsCommon{ diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 09ca00a4a..7212d8644 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -20,7 +20,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go index dcd92b5cd..bb205be0d 100644 --- a/pkg/sentry/socket/netlink/provider_vfs2.go +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -16,7 +16,6 @@ package netlink import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -52,8 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol vfsfd := &s.vfsfd mnt := t.Kernel().SocketMount() - fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) - d := sockfs.NewDentry(t.Credentials(), fs.NextIno()) + d := sockfs.NewDentry(t.Credentials(), mnt) if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, DenyPWrite: true, diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ccf9fcf5c..6129fb83d 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -27,7 +27,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index f7d9b2ff4..191970d41 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -53,8 +52,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu } mnt := t.Kernel().SocketMount() - fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) - d := sockfs.NewDentry(t.Credentials(), fs.NextIno()) + d := sockfs.NewDentry(t.Credentials(), mnt) s := &SocketVFS2{ socketOpsCommon: socketOpsCommon{ diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 941a91097..de2cc4bdf 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -21,7 +21,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 06d838868..45e109361 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -50,8 +49,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // returns a corresponding file description. func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) - d := sockfs.NewDentry(t.Credentials(), fs.NextIno()) + d := sockfs.NewDentry(t.Credentials(), mnt) fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d) if err != nil { diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index adebaeefb..caf770fd5 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -202,7 +202,7 @@ func (fs *anonFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts St Ino: 1, Size: 0, Blocks: 0, - DevMajor: 0, + DevMajor: linux.UNNAMED_MAJOR, DevMinor: fs.devMinor, }, nil } diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go index bda5576fa..1e9dffc8f 100644 --- a/pkg/sentry/vfs/device.go +++ b/pkg/sentry/vfs/device.go @@ -103,7 +103,7 @@ func (vfs *VirtualFilesystem) OpenDeviceSpecialFile(ctx context.Context, mnt *Mo } // GetAnonBlockDevMinor allocates and returns an unused minor device number for -// an "anonymous" block device with major number 0. +// an "anonymous" block device with major number UNNAMED_MAJOR. func (vfs *VirtualFilesystem) GetAnonBlockDevMinor() (uint32, error) { vfs.anonBlockDevMinorMu.Lock() defer vfs.anonBlockDevMinorMu.Unlock() diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 8c8bad11c..f802bc9fb 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -334,7 +334,10 @@ func New(args Args) (*Loader, error) { if kernel.VFS2Enabled { // Set up host mount that will be used for imported fds. - hostFilesystem := hostvfs2.NewFilesystem(k.VFS()) + hostFilesystem, err := hostvfs2.NewFilesystem(k.VFS()) + if err != nil { + return nil, fmt.Errorf("failed to create hostfs filesystem: %v", err) + } defer hostFilesystem.DecRef() hostMount, err := k.VFS().NewDisconnectedMount(hostFilesystem, nil, &vfs.MountOptions{}) if err != nil { -- cgit v1.2.3 From 8dd1d5b75a95100e747b1a88e9e557d5d2c30b64 Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Tue, 12 May 2020 12:24:23 -0700 Subject: Don't call kernel.Task.Block() from netstack.SocketOperations.Write(). kernel.Task.Block() requires that the caller is running on the task goroutine. netstack.SocketOperations.Write() uses kernel.TaskFromContext() to call kernel.Task.Block() even if it's not running on the task goroutine. Stop doing that. PiperOrigin-RevId: 311178335 --- pkg/amutex/BUILD | 1 + pkg/amutex/amutex.go | 17 +++++++++++++++++ pkg/sentry/socket/netstack/BUILD | 1 + pkg/sentry/socket/netstack/netstack.go | 13 +++++-------- pkg/sentry/socket/netstack/netstack_vfs2.go | 17 +++-------------- 5 files changed, 27 insertions(+), 22 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 9612f072e..ffc918846 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -6,6 +6,7 @@ go_library( name = "amutex", srcs = ["amutex.go"], visibility = ["//:sandbox"], + deps = ["//pkg/syserror"], ) go_test( diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go index 1c4fd1784..a078a31db 100644 --- a/pkg/amutex/amutex.go +++ b/pkg/amutex/amutex.go @@ -18,6 +18,8 @@ package amutex import ( "sync/atomic" + + "gvisor.dev/gvisor/pkg/syserror" ) // Sleeper must be implemented by users of the abortable mutex to allow for @@ -53,6 +55,21 @@ func (NoopSleeper) SleepFinish(success bool) {} // Interrupted implements Sleeper.Interrupted. func (NoopSleeper) Interrupted() bool { return false } +// Block blocks until either receiving from ch succeeds (in which case it +// returns nil) or sleeper is interrupted (in which case it returns +// syserror.ErrInterrupted). +func Block(sleeper Sleeper, ch <-chan struct{}) error { + cancel := sleeper.SleepStart() + select { + case <-ch: + sleeper.SleepFinish(true) + return nil + case <-cancel: + sleeper.SleepFinish(false) + return syserror.ErrInterrupted + } +} + // AbortableMutex is an abortable mutex. It allows Lock() to be aborted while it // waits to acquire the mutex. type AbortableMutex struct { diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 6129fb83d..333e0042e 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -18,6 +18,7 @@ go_library( ], deps = [ "//pkg/abi/linux", + "//pkg/amutex", "//pkg/binary", "//pkg/context", "//pkg/log", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 81053d8ef..9d032f052 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -34,6 +34,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -553,11 +554,9 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } if resCh != nil { - t := kernel.TaskFromContext(ctx) - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err).ToError() + if err := amutex.Block(ctx, resCh); err != nil { + return 0, err } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) } @@ -626,11 +625,9 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } if resCh != nil { - t := kernel.TaskFromContext(ctx) - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err).ToError() + if err := amutex.Block(ctx, resCh); err != nil { + return 0, err } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ Atomic: true, // See above. }) diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 191970d41..fcd8013c0 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -16,6 +16,7 @@ package netstack import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" @@ -89,11 +90,6 @@ func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { s.socketOpsCommon.EventUnregister(e) } -// PRead implements vfs.FileDescriptionImpl. -func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { - return 0, syserror.ESPIPE -} - // Read implements vfs.FileDescriptionImpl. func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { // All flags other than RWF_NOWAIT should be ignored. @@ -115,11 +111,6 @@ func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs. return int64(n), nil } -// PWrite implements vfs.FileDescriptionImpl. -func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - return 0, syserror.ESPIPE -} - // Write implements vfs.FileDescriptionImpl. func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { // All flags other than RWF_NOWAIT should be ignored. @@ -135,11 +126,9 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs } if resCh != nil { - t := kernel.TaskFromContext(ctx) - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err).ToError() + if err := amutex.Block(ctx, resCh); err != nil { + return 0, err } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) } -- cgit v1.2.3 From 8b8774d7152eb61fc6273bbae679e80c34188517 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 13 May 2020 19:47:42 -0700 Subject: Stub support for TCP_SYNCNT and TCP_WINDOW_CLAMP. This change adds support for TCP_SYNCNT and TCP_WINDOW_CLAMP options in GetSockOpt/SetSockOpt. This change does not really change any behaviour in Netstack and only stores/returns the stored value. Actual honoring of these options will be added as required. Fixes #2626, #2625 PiperOrigin-RevId: 311453777 --- pkg/sentry/socket/netstack/netstack.go | 39 ++++++ pkg/tcpip/tcpip.go | 17 +++ pkg/tcpip/transport/tcp/endpoint.go | 64 ++++++++- pkg/tcpip/transport/tcp/protocol.go | 21 +++ test/syscalls/linux/socket_ip_tcp_generic.cc | 45 +++++++ test/syscalls/linux/tcp_socket.cc | 185 ++++++++++++++++++++++++++- 6 files changed, 367 insertions(+), 4 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9d032f052..9dea2b5ff 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1321,6 +1321,29 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_SYNCNT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.TCPSynCountOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil + + case linux.TCP_WINDOW_CLAMP: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.TCPWindowClampOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil default: emitUnimplementedEventTCP(t, name) } @@ -1790,6 +1813,22 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + case linux.TCP_SYNCNT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := usermem.ByteOrder.Uint32(optVal) + + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v))) + + case linux.TCP_WINDOW_CLAMP: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := usermem.ByteOrder.Uint32(optVal) + + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 1ca4088c9..6270146fa 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -622,6 +622,19 @@ const ( // // A zero value indicates the default. TTLOption + + // TCPSynCountOption is used by SetSockOpt/GetSockOpt to specify the number of + // SYN retransmits that TCP should send before aborting the attempt to + // connect. It cannot exceed 255. + // + // NOTE: This option is currently only stubbed out and is no-op. + TCPSynCountOption + + // TCPWindowClampOption is used by SetSockOpt/GetSockOpt to bound the size + // of the advertised window to this value. + // + // NOTE: This option is currently only stubed out and is a no-op + TCPWindowClampOption ) // ErrorOption is used in GetSockOpt to specify that the last error reported by @@ -690,6 +703,10 @@ type TCPMinRTOOption time.Duration // switches to using SYN cookies. type TCPSynRcvdCountThresholdOption uint64 +// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide +// default for number of times SYN is retransmitted before aborting a connect. +type TCPSynRetriesOption uint8 + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 07d3e64c8..71735029e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -470,6 +470,17 @@ type endpoint struct { // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 + // maxSynRetries is the maximum number of SYN retransmits that TCP should + // send before aborting the attempt to connect. It cannot exceed 255. + // + // NOTE: This is currently a no-op and does not change the SYN + // retransmissions. + maxSynRetries uint8 + + // windowClamp is used to bound the size of the advertised window to + // this value. + windowClamp uint32 + // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the // protocol goroutine is signaled via sndWaker. @@ -795,8 +806,10 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, - uniqueID: s.UniqueID(), - txHash: s.Rand().Uint32(), + uniqueID: s.UniqueID(), + txHash: s.Rand().Uint32(), + windowClamp: DefaultReceiveBufferSize, + maxSynRetries: DefaultSynRetries, } var ss SendBufferSizeOption @@ -829,6 +842,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.tcpLingerTimeout = time.Duration(tcpLT) } + var synRetries tcpip.TCPSynRetriesOption + if err := s.TransportProtocolOption(ProtocolNumber, &synRetries); err == nil { + e.maxSynRetries = uint8(synRetries) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -1603,6 +1621,36 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.ttl = uint8(v) e.UnlockUser() + case tcpip.TCPSynCountOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.maxSynRetries = uint8(v) + e.UnlockUser() + + case tcpip.TCPWindowClampOption: + if v == 0 { + e.LockUser() + switch e.EndpointState() { + case StateClose, StateInitial: + e.windowClamp = 0 + e.UnlockUser() + return nil + default: + e.UnlockUser() + return tcpip.ErrInvalidOptionValue + } + } + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if v < rs.Min/2 { + v = rs.Min / 2 + } + } + e.LockUser() + e.windowClamp = uint32(v) + e.UnlockUser() } return nil } @@ -1826,6 +1874,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.UnlockUser() return v, nil + case tcpip.TCPSynCountOption: + e.LockUser() + v := int(e.maxSynRetries) + e.UnlockUser() + return v, nil + + case tcpip.TCPWindowClampOption: + e.LockUser() + v := int(e.windowClamp) + e.UnlockUser() + return v, nil + default: return -1, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index cfd9a4e8e..1e64283b5 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -64,6 +64,10 @@ const ( // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger // in TIME_WAIT state before being marked closed. DefaultTCPTimeWaitTimeout = 60 * time.Second + + // DefaultSynRetries is the default value for the number of SYN retransmits + // before a connect is aborted. + DefaultSynRetries = 6 ) // SACKEnabled option can be used to enable SACK support in the TCP @@ -164,6 +168,7 @@ type protocol struct { tcpTimeWaitTimeout time.Duration minRTO time.Duration synRcvdCount synRcvdCounter + synRetries uint8 dispatcher *dispatcher } @@ -346,6 +351,15 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPSynRetriesOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.synRetries = uint8(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -420,6 +434,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *tcpip.TCPSynRetriesOption: + p.mu.RLock() + *v = tcpip.TCPSynRetriesOption(p.synRetries) + p.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -452,6 +472,7 @@ func NewProtocol() stack.TransportProtocol { tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), + synRetries: DefaultSynRetries, minRTO: MinRTO, } } diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 27779e47c..fa81845fd 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -876,6 +876,51 @@ TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { EXPECT_EQ(get, kAbove); } +TEST_P(TCPSocketPairTest, SetTCPWindowClampBelowMinRcvBufConnectedSocket) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + // Discover minimum receive buf by setting a really low value + // for the receive buffer. + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &kZero, + sizeof(kZero)), + SyscallSucceeds()); + + // Now retrieve the minimum value for SO_RCVBUF as the set above should + // have caused SO_RCVBUF for the socket to be set to the minimum. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int min_so_rcvbuf = get; + + { + // Setting TCP_WINDOW_CLAMP to zero for a connected socket is not permitted. + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &kZero, sizeof(kZero)), + SyscallFailsWithErrno(EINVAL)); + + // Non-zero clamp values below MIN_SO_RCVBUF/2 should result in the clamp + // being set to MIN_SO_RCVBUF/2. + int below_half_min_so_rcvbuf = min_so_rcvbuf / 2 - 1; + EXPECT_THAT( + setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &below_half_min_so_rcvbuf, sizeof(below_half_min_so_rcvbuf)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(min_so_rcvbuf / 2, get); + } +} + TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { DisableSave ds; // Too many syscalls. constexpr int kThreadCount = 1000; diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index d9c1ac0e1..a4d2953e1 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1313,7 +1313,7 @@ TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) { int get = -1; socklen_t get_len = sizeof(get); ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); EXPECT_EQ(get, 0); @@ -1326,7 +1326,7 @@ TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) { int get = -1; socklen_t get_len = sizeof(get); ASSERT_THAT( - getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); EXPECT_EQ(get, 0); @@ -1378,6 +1378,187 @@ TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) { SyscallSucceedsWithValue(0)); } +TEST_P(SimpleTcpSocketTest, SetTCPSynCntLessThanOne) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int default_syn_cnt = get; + + { + // TCP_SYNCNT less than 1 should be rejected with an EINVAL. + constexpr int kZero = 0; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kZero, sizeof(kZero)), + SyscallFailsWithErrno(EINVAL)); + + // TCP_SYNCNT less than 1 should be rejected with an EINVAL. + constexpr int kNeg = -1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kNeg, sizeof(kNeg)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(default_syn_cnt, get); + } +} + +TEST_P(SimpleTcpSocketTest, GetTCPSynCntDefault) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int get = -1; + socklen_t get_len = sizeof(get); + constexpr int kDefaultSynCnt = 6; + + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultSynCnt); +} + +TEST_P(SimpleTcpSocketTest, SetTCPSynCntGreaterThanOne) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + constexpr int kTCPSynCnt = 20; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kTCPSynCnt, + sizeof(kTCPSynCnt)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPSynCnt); +} + +TEST_P(SimpleTcpSocketTest, SetTCPSynCntAboveMax) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int default_syn_cnt = get; + { + constexpr int kTCPSynCnt = 256; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kTCPSynCnt, + sizeof(kTCPSynCnt)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, default_syn_cnt); + } +} + +TEST_P(SimpleTcpSocketTest, SetTCPWindowClampBelowMinRcvBuf) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Discover minimum receive buf by setting a really low value + // for the receive buffer. + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &kZero, sizeof(kZero)), + SyscallSucceeds()); + + // Now retrieve the minimum value for SO_RCVBUF as the set above should + // have caused SO_RCVBUF for the socket to be set to the minimum. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int min_so_rcvbuf = get; + + { + // TCP_WINDOW_CLAMP less than min_so_rcvbuf/2 should be set to + // min_so_rcvbuf/2. + int below_half_min_rcvbuf = min_so_rcvbuf / 2 - 1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &below_half_min_rcvbuf, sizeof(below_half_min_rcvbuf)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(min_so_rcvbuf / 2, get); + } +} + +TEST_P(SimpleTcpSocketTest, SetTCPWindowClampZeroClosedSocket) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + constexpr int kZero = 0; + ASSERT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &kZero, sizeof(kZero)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kZero); +} + +TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Discover minimum receive buf by setting a really low value + // for the receive buffer. + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &kZero, sizeof(kZero)), + SyscallSucceeds()); + + // Now retrieve the minimum value for SO_RCVBUF as the set above should + // have caused SO_RCVBUF for the socket to be set to the minimum. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int min_so_rcvbuf = get; + + { + int above_half_min_rcv_buf = min_so_rcvbuf / 2 + 1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &above_half_min_rcv_buf, sizeof(above_half_min_rcv_buf)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(above_half_min_rcv_buf, get); + } +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 420b791a3d6e0e6e2fc30c6f8be013bce7ca6549 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Fri, 15 May 2020 20:03:54 -0700 Subject: Minor formatting updates for gvisor.dev. * Aggregate architecture Overview in "What is gVisor?" as it makes more sense in one place. * Drop "user-space kernel" and use "application kernel". The term "user-space kernel" is confusing when some platform implementation do not run in user-space (instead running in guest ring zero). * Clear up the relationship between the Platform page in the user guide and the Platform page in the architecture guide, and ensure they are cross-linked. * Restore the call-to-action quick start link in the main page, and drop the GitHub link (which also appears in the top-right). * Improve image formatting by centering all doc and blog images, and move the image captions to the alt text. PiperOrigin-RevId: 311845158 --- README.md | 79 ++--- g3doc/BUILD | 4 + g3doc/Layers.png | Bin 0 -> 11044 bytes g3doc/Layers.svg | 1 + g3doc/Machine-Virtualization.png | Bin 0 -> 13205 bytes g3doc/Machine-Virtualization.svg | 1 + g3doc/README.md | 161 +++++++++- g3doc/Rule-Based-Execution.png | Bin 0 -> 6780 bytes g3doc/Rule-Based-Execution.svg | 1 + g3doc/Sentry-Gofer.png | Bin 0 -> 9064 bytes g3doc/Sentry-Gofer.svg | 1 + g3doc/architecture_guide/BUILD | 30 +- g3doc/architecture_guide/Layers.png | Bin 11044 -> 0 bytes g3doc/architecture_guide/Layers.svg | 1 - .../architecture_guide/Machine-Virtualization.png | Bin 13205 -> 0 bytes .../architecture_guide/Machine-Virtualization.svg | 1 - g3doc/architecture_guide/README.md | 83 ----- g3doc/architecture_guide/Rule-Based-Execution.png | Bin 6780 -> 0 bytes g3doc/architecture_guide/Rule-Based-Execution.svg | 1 - g3doc/architecture_guide/Sentry-Gofer.png | Bin 9064 -> 0 bytes g3doc/architecture_guide/Sentry-Gofer.svg | 1 - g3doc/architecture_guide/performance.md | 35 ++- g3doc/architecture_guide/platforms.md | 109 +++---- g3doc/architecture_guide/platforms.png | Bin 0 -> 21384 bytes g3doc/architecture_guide/platforms.svg | 334 +++++++++++++++++++++ g3doc/architecture_guide/resources.md | 27 +- g3doc/architecture_guide/resources.png | Bin 0 -> 16621 bytes g3doc/architecture_guide/resources.svg | 208 +++++++++++++ g3doc/architecture_guide/security.md | 28 +- g3doc/architecture_guide/security.png | Bin 0 -> 16932 bytes g3doc/architecture_guide/security.svg | 153 ++++++++++ g3doc/user_guide/filesystem.md | 4 +- g3doc/user_guide/platforms.md | 100 +++--- pkg/sentry/arch/syscalls_arm64.go | 2 +- pkg/sentry/kernel/pipe/pipe_util.go | 2 +- pkg/sentry/kernel/task_syscall.go | 4 +- pkg/sentry/socket/netstack/netstack.go | 6 +- pkg/tcpip/stack/stack.go | 4 +- pkg/tcpip/transport/tcp/endpoint.go | 2 +- pkg/tcpip/transport/tcp/tcp_test.go | 2 +- runsc/cmd/help.go | 12 +- website/BUILD | 1 - website/_layouts/docs.html | 2 + website/_sass/front.scss | 4 +- website/_sass/style.scss | 10 + website/blog/2019-11-18-security-basics.md | 28 +- website/blog/2020-04-02-networking-security.md | 8 +- website/index.md | 10 +- 48 files changed, 1054 insertions(+), 406 deletions(-) create mode 100644 g3doc/Layers.png create mode 100644 g3doc/Layers.svg create mode 100644 g3doc/Machine-Virtualization.png create mode 100644 g3doc/Machine-Virtualization.svg create mode 100644 g3doc/Rule-Based-Execution.png create mode 100644 g3doc/Rule-Based-Execution.svg create mode 100644 g3doc/Sentry-Gofer.png create mode 100644 g3doc/Sentry-Gofer.svg delete mode 100644 g3doc/architecture_guide/Layers.png delete mode 100644 g3doc/architecture_guide/Layers.svg delete mode 100644 g3doc/architecture_guide/Machine-Virtualization.png delete mode 100644 g3doc/architecture_guide/Machine-Virtualization.svg delete mode 100644 g3doc/architecture_guide/README.md delete mode 100644 g3doc/architecture_guide/Rule-Based-Execution.png delete mode 100644 g3doc/architecture_guide/Rule-Based-Execution.svg delete mode 100644 g3doc/architecture_guide/Sentry-Gofer.png delete mode 100644 g3doc/architecture_guide/Sentry-Gofer.svg create mode 100644 g3doc/architecture_guide/platforms.png create mode 100644 g3doc/architecture_guide/platforms.svg create mode 100644 g3doc/architecture_guide/resources.png create mode 100644 g3doc/architecture_guide/resources.svg create mode 100644 g3doc/architecture_guide/security.png create mode 100644 g3doc/architecture_guide/security.svg (limited to 'pkg/sentry/socket/netstack') diff --git a/README.md b/README.md index de3e06f4e..442f5672a 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ ## What is gVisor? -**gVisor** is a user-space kernel, written in Go, that implements a substantial +**gVisor** is a application kernel, written in Go, that implements a substantial portion of the Linux system surface. It includes an [Open Container Initiative (OCI)][oci] runtime called `runsc` that provides an isolation boundary between the application and the host kernel. The `runsc` @@ -15,16 +15,17 @@ containers. ## Why does gVisor exist? Containers are not a [**sandbox**][sandbox]. While containers have -revolutionized how we develop, package, and deploy applications, running -untrusted or potentially malicious code without additional isolation is not a -good idea. The efficiency and performance gains from using a single, shared -kernel also mean that container escape is possible with a single vulnerability. - -gVisor is a user-space kernel for containers. It limits the host kernel surface -accessible to the application while still giving the application access to all -the features it expects. Unlike most kernels, gVisor does not assume or require -a fixed set of physical resources; instead, it leverages existing host kernel -functionality and runs as a normal user-space process. In other words, gVisor +revolutionized how we develop, package, and deploy applications, using them to +run untrusted or potentially malicious code without additional isolation is not +a good idea. While using a single, shared kernel allows for efficiency and +performance gains, it also means that container escape is possible with a single +vulnerability. + +gVisor is an application kernel for containers. It limits the host kernel +surface accessible to the application while still giving the application access +to all the features it expects. Unlike most kernels, gVisor does not assume or +require a fixed set of physical resources; instead, it leverages existing host +kernel functionality and runs as a normal process. In other words, gVisor implements Linux by way of Linux. gVisor should not be confused with technologies and tools to harden containers @@ -39,33 +40,24 @@ be found at [gvisor.dev][gvisor-dev]. ## Installing from source -gVisor currently requires x86\_64 Linux to build, though support for other -architectures may become available in the future. +gVisor builds on x86_64 and ARM64. Other architectures may become available in +the future. + +For the purposes of these instructions, [bazel][bazel] and other build +dependencies are wrapped in a build container. It is possible to use +[bazel][bazel] directly, or type `make help` for standard targets. ### Requirements Make sure the following dependencies are installed: * Linux 4.14.77+ ([older linux][old-linux]) -* [git][git] -* [Bazel][bazel] 1.2+ -* [Python][python] * [Docker version 17.09.0 or greater][docker] -* C++ toolchain supporting C++17 (GCC 7+, Clang 5+) -* Gold linker (e.g. `binutils-gold` package on Ubuntu) ### Building Build and install the `runsc` binary: -``` -bazel build runsc -sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin -``` - -If you don't want to install bazel on your system, you can build runsc in a -Docker container: - ``` make runsc sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin @@ -73,41 +65,19 @@ sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin ### Testing -The test suite can be run with Bazel: - -``` -bazel test //... -``` - -or in a Docker container: +To run standard test suites, you can use: ``` make unit-tests make tests ``` -### Using remote execution - -If you have a [Remote Build Execution][rbe] environment, you can use it to speed -up build and test cycles. - -You must authenticate with the project first: +To run specific tests, you can specify the target: ``` -gcloud auth application-default login --no-launch-browser +make test TARGET="//runsc:version_test" ``` -Then invoke bazel with the following flags: - -``` ---config=remote ---project_id=$PROJECT ---remote_instance_name=projects/$PROJECT/instances/default_instance -``` - -You can also add those flags to your local ~/.bazelrc to avoid needing to -specify them each time on the command line. - ### Using `go get` This project uses [bazel][bazel] to build and manage dependencies. A synthetic @@ -128,7 +98,7 @@ development on this branch is not supported. Development should occur on the ## Community & Governance -The governance model is documented in our [community][community] repository. +See [GOVERNANCE.md](GOVERANCE.md) for project governance information. The [gvisor-users mailing list][gvisor-users-list] and [gvisor-dev mailing list][gvisor-dev-list] are good starting points for @@ -145,12 +115,9 @@ See [Contributing.md](CONTRIBUTING.md). [bazel]: https://bazel.build [community]: https://gvisor.googlesource.com/community [docker]: https://www.docker.com -[git]: https://git-scm.com [gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users +[gvisor-dev]: https://gvisor.dev [gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev [oci]: https://www.opencontainers.org [old-linux]: https://gvisor.dev/docs/user_guide/networking/#gso -[python]: https://python.org -[rbe]: https://blog.bazel.build/2018/10/05/remote-build-execution.html [sandbox]: https://en.wikipedia.org/wiki/Sandbox_(computer_security) -[gvisor-dev]: https://gvisor.dev diff --git a/g3doc/BUILD b/g3doc/BUILD index 24177ad06..dbbf96204 100644 --- a/g3doc/BUILD +++ b/g3doc/BUILD @@ -9,6 +9,10 @@ doc( name = "index", src = "README.md", category = "Project", + data = glob([ + "*.png", + "*.svg", + ]), permalink = "/docs/", weight = "0", ) diff --git a/g3doc/Layers.png b/g3doc/Layers.png new file mode 100644 index 000000000..308c6c451 Binary files /dev/null and b/g3doc/Layers.png differ diff --git a/g3doc/Layers.svg b/g3doc/Layers.svg new file mode 100644 index 000000000..0a366f841 --- /dev/null +++ b/g3doc/Layers.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/g3doc/Machine-Virtualization.png b/g3doc/Machine-Virtualization.png new file mode 100644 index 000000000..1ba2ed6b2 Binary files /dev/null and b/g3doc/Machine-Virtualization.png differ diff --git a/g3doc/Machine-Virtualization.svg b/g3doc/Machine-Virtualization.svg new file mode 100644 index 000000000..5352da07b --- /dev/null +++ b/g3doc/Machine-Virtualization.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/g3doc/README.md b/g3doc/README.md index 7999f5d47..304a91493 100644 --- a/g3doc/README.md +++ b/g3doc/README.md @@ -1,6 +1,6 @@ # What is gVisor? -gVisor is a user-space kernel, written in Go, that implements a substantial +gVisor is an application kernel, written in Go, that implements a substantial portion of the [Linux system call interface][linux]. It provides an additional layer of isolation between running applications and the host operating system. @@ -9,19 +9,160 @@ that makes it easy to work with existing container tooling. The `runsc` runtime integrates with Docker and Kubernetes, making it simple to run sandboxed containers. -gVisor takes a distinct approach to container sandboxing and makes a different -set of technical trade-offs compared to existing sandbox technologies, thus -providing new tools and ideas for the container security landscape. - gVisor can be used with Docker, Kubernetes, or directly using `runsc`. Use the links below to see detailed instructions for each of them: -* [Docker](./user_guide/quick_start/docker/): The quickest and easiest way to - get started. -* [Kubernetes](./user_guide/quick_start/kubernetes/): Isolate Pods in your K8s - cluster with gVisor. -* [OCI Quick Start](./user_guide/quick_start/oci/): Expert mode. Customize +* [Docker](./user_guide/quick_start/docker.md): The quickest and easiest way + to get started. +* [Kubernetes](./user_guide/quick_start/kubernetes.md): Isolate Pods in your + K8s cluster with gVisor. +* [OCI Quick Start](./user_guide/quick_start/oci.md): Expert mode. Customize gVisor for your environment. +## What does gVisor do? + +gVisor provides a virtualized environment in order to sandbox containers. The +system interfaces normally implemented by the host kernel are moved into a +distinct, per-sandbox application kernel in order to minimize the risk of an +container escape exploit. gVisor does not introduce large fixed overheads +however, and still retains a process-like model with respect to resource +utilization. + +## How is this different? + +Two other approaches are commonly taken to provide stronger isolation than +native containers. + +**Machine-level virtualization**, such as [KVM][kvm] and [Xen][xen], exposes +virtualized hardware to a guest kernel via a Virtual Machine Monitor (VMM). This +virtualized hardware is generally enlightened (paravirtualized) and additional +mechanisms can be used to improve the visibility between the guest and host +(e.g. balloon drivers, paravirtualized spinlocks). Running containers in +distinct virtual machines can provide great isolation, compatibility and +performance (though nested virtualization may bring challenges in this area), +but for containers it often requires additional proxies and agents, and may +require a larger resource footprint and slower start-up times. + +![Machine-level virtualization](Machine-Virtualization.png "Machine-level virtualization") + +**Rule-based execution**, such as [seccomp][seccomp], [SELinux][selinux] and +[AppArmor][apparmor], allows the specification of a fine-grained security policy +for an application or container. These schemes typically rely on hooks +implemented inside the host kernel to enforce the rules. If the surface can be +made small enough, then this is an excellent way to sandbox applications and +maintain native performance. However, in practice it can be extremely difficult +(if not impossible) to reliably define a policy for arbitrary, previously +unknown applications, making this approach challenging to apply universally. + +![Rule-based execution](Rule-Based-Execution.png "Rule-based execution") + +Rule-based execution is often combined with additional layers for +defense-in-depth. + +**gVisor** provides a third isolation mechanism, distinct from those above. + +gVisor intercepts application system calls and acts as the guest kernel, without +the need for translation through virtualized hardware. gVisor may be thought of +as either a merged guest kernel and VMM, or as seccomp on steroids. This +architecture allows it to provide a flexible resource footprint (i.e. one based +on threads and memory mappings, not fixed guest physical resources) while also +lowering the fixed costs of virtualization. However, this comes at the price of +reduced application compatibility and higher per-system call overhead. + +![gVisor](Layers.png "gVisor") + +On top of this, gVisor employs rule-based execution to provide defense-in-depth +(details below). + +gVisor's approach is similar to [User Mode Linux (UML)][uml], although UML +virtualizes hardware internally and thus provides a fixed resource footprint. + +Each of the above approaches may excel in distinct scenarios. For example, +machine-level virtualization will face challenges achieving high density, while +gVisor may provide poor performance for system call heavy workloads. + +## Why Go? + +gVisor is written in [Go][golang] in order to avoid security pitfalls that can +plague kernels. With Go, there are strong types, built-in bounds checks, no +uninitialized variables, no use-after-free, no stack overflow, and a built-in +race detector. However, the use of Go has its challenges, and the runtime often +introduces performance overhead. + +## What are the different components? + +A gVisor sandbox consists of multiple processes. These processes collectively +comprise an environment in which one or more containers can be run. + +Each sandbox has its own isolated instance of: + +* The **Sentry**, which is a kernel that runs the containers and intercepts + and responds to system calls made by the application. + +Each container running in the sandbox has its own isolated instance of: + +* A **Gofer** which provides file system access to the containers. + +![gVisor architecture diagram](Sentry-Gofer.png "gVisor architecture diagram") + +## What is runsc? + +The entrypoint to running a sandboxed container is the `runsc` executable. +`runsc` implements the [Open Container Initiative (OCI)][oci] runtime +specification, which is used by Docker and Kubernetes. This means that OCI +compatible _filesystem bundles_ can be run by `runsc`. Filesystem bundles are +comprised of a `config.json` file containing container configuration, and a root +filesystem for the container. Please see the [OCI runtime spec][runtime-spec] +for more information on filesystem bundles. `runsc` implements multiple commands +that perform various functions such as starting, stopping, listing, and querying +the status of containers. + +### Sentry + + + +The Sentry is the largest component of gVisor. It can be thought of as a +application kernel. The Sentry implements all the kernel functionality needed by +the application, including: system calls, signal delivery, memory management and +page faulting logic, the threading model, and more. + +When the application makes a system call, the +[Platform](./architecture_guide/platforms.md) redirects the call to the Sentry, +which will do the necessary work to service it. It is important to note that the +Sentry does not pass system calls through to the host kernel. As a userspace +application, the Sentry will make some host system calls to support its +operation, but it does not allow the application to directly control the system +calls it makes. For example, the Sentry is not able to open files directly; file +system operations that extend beyond the sandbox (not internal `/proc` files, +pipes, etc) are sent to the Gofer, described below. + +### Gofer + + + +The Gofer is a standard host process which is started with each container and +communicates with the Sentry via the [9P protocol][9p] over a socket or shared +memory channel. The Sentry process is started in a restricted seccomp container +without access to file system resources. The Gofer mediates all access to the +these resources, providing an additional level of isolation. + +### Application + +The application is a normal Linux binary provided to gVisor in an OCI runtime +bundle. gVisor aims to provide an environment equivalent to Linux v4.4, so +applications should be able to run unmodified. However, gVisor does not +presently implement every system call, `/proc` file, or `/sys` file so some +incompatibilities may occur. See [Commpatibility](./user_guide/compatibility.md) +for more information. + +[9p]: https://en.wikipedia.org/wiki/9P_(protocol) +[apparmor]: https://wiki.ubuntu.com/AppArmor +[golang]: https://golang.org +[kvm]: https://www.linux-kvm.org [linux]: https://en.wikipedia.org/wiki/Linux_kernel_interfaces [oci]: https://www.opencontainers.org +[runtime-spec]: https://github.com/opencontainers/runtime-spec +[seccomp]: https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt +[selinux]: https://selinuxproject.org +[uml]: http://user-mode-linux.sourceforge.net/ +[xen]: https://www.xenproject.org diff --git a/g3doc/Rule-Based-Execution.png b/g3doc/Rule-Based-Execution.png new file mode 100644 index 000000000..b42654a90 Binary files /dev/null and b/g3doc/Rule-Based-Execution.png differ diff --git a/g3doc/Rule-Based-Execution.svg b/g3doc/Rule-Based-Execution.svg new file mode 100644 index 000000000..bd6717043 --- /dev/null +++ b/g3doc/Rule-Based-Execution.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/g3doc/Sentry-Gofer.png b/g3doc/Sentry-Gofer.png new file mode 100644 index 000000000..ca2c27ef7 Binary files /dev/null and b/g3doc/Sentry-Gofer.png differ diff --git a/g3doc/Sentry-Gofer.svg b/g3doc/Sentry-Gofer.svg new file mode 100644 index 000000000..5c10750d2 --- /dev/null +++ b/g3doc/Sentry-Gofer.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/g3doc/architecture_guide/BUILD b/g3doc/architecture_guide/BUILD index 72038305b..404f627a4 100644 --- a/g3doc/architecture_guide/BUILD +++ b/g3doc/architecture_guide/BUILD @@ -5,31 +5,13 @@ package( licenses = ["notice"], ) -doc( - name = "index", - src = "README.md", - category = "Architecture Guide", - data = [ - "Layers.png", - "Layers.svg", - "Machine-Virtualization.png", - "Machine-Virtualization.svg", - "Rule-Based-Execution.png", - "Rule-Based-Execution.svg", - "Sentry-Gofer.png", - "Sentry-Gofer.svg", - ], - permalink = "/docs/architecture_guide/", - weight = "0", -) - doc( name = "platforms", src = "platforms.md", category = "Architecture Guide", data = [ - "Sentry-Gofer.png", - "Sentry-Gofer.svg", + "platforms.png", + "platforms.svg", ], permalink = "/docs/architecture_guide/platforms/", weight = "40", @@ -39,6 +21,10 @@ doc( name = "resources", src = "resources.md", category = "Architecture Guide", + data = [ + "resources.png", + "resources.svg", + ], permalink = "/docs/architecture_guide/resources/", weight = "30", ) @@ -48,8 +34,8 @@ doc( src = "security.md", category = "Architecture Guide", data = [ - "Layers.png", - "Layers.svg", + "security.png", + "security.svg", ], permalink = "/docs/architecture_guide/security/", weight = "10", diff --git a/g3doc/architecture_guide/Layers.png b/g3doc/architecture_guide/Layers.png deleted file mode 100644 index 308c6c451..000000000 Binary files a/g3doc/architecture_guide/Layers.png and /dev/null differ diff --git a/g3doc/architecture_guide/Layers.svg b/g3doc/architecture_guide/Layers.svg deleted file mode 100644 index 0a366f841..000000000 --- a/g3doc/architecture_guide/Layers.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/g3doc/architecture_guide/Machine-Virtualization.png b/g3doc/architecture_guide/Machine-Virtualization.png deleted file mode 100644 index 1ba2ed6b2..000000000 Binary files a/g3doc/architecture_guide/Machine-Virtualization.png and /dev/null differ diff --git a/g3doc/architecture_guide/Machine-Virtualization.svg b/g3doc/architecture_guide/Machine-Virtualization.svg deleted file mode 100644 index 5352da07b..000000000 --- a/g3doc/architecture_guide/Machine-Virtualization.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/g3doc/architecture_guide/README.md b/g3doc/architecture_guide/README.md deleted file mode 100644 index ab9ef7174..000000000 --- a/g3doc/architecture_guide/README.md +++ /dev/null @@ -1,83 +0,0 @@ -# Overview - -gVisor provides a virtualized environment in order to sandbox untrusted -containers. The system interfaces normally implemented by the host kernel are -moved into a distinct, per-sandbox user space kernel in order to minimize the -risk of an exploit. gVisor does not introduce large fixed overheads however, and -still retains a process-like model with respect to resource utilization. - -## How is this different? - -Two other approaches are commonly taken to provide stronger isolation than -native containers. - -**Machine-level virtualization**, such as [KVM][kvm] and [Xen][xen], exposes -virtualized hardware to a guest kernel via a Virtual Machine Monitor (VMM). This -virtualized hardware is generally enlightened (paravirtualized) and additional -mechanisms can be used to improve the visibility between the guest and host -(e.g. balloon drivers, paravirtualized spinlocks). Running containers in -distinct virtual machines can provide great isolation, compatibility and -performance (though nested virtualization may bring challenges in this area), -but for containers it often requires additional proxies and agents, and may -require a larger resource footprint and slower start-up times. - -![Machine-level virtualization](Machine-Virtualization.png "Machine-level virtualization") - -**Rule-based execution**, such as [seccomp][seccomp], [SELinux][selinux] and -[AppArmor][apparmor], allows the specification of a fine-grained security policy -for an application or container. These schemes typically rely on hooks -implemented inside the host kernel to enforce the rules. If the surface can be -made small enough (i.e. a sufficiently complete policy defined), then this is an -excellent way to sandbox applications and maintain native performance. However, -in practice it can be extremely difficult (if not impossible) to reliably define -a policy for arbitrary, previously unknown applications, making this approach -challenging to apply universally. - -![Rule-based execution](Rule-Based-Execution.png "Rule-based execution") - -Rule-based execution is often combined with additional layers for -defense-in-depth. - -**gVisor** provides a third isolation mechanism, distinct from those above. - -gVisor intercepts application system calls and acts as the guest kernel, without -the need for translation through virtualized hardware. gVisor may be thought of -as either a merged guest kernel and VMM, or as seccomp on steroids. This -architecture allows it to provide a flexible resource footprint (i.e. one based -on threads and memory mappings, not fixed guest physical resources) while also -lowering the fixed costs of virtualization. However, this comes at the price of -reduced application compatibility and higher per-system call overhead. - -![gVisor](Layers.png "gVisor") - -On top of this, gVisor employs rule-based execution to provide defense-in-depth -(details below). - -gVisor's approach is similar to [User Mode Linux (UML)][uml], although UML -virtualizes hardware internally and thus provides a fixed resource footprint. - -Each of the above approaches may excel in distinct scenarios. For example, -machine-level virtualization will face challenges achieving high density, while -gVisor may provide poor performance for system call heavy workloads. - -### Why Go? - -gVisor is written in [Go][golang] in order to avoid security pitfalls that can -plague kernels. With Go, there are strong types, built-in bounds checks, no -uninitialized variables, no use-after-free, no stack overflow, and a built-in -race detector. (The use of Go has its challenges too, and isn't free.) - -### What about Gofers? - - - -Gofers mediate file system interactions, and are used to provide additional -isolation. For more details, see the [Platform Guide](./platforms.md). - -[apparmor]: https://wiki.ubuntu.com/AppArmor -[golang]: https://golang.org -[kvm]: https://www.linux-kvm.org -[seccomp]: https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt -[selinux]: https://selinuxproject.org -[uml]: http://user-mode-linux.sourceforge.net/ -[xen]: https://www.xenproject.org diff --git a/g3doc/architecture_guide/Rule-Based-Execution.png b/g3doc/architecture_guide/Rule-Based-Execution.png deleted file mode 100644 index b42654a90..000000000 Binary files a/g3doc/architecture_guide/Rule-Based-Execution.png and /dev/null differ diff --git a/g3doc/architecture_guide/Rule-Based-Execution.svg b/g3doc/architecture_guide/Rule-Based-Execution.svg deleted file mode 100644 index bd6717043..000000000 --- a/g3doc/architecture_guide/Rule-Based-Execution.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/g3doc/architecture_guide/Sentry-Gofer.png b/g3doc/architecture_guide/Sentry-Gofer.png deleted file mode 100644 index ca2c27ef7..000000000 Binary files a/g3doc/architecture_guide/Sentry-Gofer.png and /dev/null differ diff --git a/g3doc/architecture_guide/Sentry-Gofer.svg b/g3doc/architecture_guide/Sentry-Gofer.svg deleted file mode 100644 index 5c10750d2..000000000 --- a/g3doc/architecture_guide/Sentry-Gofer.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/g3doc/architecture_guide/performance.md b/g3doc/architecture_guide/performance.md index 3862d78ee..39dbb0045 100644 --- a/g3doc/architecture_guide/performance.md +++ b/g3doc/architecture_guide/performance.md @@ -13,12 +13,13 @@ forms: additional cycles and memory usage, which may manifest as increased latency, reduced throughput or density, or not at all. In general, these costs come from two different sources. -First, the existence of the [Sentry](../) means that additional memory will be -required, and application system calls must traverse additional layers of -software. The design emphasizes [security](../security/) and therefore we chose -to use a language for the Sentry that provides benefits in this domain but may -not yet offer the raw performance of other choices. Costs imposed by these -design choices are **structural costs**. +First, the existence of the [Sentry](../README.md#sentry) means that additional +memory will be required, and application system calls must traverse additional +layers of software. The design emphasizes +[security](/docs/architecture_guide/security/) and therefore we chose to use a +language for the Sentry that provides benefits in this domain but may not yet +offer the raw performance of other choices. Costs imposed by these design +choices are **structural costs**. Second, as gVisor is an independent implementation of the system call surface, many of the subsystems or specific calls are not as optimized as more mature @@ -50,7 +51,7 @@ Virtual Machines (VMs) with the following specifications: Through this document, `runsc` is used to indicate the runtime provided by gVisor. When relevant, we use the name `runsc-platform` to describe a specific -[platform choice](../platforms/). +[platform choice](/docs/architecture_guide/platforms/). **Except where specified, all tests below are conducted with the `ptrace` platform. The `ptrace` platform works everywhere and does not require hardware @@ -131,11 +132,11 @@ full start-up and run time for the workload, which trains a model. ## System calls Some **structural costs** of gVisor are heavily influenced by the -[platform choice](../platforms/), which implements system call interception. -Today, gVisor supports a variety of platforms. These platforms present distinct -performance, compatibility and security trade-offs. For example, the KVM -platform has low overhead system call interception but runs poorly with nested -virtualization. +[platform choice](/docs/architecture_guide/platforms/), which implements system +call interception. Today, gVisor supports a variety of platforms. These +platforms present distinct performance, compatibility and security trade-offs. +For example, the KVM platform has low overhead system call interception but runs +poorly with nested virtualization. {% include graph.html id="syscall" url="/performance/syscall.csv" title="perf.py syscall --runtime=runc --runtime=runsc-ptrace --runtime=runsc-kvm" y_min="100" @@ -163,7 +164,8 @@ overhead. Some of these costs above are **structural costs**, and `redis` is likely to remain a challenging performance scenario. However, optimizing the -[platform](../platforms/) will also have a dramatic impact. +[platform](/docs/architecture_guide/platforms/) will also have a dramatic +impact. ## Start-up time @@ -184,7 +186,7 @@ similarly loads a number of modules and binds an HTTP server. > Note: most of the time overhead above is associated Docker itself. This is > evident with the empty `runc` benchmark. To avoid these costs with `runsc`, > you may also consider using `runsc do` mode or invoking the -> [OCI runtime](../../user_guide/quick_start/oci/) directly. +> [OCI runtime](../user_guide/quick_start/oci.md) directly. ## Network @@ -222,8 +224,9 @@ In terms of raw disk I/O, gVisor does not introduce significant fundamental overhead. For general file operations, gVisor introduces a small fixed overhead for data that transitions across the sandbox boundary. This manifests as **structural costs** in some cases, since these operations must be routed -through the [Gofer](../) as a result of our [security model](../security/), but -in most cases are dominated by **implementation costs**, due to an internal +through the [Gofer](../README.md#gofer) as a result of our +[Security Model](/docs/architecture_guide/security/), but in most cases are +dominated by **implementation costs**, due to an internal [Virtual File System][vfs] (VFS) implementation that needs improvement. {% include graph.html id="fio-bw" url="/performance/fio.csv" title="perf.py fio diff --git a/g3doc/architecture_guide/platforms.md b/g3doc/architecture_guide/platforms.md index 6e63da8ce..d112c9a28 100644 --- a/g3doc/architecture_guide/platforms.md +++ b/g3doc/architecture_guide/platforms.md @@ -1,86 +1,61 @@ # Platform Guide -A gVisor sandbox consists of multiple processes when running. These processes -collectively comprise a shared environment in which one or more containers can -be run. +[TOC] -Each sandbox has its own isolated instance of: - -* The **Sentry**, A user-space kernel that runs the container and intercepts - and responds to system calls made by the application. - -Each container running in the sandbox has its own isolated instance of: - -* A **Gofer** which provides file system access to the container. - -![gVisor architecture diagram](Sentry-Gofer.png "gVisor architecture diagram") - -## runsc - -The entrypoint to running a sandboxed container is the `runsc` executable. -`runsc` implements the [Open Container Initiative (OCI)][oci] runtime -specification. This means that OCI compatible _filesystem bundles_ can be run by -`runsc`. Filesystem bundles are comprised of a `config.json` file containing -container configuration, and a root filesystem for the container. Please see the -[OCI runtime spec][runtime-spec] for more information on filesystem bundles. -`runsc` implements multiple commands that perform various functions such as -starting, stopping, listing, and querying the status of containers. +gVisor requires a platform to implement interception of syscalls, basic context +switching, and memory mapping functionality. Internally, gVisor uses an +abstraction sensibly called [Platform][platform]. A simplified version of this +interface looks like: -## Sentry +```golang +type Platform interface { + NewAddressSpace() (AddressSpace, error) + NewContext() Context +} -The Sentry is the largest component of gVisor. It can be thought of as a -userspace OS kernel. The Sentry implements all the kernel functionality needed -by the untrusted application. It implements all of the supported system calls, -signal delivery, memory management and page faulting logic, the threading model, -and more. +type Context interface { + Switch(as AddressSpace, ac arch.Context) (..., error) +} -When the untrusted application makes a system call, the currently used platform -redirects the call to the Sentry, which will do the necessary work to service -it. It is important to note that the Sentry will not simply pass through system -calls to the host kernel. As a userspace application, the Sentry will make some -host system calls to support its operation, but it will not allow the -application to directly control the system calls it makes. +type AddressSpace interface { + MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, ...) error + Unmap(addr usermem.Addr, length uint64) +} +``` -The Sentry aims to present an equivalent environment to (upstream) Linux v4.4. +There are a number of different ways to implement this interface that come with +various trade-offs, generally around performance and hardware requirements. -File system operations that extend beyond the sandbox (not internal /proc files, -pipes, etc) are sent to the Gofer, described below. +## Implementations -## Platforms +The choice of platform depends on the context in which `runsc` is executing. In +general, virtualized platforms may be limited to platforms that do not require +hardware virtualized support (since the hardware is already in use): -gVisor requires a platform to implement interception of syscalls, basic context -switching, and memory mapping functionality. +![Platforms](platforms.png "Platform examples.") ### ptrace -The ptrace platform uses `PTRACE_SYSEMU` to execute user code without allowing -it to execute host system calls. This platform can run anywhere that ptrace -works (even VMs without nested virtualization). - -### KVM (experimental) +The ptrace platform uses [PTRACE_SYSEMU][ptrace] to execute user code without +allowing it to execute host system calls. This platform can run anywhere that +`ptrace` works (even VMs without nested virtualization), which is ubiquitous. -The KVM platform allows the Sentry to act as both guest OS and VMM, switching -back and forth between the two worlds seamlessly. The KVM platform can run on -bare-metal or in a VM with nested virtualization enabled. While there is no -virtualized hardware layer -- the sandbox retains a process model -- gVisor -leverages virtualization extensions available on modern processors in order to -improve isolation and performance of address space switches. +Unfortunately, the ptrace platform has high context switch overhead, so system +call-heavy applications may pay a [performance penalty](./performance.md). -## Gofer +### KVM -The Gofer is a normal host Linux process. The Gofer is started with each sandbox -and connected to the Sentry. The Sentry process is started in a restricted -seccomp container without access to file system resources. The Gofer provides -the Sentry access to file system resources via the 9P protocol and provides an -additional level of isolation. +The KVM platform uses the kernel's [KVM][kvm] functionality to allow the Sentry +to act as both guest OS and VMM. The KVM platform can run on bare-metal or in a +VM with nested virtualization enabled. While there is no virtualized hardware +layer -- the sandbox retains a process model -- gVisor leverages virtualization +extensions available on modern processors in order to improve isolation and +performance of address space switches. -## Application +## Changing Platforms -The application (aka the untrusted application) is a normal Linux binary -provided to gVisor in an OCI runtime bundle. gVisor aims to provide an -environment equivalent to Linux v4.4, so applications should be able to run -unmodified. However, gVisor does not presently implement every system call, -/proc file, or /sys file so some incompatibilities may occur. +See [Changing Platforms](../user_guide/platforms.md). -[oci]: https://www.opencontainers.org -[runtime-spec]: https://github.com/opencontainers/runtime-spec +[kvm]: https://www.kernel.org/doc/Documentation/virtual/kvm/api.txt +[platform]: https://cs.opensource.google/gvisor/gvisor/+/release-20190304.1:pkg/sentry/platform/platform.go;l=33 +[ptrace]: http://man7.org/linux/man-pages/man2/ptrace.2.html diff --git a/g3doc/architecture_guide/platforms.png b/g3doc/architecture_guide/platforms.png new file mode 100644 index 000000000..005d56feb Binary files /dev/null and b/g3doc/architecture_guide/platforms.png differ diff --git a/g3doc/architecture_guide/platforms.svg b/g3doc/architecture_guide/platforms.svg new file mode 100644 index 000000000..b0bac9ba7 --- /dev/null +++ b/g3doc/architecture_guide/platforms.svg @@ -0,0 +1,334 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + gVisor + workload + host + + + + gVisor + workload + + KVM + + + gVisor + workload + + ptrace + + ptrace + VM + + guest + + + gVisor + workload + + ptrace + + diff --git a/g3doc/architecture_guide/resources.md b/g3doc/architecture_guide/resources.md index 894f995ae..1dec37bd1 100644 --- a/g3doc/architecture_guide/resources.md +++ b/g3doc/architecture_guide/resources.md @@ -10,9 +10,10 @@ sandbox to be highly dynamic in terms of resource usage: spanning a large number of cores and large amount of memory when busy, and yielding those resources back to the host when not. -Some of the details here may depend on the [platform](../platforms/), but in -general this page describes the resource model used by gVisor. If you're not -familiar with the terms here, uou may want to start with the [Overview](../). +In order words, the shape of the sandbox should closely track the shape of the +sandboxed process: + +![Resource model](resources.png "Workloads of different shapes.") ## Processes @@ -23,9 +24,9 @@ the sandbox (e.g. via a [Docker exec][exec]). ## Networking -Similarly to processes, the sandbox attaches a network endpoint to the system, -but runs it's own network stack. All network resources, other than packets in -flight, exist only inside the sandbox, bound by relevant resource limits. +The sandbox attaches a network endpoint to the system, but runs it's own network +stack. All network resources, other than packets in flight on the host, exist +only inside the sandbox, bound by relevant resource limits. You can interact with network endpoints exposed by the sandbox, just as you would any other container, but network introspection similarly requires entering @@ -33,15 +34,14 @@ the sandbox. ## Files -Files may be backed by different implementations. For host-native files (where a -file descriptor is available), the Gofer may return a file descriptor to the -Sentry via [SCM_RIGHTS][scmrights][^1]. +Files in the sandbox may be backed by different implementations. For host-native +files (where a file descriptor is available), the Gofer may return a file +descriptor to the Sentry via [SCM_RIGHTS][scmrights][^1]. These files may be read from and written to through standard system calls, and also mapped into the associated application's address space. This allows the same host memory to be shared across multiple sandboxes, although this mechanism -does not preclude the use of side-channels (see the -[security model](../security/)). +does not preclude the use of side-channels (see [Security Model](./security.md). Note that some file systems exist only within the context of the sandbox. For example, in many cases a `tmpfs` mount will be available at `/tmp` or @@ -64,8 +64,9 @@ scheduling decisions about all application threads. ## Time Time in the sandbox is provided by the Sentry, through its own [vDSO][vdso] and -timekeeping implementation. This is divorced from the host time, and no state is -shared with the host, although the time will be initialized with the host clock. +time-keeping implementation. This is distinct from the host time, and no state +is shared with the host, although the time will be initialized with the host +clock. The Sentry runs timers to note the passage of time, much like a kernel running on hardware (though the timers are software timers, in this case). These timers diff --git a/g3doc/architecture_guide/resources.png b/g3doc/architecture_guide/resources.png new file mode 100644 index 000000000..f715008ec Binary files /dev/null and b/g3doc/architecture_guide/resources.png differ diff --git a/g3doc/architecture_guide/resources.svg b/g3doc/architecture_guide/resources.svg new file mode 100644 index 000000000..fd7805d90 --- /dev/null +++ b/g3doc/architecture_guide/resources.svg @@ -0,0 +1,208 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + gVisor + gVisor + gVisor + workload + workload + workload + host + + diff --git a/g3doc/architecture_guide/security.md b/g3doc/architecture_guide/security.md index f78586291..b99b86332 100644 --- a/g3doc/architecture_guide/security.md +++ b/g3doc/architecture_guide/security.md @@ -86,15 +86,17 @@ a substitute for a secure architecture*. ## Goals: Limiting Exposure -gVisor’s primary design goal is to minimize the System API attack vector while -still providing a process model. There are two primary security principles that -inform this design. First, the application’s direct interactions with the host -System API are intercepted by the Sentry, which implements the System API -instead. Second, the System API accessible to the Sentry itself is minimized to -a safer, restricted set. The first principle minimizes the possibility of direct -exploitation of the host System API by applications, and the second principle -minimizes indirect exploitability, which is the exploitation by an exploited or -buggy Sentry (e.g. chaining an exploit). +![Threat model](security.png "Threat model.") + +gVisor’s primary design goal is to minimize the System API attack vector through +multiple layers of defense, while still providing a process model. There are two +primary security principles that inform this design. First, the application’s +direct interactions with the host System API are intercepted by the Sentry, +which implements the System API instead. Second, the System API accessible to +the Sentry itself is minimized to a safer, restricted set. The first principle +minimizes the possibility of direct exploitation of the host System API by +applications, and the second principle minimizes indirect exploitability, which +is the exploitation by an exploited or buggy Sentry (e.g. chaining an exploit). The first principle is similar to the security basis for a Virtual Machine (VM). With a VM, an application’s interactions with the host are replaced by @@ -210,9 +212,9 @@ crashes are recorded and triaged to similarly identify material issues. ### Is this more or less secure than a Virtual Machine? The security of a VM depends to a large extent on what is exposed from the host -kernel and user space support code. For example, device emulation code in the +kernel and userspace support code. For example, device emulation code in the host kernel (e.g. APIC) or optimizations (e.g. vhost) can be more complex than a -simple system call, and exploits carry the same risks. Similarly, the user space +simple system call, and exploits carry the same risks. Similarly, the userspace support code is frequently unsandboxed, and exploits, while rare, may allow unfettered access to the system. @@ -245,8 +247,8 @@ In gVisor, the platforms that use ptrace operate differently. The stubs that are traced are never allowed to continue execution into the host kernel and complete a call directly. Instead, all system calls are interpreted and handled by the Sentry itself, who reflects resulting register state back into the tracee before -continuing execution in user space. This is very similar to the mechanism used -by User-Mode Linux (UML). +continuing execution in userspace. This is very similar to the mechanism used by +User-Mode Linux (UML). [dirtycow]: https://en.wikipedia.org/wiki/Dirty_COW [clang]: https://en.wikipedia.org/wiki/C_(programming_language) diff --git a/g3doc/architecture_guide/security.png b/g3doc/architecture_guide/security.png new file mode 100644 index 000000000..c29befbf6 Binary files /dev/null and b/g3doc/architecture_guide/security.png differ diff --git a/g3doc/architecture_guide/security.svg b/g3doc/architecture_guide/security.svg new file mode 100644 index 000000000..0575e2dec --- /dev/null +++ b/g3doc/architecture_guide/security.svg @@ -0,0 +1,153 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/g3doc/user_guide/filesystem.md b/g3doc/user_guide/filesystem.md index 6c69f42a1..cd00762dd 100644 --- a/g3doc/user_guide/filesystem.md +++ b/g3doc/user_guide/filesystem.md @@ -4,8 +4,8 @@ gVisor accesses the filesystem through a file proxy, called the Gofer. The gofer runs as a separate process, that is isolated from the sandbox. Gofer instances -communicate with their respective sentry using the 9P protocol. For a more -detailed explanation see [Overview > Gofer](../../architecture_guide/#gofer). +communicate with their respective sentry using the 9P protocol. For another +explanation see [What is gVisor?](../README.md). ## Sandbox overlay diff --git a/g3doc/user_guide/platforms.md b/g3doc/user_guide/platforms.md index eefb6b222..752025881 100644 --- a/g3doc/user_guide/platforms.md +++ b/g3doc/user_guide/platforms.md @@ -1,56 +1,27 @@ -# Platforms (KVM) +# Changing Platforms [TOC] -This document will help you set up your system to use a different gVisor -platform. +This guide described how to change the +[platform](../architecture_guide/platforms.md) used by `runsc`. -## What is a Platform? +## Prerequisites -gVisor requires a *platform* to implement interception of syscalls, basic -context switching, and memory mapping functionality. These are described in more -depth in the [Platform Design](../../architecture_guide/platforms/). +If you intend to run the KVM platform, you will also to have KVM installed on +your system. If you are running a Debian based system like Debian or Ubuntu you +can usually do this by ensuring the module is loaded, and permissions are +appropriately set on the `/dev/kvm` device. -## Selecting a Platform - -The platform is selected by the `--platform` command line flag passed to -`runsc`. By default, the ptrace platform is selected. To select a different -platform, modify your Docker configuration (`/etc/docker/daemon.json`) to pass -this argument: - -```json -{ - "runtimes": { - "runsc": { - "path": "/usr/local/bin/runsc", - "runtimeArgs": [ - "--platform=kvm" - ] - } - } -} -``` - -You must restart the Docker daemon after making changes to this file, typically -this is done via `systemd`: +If you have an Intel CPU: ```bash -sudo systemctl restart docker +sudo modprobe kvm-intel && sudo chmod a+rw /dev/kvm ``` -## Example: Using the KVM Platform - -The KVM platform is currently experimental; however, it provides several -benefits over the default ptrace platform. - -### Prerequisites - -You will also to have KVM installed on your system. If you are running a Debian -based system like Debian or Ubuntu you can usually do this by installing the -`qemu-kvm` package. +If you have an AMD CPU: ```bash -sudo apt-get install qemu-kvm +sudo modprobe kvm-amd && sudo chmod a+rw /dev/kvm ``` If you are using a virtual machine you will need to make sure that nested @@ -68,31 +39,22 @@ cause of security issues (e.g. [CVE-2018-12904](https://nvd.nist.gov/vuln/detail/CVE-2018-12904)). It is not recommended for production.*** -### Configuring Docker - -Per above, you will need to configure Docker to use `runsc` with the KVM -platform. You will remember from the Docker Quick Start that you configured -Docker to use `runsc` as the runtime. Docker allows you to add multiple runtimes -to the Docker configuration. +## Configuring Docker -Add a new entry for the KVM platform entry to your Docker configuration -(`/etc/docker/daemon.json`) in order to provide the `--platform=kvm` runtime -argument. - -In the end, the file should look something like: +The platform is selected by the `--platform` command line flag passed to +`runsc`. By default, the ptrace platform is selected. For example, to select the +KVM platform, modify your Docker configuration (`/etc/docker/daemon.json`) to +pass the `--platform` argument: ```json { "runtimes": { "runsc": { - "path": "/usr/local/bin/runsc" - }, - "runsc-kvm": { "path": "/usr/local/bin/runsc", "runtimeArgs": [ "--platform=kvm" ] - } + } } } ``` @@ -104,13 +66,27 @@ this is done via `systemd`: sudo systemctl restart docker ``` -## Running a container +Note that you may configure multiple runtimes using different platforms. For +example, the following configuration has one configuration for ptrace and one +for the KVM platform: -Now run your container using the `runsc-kvm` runtime. This will run the -container using the KVM platform: - -```bash -docker run --runtime=runsc-kvm --rm hello-world +```json +{ + "runtimes": { + "runsc-ptrace": { + "path": "/usr/local/bin/runsc", + "runtimeArgs": [ + "--platform=ptrace" + ] + }, + "runsc-kvm": { + "path": "/usr/local/bin/runsc", + "runtimeArgs": [ + "--platform=kvm" + ] + } + } +} ``` [nested-azure]: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/nested-virtualization diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go index 92d062513..95dfd1e90 100644 --- a/pkg/sentry/arch/syscalls_arm64.go +++ b/pkg/sentry/arch/syscalls_arm64.go @@ -23,7 +23,7 @@ const restartSyscallNr = uintptr(128) // // In linux, at the entry of the syscall handler(el0_svc_common()), value of R0 // is saved to the pt_regs.orig_x0 in kernel code. But currently, the orig_x0 -// was not accessible to the user space application, so we have to do the same +// was not accessible to the userspace application, so we have to do the same // operation in the sentry code to save the R0 value into the App context. func (c *context64) SyscallSaveOrig() { c.OrigR0 = c.Regs.Regs[0] diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 5a1d4fd57..aacf28da2 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -144,7 +144,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume if v > math.MaxInt32 { v = math.MaxInt32 // Silently truncate. } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index c9db78e06..a5903b0b5 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -199,10 +199,10 @@ func (t *Task) doSyscall() taskRunState { // // On x86, register rax was shared by syscall number and return // value, and at the entry of the syscall handler, the rax was - // saved to regs.orig_rax which was exposed to user space. + // saved to regs.orig_rax which was exposed to userspace. // But on arm64, syscall number was passed through X8, and the X0 // was shared by the first syscall argument and return value. The - // X0 was saved to regs.orig_x0 which was not exposed to user space. + // X0 was saved to regs.orig_x0 which was not exposed to userspace. // So we have to do the same operation here to save the X0 value // into the task context. t.Arch().SyscallSaveOrig() diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9dea2b5ff..60df51dae 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2718,7 +2718,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) @@ -2787,7 +2787,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc if v > math.MaxInt32 { v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) @@ -2803,7 +2803,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b39ffa9fb..0ab4c3e19 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -235,11 +235,11 @@ type RcvBufAutoTuneParams struct { // was started. MeasureTime time.Time - // CopiedBytes is the number of bytes copied to user space since + // CopiedBytes is the number of bytes copied to userspace since // this measure began. CopiedBytes int - // PrevCopiedBytes is the number of bytes copied to user space in + // PrevCopiedBytes is the number of bytes copied to userspace in // the previous RTT period. PrevCopiedBytes int diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 71735029e..b5ba972f1 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1097,7 +1097,7 @@ func (e *endpoint) initialReceiveWindow() int { } // ModerateRecvBuf adjusts the receive buffer and the advertised window -// based on the number of bytes copied to user space. +// based on the number of bytes copied to userspace. func (e *endpoint) ModerateRecvBuf(copied int) { e.LockUser() defer e.UnlockUser() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0b4512c65..6ef32a1b3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5869,7 +5869,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Invoke the moderation API. This is required for auto-tuning // to happen. This method is normally expected to be invoked // from a higher layer than tcpip.Endpoint. So we simulate - // copying to user-space by invoking it explicitly here. + // copying to userspace by invoking it explicitly here. c.EP.ModerateRecvBuf(totalCopied) // Now send a keep-alive packet to trigger an ACK so that we can diff --git a/runsc/cmd/help.go b/runsc/cmd/help.go index c7d210140..cd85dabbb 100644 --- a/runsc/cmd/help.go +++ b/runsc/cmd/help.go @@ -65,16 +65,10 @@ func (h *Help) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{} switch f.NArg() { case 0: fmt.Fprintf(h.cdr.Output, "Usage: %s \n\n", h.cdr.Name()) - fmt.Fprintf(h.cdr.Output, `runsc is a command line client for running applications packaged in the Open -Container Initiative (OCI) format. Applications run by runsc are run in an -isolated gVisor sandbox that emulates a Linux environment. + fmt.Fprintf(h.cdr.Output, `runsc is the gVisor container runtime. -gVisor is a user-space kernel, written in Go, that implements a substantial -portion of the Linux system call interface. It provides an additional layer -of isolation between running applications and the host operating system. - -Functionality is provided by subcommands. For additonal help on individual -subcommands use "%s %s ". +Functionality is provided by subcommands. For help with a specific subcommand, +use "%s %s ". `, h.cdr.Name(), h.Name()) h.cdr.VisitGroups(func(g *subcommands.CommandGroup) { diff --git a/website/BUILD b/website/BUILD index d6afd5f44..c97b2560b 100644 --- a/website/BUILD +++ b/website/BUILD @@ -138,7 +138,6 @@ docs( "//g3doc:community", "//g3doc:index", "//g3doc:roadmap", - "//g3doc/architecture_guide:index", "//g3doc/architecture_guide:performance", "//g3doc/architecture_guide:platforms", "//g3doc/architecture_guide:resources", diff --git a/website/_layouts/docs.html b/website/_layouts/docs.html index 33ea8e1de..549305089 100644 --- a/website/_layouts/docs.html +++ b/website/_layouts/docs.html @@ -51,7 +51,9 @@ categories: Create issue

{% endif %} +
{{ content }} +
diff --git a/website/_sass/front.scss b/website/_sass/front.scss index 44a7e3473..0e4208f3c 100644 --- a/website/_sass/front.scss +++ b/website/_sass/front.scss @@ -4,12 +4,14 @@ background-repeat: no-repeat; background-size: cover; background-blend-mode: darken; - background-color: rgba(0, 0, 0, 0.1); + background-color: rgba(0, 0, 0, 0.3); p { color: #fff; margin-top: 0; margin-bottom: 0; font-weight: 300; + font-size: 24px; + line-height: 30px; } } diff --git a/website/_sass/style.scss b/website/_sass/style.scss index 520ea469a..4deb945d4 100644 --- a/website/_sass/style.scss +++ b/website/_sass/style.scss @@ -142,3 +142,13 @@ table th { margin-top: 10px; margin-bottom: 20px; } + +.docs-content * img { + display: block; + margin: 20px auto; +} + +.blog-content * img { + display: block; + margin: 20px auto; +} diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md index ed6d97ffe..fbdd511dd 100644 --- a/website/blog/2019-11-18-security-basics.md +++ b/website/blog/2019-11-18-security-basics.md @@ -56,15 +56,9 @@ in combination: redundant walls, scattered draw bridges, small bottle-neck entrances, moats, etc. A simplified version of the design is below -([more detailed version](/docs/architecture_guide/))[^2]: +([more detailed version](/docs/))[^2]: --------------------------------------------------------------------------------- - -![Figure 1](/assets/images/2019-11-18-security-basics-figure1.png) - -Figure 1: Simplified design of gVisor. - --------------------------------------------------------------------------------- +![Figure 1](/assets/images/2019-11-18-security-basics-figure1.png "Simplified design of gVisor.") In order to discuss design principles, the following components are important to know: @@ -134,13 +128,7 @@ minimum level of permission is required for it to perform its function. Specifically, the closer you are to the untrusted application, the less privilege you have. --------------------------------------------------------------------------------- - -![Figure 2](/assets/images/2019-11-18-security-basics-figure2.png) - -Figure 2: runsc components and their privileges. - --------------------------------------------------------------------------------- +![Figure 2](/assets/images/2019-11-18-security-basics-figure2.png "runsc components and their privileges.") This is evident in how runsc (the drop in gVisor binary for Docker/Kubernetes) constructs the sandbox. The Sentry has the least privilege possible (it can't @@ -222,15 +210,7 @@ the host Linux syscalls. In other words, with gVisor, applications get the vast majority (and growing) functionality of Linux containers for only 68 possible syscalls to the Host OS. 350 syscalls to 68 is attack surface reduction. --------------------------------------------------------------------------------- - -![Figure 3](/assets/images/2019-11-18-security-basics-figure3.png) - -Figure 3: Reduction of Attack Surface of the Syscall Table. Note that the -Senty's Syscall Emulation Layer keeps the Containerized Process from ever -calling the Host OS. - --------------------------------------------------------------------------------- +![Figure 3](/assets/images/2019-11-18-security-basics-figure3.png "Reduction of Attack Surface of the Syscall Table. Note that the Senty's Syscall Emulation Layer keeps the Containerized Process from ever calling the Host OS.") ## Secure-by-default diff --git a/website/blog/2020-04-02-networking-security.md b/website/blog/2020-04-02-networking-security.md index 78f0a6714..5a5e38fd7 100644 --- a/website/blog/2020-04-02-networking-security.md +++ b/website/blog/2020-04-02-networking-security.md @@ -69,13 +69,7 @@ a similar syscall). Moreover, because packets typically come from off-host (e.g. the internet), the Host OS's packet processing code has received a lot of scrutiny, hopefully resulting in a high degree of hardening. --------------------------------------------------------------------------------- - -![Figure 1](/assets/images/2020-04-02-networking-security-figure1.png) - -Figure 1: Netstack and gVisor - --------------------------------------------------------------------------------- +![Figure 1](/assets/images/2020-04-02-networking-security-figure1.png "Network and gVisor.") ## Writing a network stack diff --git a/website/index.md b/website/index.md index 95d5d16f0..84f877d49 100644 --- a/website/index.md +++ b/website/index.md @@ -3,10 +3,10 @@
-

gVisor is an application kernel and container runtime providing defense-in-depth for containers anywhere.

+

gVisor is an application kernel for containers that provides efficient defense-in-depth anywhere.

+ Quick start  Learn More  - GitHub 

@@ -19,8 +19,8 @@

Container-native Security

-

By providing each container with its own userspace kernel, gVisor limits - the attack surface of the host. This protection does not limit +

By providing each container with its own application kernel, gVisor + limits the attack surface of the host. This protection does not limit functionality: gVisor runs unmodified binaries and integrates with container orchestration systems, such as Docker and Kubernetes, and supports features such as volumes and sidecars.

@@ -43,7 +43,7 @@ The pluggable platform architecture of gVisor allows it to run anywhere, enabling consistent security policies across multiple environments without having to rearchitect your infrastructure.

- Get Started » + Read More »
-- cgit v1.2.3 From 41da7a568b1e4f46b3bc09724996556fb18b4d16 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Thu, 4 Jun 2020 15:38:33 -0700 Subject: Fix copylocks error about copying IPTables. IPTables.connections contains a sync.RWMutex. Copying it will trigger copylocks analysis. Tested by manually enabling nogo tests. sync.RWMutex is added to IPTables for the additional race condition discovered. PiperOrigin-RevId: 314817019 --- pkg/sentry/socket/netfilter/netfilter.go | 36 ++++++++++++--------------- pkg/sentry/socket/netstack/stack.go | 9 +++---- pkg/tcpip/stack/iptables.go | 42 +++++++++++++++++++++++++++----- pkg/tcpip/stack/iptables_types.go | 15 ++++++++---- pkg/tcpip/stack/stack.go | 23 ++++------------- pkg/tcpip/transport/icmp/endpoint.go | 5 ---- pkg/tcpip/transport/packet/endpoint.go | 5 ---- pkg/tcpip/transport/raw/endpoint.go | 5 ---- pkg/tcpip/transport/tcp/endpoint.go | 5 ---- pkg/tcpip/transport/udp/endpoint.go | 5 ---- runsc/boot/loader.go | 2 +- 11 files changed, 71 insertions(+), 81 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 47ff48c00..66015e2bc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -144,31 +144,27 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen } func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - ipt := stk.IPTables() - table, ok := ipt.Tables[tablename.String()] + table, ok := stk.IPTables().GetTable(tablename.String()) if !ok { return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } -// FillDefaultIPTables sets stack's IPTables to the default tables and -// populates them with metadata. -func FillDefaultIPTables(stk *stack.Stack) { - ipt := stack.DefaultTables() - - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range ipt.Tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func FillIPTablesMetadata(stk *stack.Stack) { + stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { + // In order to fill in the metadata, we have to translate ipt from its + // netstack format to Linux's giant-binary-blob format. + for name, table := range tables { + _, metadata, err := convertNetstackToBinary(name, table) + if err != nil { + panic(fmt.Errorf("Unable to set default IP tables: %v", err)) + } + table.SetMetadata(metadata) + tables[name] = table } - table.SetMetadata(metadata) - ipt.Tables[name] = table - } - - stk.SetIPTables(ipt) + }) } // convertNetstackToBinary converts the iptables as stored in netstack to the @@ -573,15 +569,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stk.IPTables() table.SetMetadata(metadata{ HookEntry: replace.HookEntry, Underflow: replace.Underflow, NumEntries: replace.NumEntries, Size: replace.Size, }) - ipt.Tables[replace.Name.String()] = table - stk.SetIPTables(ipt) + stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f5fa18136..9b44c2b89 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -362,14 +362,13 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (stack.IPTables, error) { +func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillDefaultIPTables sets the stack's iptables to the default tables, which -// allow and do not modify all traffic. -func (s *Stack) FillDefaultIPTables() { - netfilter.FillDefaultIPTables(s.Stack) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func (s *Stack) FillIPTablesMetadata() { + netfilter.FillIPTablesMetadata(s.Stack) } // Resume implements inet.Stack.Resume. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index d989dbe91..4e9b404c8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -43,11 +43,11 @@ const HookUnset = -1 // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() IPTables { +func DefaultTables() *IPTables { // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for // iotas. - return IPTables{ - Tables: map[string]Table{ + return &IPTables{ + tables: map[string]Table{ TablenameNat: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -106,7 +106,7 @@ func DefaultTables() IPTables { UserChains: map[string]int{}, }, }, - Priorities: map[Hook][]string{ + priorities: map[Hook][]string{ Input: []string{TablenameNat, TablenameFilter}, Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, @@ -158,6 +158,36 @@ func EmptyNatTable() Table { } } +// GetTable returns table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + it.mu.RLock() + defer it.mu.RUnlock() + t, ok := it.tables[name] + return t, ok +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) { + it.mu.Lock() + defer it.mu.Unlock() + it.tables[name] = table +} + +// ModifyTables acquires write-lock and calls fn with internal name-to-table +// map. This function can be used to update multiple tables atomically. +func (it *IPTables) ModifyTables(fn func(map[string]Table)) { + it.mu.Lock() + defer it.mu.Unlock() + fn(it.tables) +} + +// GetPriorities returns slice of priorities associated with hook. +func (it *IPTables) GetPriorities(hook Hook) []string { + it.mu.RLock() + defer it.mu.RUnlock() + return it.priorities[hook] +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -184,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr it.connections.HandlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] + for _, tablename := range it.GetPriorities(hook) { + table, _ := it.GetTable(tablename) ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index af72b9c46..4a6a5c6f1 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -16,6 +16,7 @@ package stack import ( "strings" + "sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,17 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table + // mu protects tables and priorities. + mu sync.RWMutex - // Priorities maps each hook to a list of table names. The order of the + // tables maps table names to tables. User tables have arbitrary names. mu + // needs to be locked for accessing. + tables map[string]Table + + // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string + // hook. mu needs to be locked for accessing. + priorities map[Hook][]string connections ConnTrackTable } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8af06cb9a..294ce8775 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -424,12 +424,8 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool - // tablesMu protects iptables. - tablesMu sync.RWMutex - - // tables are the iptables packet filtering and manipulation rules. The are - // protected by tablesMu.` - tables IPTables + // tables are the iptables packet filtering and manipulation rules. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -676,6 +672,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, @@ -1741,18 +1738,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() IPTables { - s.tablesMu.RLock() - t := s.tables - s.tablesMu.RUnlock() - return t -} - -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt IPTables) { - s.tablesMu.Lock() - s.tables = ipt - s.tablesMu.Unlock() +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 29ff68df3..3bc72bc19 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -140,11 +140,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index bab2d63ae..baf08eda6 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -132,11 +132,6 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (stack.IPTables, error) { - return ep.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 25a17940d..21c34fac2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !e.associated { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d048ef90c..edca98160 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1172,11 +1172,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 79faa7869..663af8fec 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -247,11 +247,6 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index f802bc9fb..002479612 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1056,7 +1056,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) } - s.FillDefaultIPTables() + s.FillIPTablesMetadata() return &s, nil } -- cgit v1.2.3 From 526df4f52a07a02687dd43ceb752621a41883f95 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Fri, 5 Jun 2020 13:41:19 -0700 Subject: Fix error code returned due to Port exhaustion. For TCP sockets gVisor incorrectly returns EAGAIN when no ephemeral ports are available to bind during a connect. Linux returns EADDRNOTAVAIL. This change fixes gVisor to return the correct code and adds a test for the same. This change also fixes a minor bug for ping sockets where connect() would fail with EINVAL unless the socket was bound first. Also added tests for testing UDP Port exhaustion and Ping socket port exhaustion. PiperOrigin-RevId: 314988525 --- pkg/sentry/socket/netstack/BUILD | 1 + pkg/sentry/socket/netstack/netstack.go | 9 ++ pkg/tcpip/transport/icmp/endpoint.go | 1 + test/syscalls/BUILD | 16 ++ test/syscalls/linux/BUILD | 35 +++++ test/syscalls/linux/ping_socket.cc | 91 +++++++++++ test/syscalls/linux/poll.cc | 6 +- .../linux/socket_inet_loopback_nogotsan.cc | 171 +++++++++++++++++++++ test/syscalls/linux/socket_ipv4_udp_unbound.cc | 33 ++++ 9 files changed, 360 insertions(+), 3 deletions(-) create mode 100644 test/syscalls/linux/ping_socket.cc create mode 100644 test/syscalls/linux/socket_inet_loopback_nogotsan.cc (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 333e0042e..8f0f5466e 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -50,5 +50,6 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 60df51dae..e1e0c5931 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -33,6 +33,7 @@ import ( "syscall" "time" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" @@ -719,6 +720,14 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool defer s.EventUnregister(&e) if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting { + if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM { + // TCP unlike UDP returns EADDRNOTAVAIL when it can't + // find an available local ephemeral port. + if err == tcpip.ErrNoPortAvailable { + return syserr.ErrAddressNotAvailable + } + } + return syserr.TranslateNetstackError(err) } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 3bc72bc19..57e0a069b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -506,6 +506,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicID := addr.NIC localPort := uint16(0) switch e.state { + case stateInitial: case stateBound, stateConnected: localPort = e.ID.LocalPort if e.BindNICID == 0 { diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 3406a2de8..d68afbe44 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -400,6 +400,14 @@ syscall_test( vfs2 = "True", ) +syscall_test( + size = "medium", + # Takes too long under gotsan to run. + tags = ["nogotsan"], + test = "//test/syscalls/linux:ping_socket_test", + vfs2 = "True", +) + syscall_test( size = "large", add_overlay = True, @@ -697,6 +705,14 @@ syscall_test( test = "//test/syscalls/linux:socket_inet_loopback_test", ) +syscall_test( + size = "large", + shard_count = 50, + # Takes too long for TSAN. Creates a lot of TCP sockets. + tags = ["nogotsan"], + test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test", +) + syscall_test( size = "large", shard_count = 50, diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index f4b5de18d..ae2aa44dc 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1411,6 +1411,21 @@ cc_binary( ], ) +cc_binary( + name = "ping_socket_test", + testonly = 1, + srcs = ["ping_socket.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:file_descriptor", + gtest, + "//test/util:save_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + cc_binary( name = "pipe_test", testonly = 1, @@ -2780,6 +2795,26 @@ cc_binary( ], ) +cc_binary( + name = "socket_inet_loopback_nogotsan_test", + testonly = 1, + srcs = ["socket_inet_loopback_nogotsan.cc"], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + "//test/util:file_descriptor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + gtest, + "//test/util:posix_error", + "//test/util:save_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + cc_binary( name = "socket_netlink_test", testonly = 1, diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc new file mode 100644 index 000000000..a9bfdb37b --- /dev/null +++ b/test/syscalls/linux/ping_socket.cc @@ -0,0 +1,91 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/save_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +class PingSocket : public ::testing::Test { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // The loopback address. + struct sockaddr_in addr_; +}; + +void PingSocket::SetUp() { + // On some hosts ping sockets are restricted to specific groups using the + // sysctl "ping_group_range". + int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP); + if (s < 0 && errno == EPERM) { + GTEST_SKIP(); + } + close(s); + + addr_ = {}; + // Just a random port as the destination port number is irrelevant for ping + // sockets. + addr_.sin_port = 12345; + addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr_.sin_family = AF_INET; +} + +void PingSocket::TearDown() {} + +// Test ICMP port exhaustion returns EAGAIN. +// +// We disable both random/cooperative S/R for this test as it makes way too many +// syscalls. +TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) { + DisableSave ds; + std::vector sockets; + constexpr int kSockets = 65536; + addr_.sin_port = 0; + for (int i = 0; i < kSockets; i++) { + auto s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); + int ret = connect(s.get(), reinterpret_cast(&addr_), + sizeof(addr_)); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); + break; + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc index 1e35a4a8b..7a316427d 100644 --- a/test/syscalls/linux/poll.cc +++ b/test/syscalls/linux/poll.cc @@ -259,9 +259,9 @@ TEST_F(PollTest, Nfds) { TEST_PCHECK(getrlimit(RLIMIT_NOFILE, &rlim) == 0); // gVisor caps the number of FDs that epoll can use beyond RLIMIT_NOFILE. - constexpr rlim_t gVisorMax = 1048576; - if (rlim.rlim_cur > gVisorMax) { - rlim.rlim_cur = gVisorMax; + constexpr rlim_t maxFD = 4096; + if (rlim.rlim_cur > maxFD) { + rlim.rlim_cur = maxFD; TEST_PCHECK(setrlimit(RLIMIT_NOFILE, &rlim) == 0); } diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc new file mode 100644 index 000000000..2324c7f6a --- /dev/null +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -0,0 +1,171 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/save_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +using ::testing::Gt; + +PosixErrorOr AddrPort(int family, sockaddr_storage const& addr) { + switch (family) { + case AF_INET: + return static_cast( + reinterpret_cast(&addr)->sin_port); + case AF_INET6: + return static_cast( + reinterpret_cast(&addr)->sin6_port); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { + switch (family) { + case AF_INET: + reinterpret_cast(addr)->sin_port = port; + return NoError(); + case AF_INET6: + reinterpret_cast(addr)->sin6_port = port; + return NoError(); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +struct TestParam { + TestAddress listener; + TestAddress connector; +}; + +std::string DescribeTestParam(::testing::TestParamInfo const& info) { + return absl::StrCat("Listen", info.param.listener.description, "_Connect", + info.param.connector.description); +} + +using SocketInetLoopbackTest = ::testing::TestWithParam; + +// This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral +// ports are already in use for a given destination ip/port. +// We disable S/R because this test creates a large number of sockets. +TEST_P(SocketInetLoopbackTest, TestTCPPortExhaustion_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 10; + constexpr int kClients = 65536; + + // Create the listening socket. + auto listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + + // Now we keep opening connections till we run out of local ephemeral ports. + // and assert the error we get back. + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector clients; + std::vector servers; + + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast(&conn_addr), + connector.addr_len); + if (ret == 0) { + clients.push_back(std::move(client)); + FileDescriptor server = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + servers.push_back(std::move(server)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRNOTAVAIL)); + break; + } +} + +INSTANTIATE_TEST_SUITE_P( + All, SocketInetLoopbackTest, + ::testing::Values( + // Listeners bound to IPv4 addresses refuse connections using IPv6 + // addresses. + TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()}, + TestParam{V4Any(), V4MappedAny()}, + TestParam{V4Any(), V4MappedLoopback()}, + TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()}, + TestParam{V4Loopback(), V4MappedLoopback()}, + TestParam{V4MappedAny(), V4Any()}, + TestParam{V4MappedAny(), V4Loopback()}, + TestParam{V4MappedAny(), V4MappedAny()}, + TestParam{V4MappedAny(), V4MappedLoopback()}, + TestParam{V4MappedLoopback(), V4Any()}, + TestParam{V4MappedLoopback(), V4Loopback()}, + TestParam{V4MappedLoopback(), V4MappedLoopback()}, + + // Listeners bound to IN6ADDR_ANY accept all connections. + TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()}, + TestParam{V6Any(), V4MappedAny()}, + TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()}, + TestParam{V6Any(), V6Loopback()}, + + // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 + // addresses. + TestParam{V6Loopback(), V6Any()}, + TestParam{V6Loopback(), V6Loopback()}), + DescribeTestParam); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index bc4b07a62..1294d9050 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -2129,6 +2129,39 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { SyscallSucceedsWithValue(kMessageSize)); } +// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports. +// We disable S/R because this test creates a large number of sockets. +TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + constexpr int kClients = 65536; + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(receiver1->get(), reinterpret_cast(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(receiver1->get(), + reinterpret_cast(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + std::vector> sockets; + for (int i = 0; i < kClients; i++) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int ret = connect(s->get(), reinterpret_cast(&addr.addr), + addr.addr_len); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); + break; + } +} + // Test that socket will receive packet info control message. TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. -- cgit v1.2.3 From 67565078bbcdd8f797206d996605df8f6658d00a Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Tue, 9 Jun 2020 18:44:57 -0700 Subject: Implement flock(2) in VFS2 LockFD is the generic implementation that can be embedded in FileDescriptionImpl implementations. Unique lock ID is maintained in vfs.FileDescription and is created on demand. Updates #1480 PiperOrigin-RevId: 315604825 --- pkg/sentry/devices/memdev/full.go | 1 + pkg/sentry/devices/memdev/null.go | 1 + pkg/sentry/devices/memdev/random.go | 1 + pkg/sentry/devices/memdev/zero.go | 1 + pkg/sentry/fs/file.go | 2 +- pkg/sentry/fs/lock/lock.go | 41 +++----- pkg/sentry/fs/lock/lock_set_functions.go | 8 +- pkg/sentry/fs/lock/lock_test.go | 111 +++++++++++----------- pkg/sentry/fsimpl/devpts/BUILD | 1 + pkg/sentry/fsimpl/devpts/devpts.go | 5 +- pkg/sentry/fsimpl/devpts/master.go | 5 + pkg/sentry/fsimpl/devpts/slave.go | 5 + pkg/sentry/fsimpl/eventfd/eventfd.go | 1 + pkg/sentry/fsimpl/ext/BUILD | 1 + pkg/sentry/fsimpl/ext/file_description.go | 1 + pkg/sentry/fsimpl/ext/inode.go | 6 ++ pkg/sentry/fsimpl/ext/regular_file.go | 1 + pkg/sentry/fsimpl/gofer/BUILD | 2 + pkg/sentry/fsimpl/gofer/filesystem.go | 9 +- pkg/sentry/fsimpl/gofer/gofer.go | 23 +++++ pkg/sentry/fsimpl/gofer/special_file.go | 4 +- pkg/sentry/fsimpl/host/BUILD | 1 + pkg/sentry/fsimpl/host/host.go | 8 +- pkg/sentry/fsimpl/kernfs/BUILD | 2 + pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go | 10 +- pkg/sentry/fsimpl/kernfs/fd_impl_util.go | 9 +- pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 5 +- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 13 ++- pkg/sentry/fsimpl/pipefs/BUILD | 1 + pkg/sentry/fsimpl/pipefs/pipefs.go | 6 +- pkg/sentry/fsimpl/proc/BUILD | 1 + pkg/sentry/fsimpl/proc/subtasks.go | 5 +- pkg/sentry/fsimpl/proc/task.go | 5 +- pkg/sentry/fsimpl/proc/task_fds.go | 7 +- pkg/sentry/fsimpl/proc/task_files.go | 10 +- pkg/sentry/fsimpl/proc/tasks.go | 5 +- pkg/sentry/fsimpl/signalfd/signalfd.go | 1 + pkg/sentry/fsimpl/sys/BUILD | 1 + pkg/sentry/fsimpl/sys/sys.go | 7 +- pkg/sentry/fsimpl/timerfd/timerfd.go | 1 + pkg/sentry/fsimpl/tmpfs/filesystem.go | 6 +- pkg/sentry/fsimpl/tmpfs/regular_file.go | 23 ----- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 41 +------- pkg/sentry/kernel/fd_table.go | 15 +-- pkg/sentry/kernel/kernel.go | 5 - pkg/sentry/kernel/pipe/BUILD | 1 + pkg/sentry/kernel/pipe/vfs.go | 13 ++- pkg/sentry/socket/hostinet/BUILD | 1 + pkg/sentry/socket/hostinet/socket_vfs2.go | 3 + pkg/sentry/socket/netlink/BUILD | 1 + pkg/sentry/socket/netlink/socket_vfs2.go | 8 +- pkg/sentry/socket/netstack/BUILD | 1 + pkg/sentry/socket/netstack/netstack_vfs2.go | 3 + pkg/sentry/socket/unix/BUILD | 1 + pkg/sentry/socket/unix/unix_vfs2.go | 7 +- pkg/sentry/syscalls/linux/sys_file.go | 39 ++------ pkg/sentry/syscalls/linux/vfs2/BUILD | 2 + pkg/sentry/syscalls/linux/vfs2/lock.go | 64 +++++++++++++ pkg/sentry/syscalls/linux/vfs2/vfs2.go | 2 +- pkg/sentry/vfs/BUILD | 1 + pkg/sentry/vfs/epoll.go | 1 + pkg/sentry/vfs/file_description.go | 25 ++++- pkg/sentry/vfs/file_description_impl_util.go | 80 ++++++++++++---- pkg/sentry/vfs/file_description_impl_util_test.go | 1 + pkg/sentry/vfs/inotify.go | 1 + test/syscalls/linux/BUILD | 3 + test/syscalls/linux/flock.cc | 75 ++++++++++++--- 67 files changed, 470 insertions(+), 281 deletions(-) create mode 100644 pkg/sentry/syscalls/linux/vfs2/lock.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go index c7e197691..af66fe4dc 100644 --- a/pkg/sentry/devices/memdev/full.go +++ b/pkg/sentry/devices/memdev/full.go @@ -42,6 +42,7 @@ type fullFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD } // Release implements vfs.FileDescriptionImpl.Release. diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go index 33d060d02..92d3d71be 100644 --- a/pkg/sentry/devices/memdev/null.go +++ b/pkg/sentry/devices/memdev/null.go @@ -43,6 +43,7 @@ type nullFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD } // Release implements vfs.FileDescriptionImpl.Release. diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go index acfa23149..6b81da5ef 100644 --- a/pkg/sentry/devices/memdev/random.go +++ b/pkg/sentry/devices/memdev/random.go @@ -48,6 +48,7 @@ type randomFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD // off is the "file offset". off is accessed using atomic memory // operations. diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index 3b1372b9e..c6f15054d 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -44,6 +44,7 @@ type zeroFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD } // Release implements vfs.FileDescriptionImpl.Release. diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 2a278fbe3..ca41520b4 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -146,7 +146,7 @@ func (f *File) DecRef() { f.DecRefWithDestructor(func() { // Drop BSD style locks. lockRng := lock.LockRange{Start: 0, End: lock.LockEOF} - f.Dirent.Inode.LockCtx.BSD.UnlockRegion(lock.UniqueID(f.UniqueID), lockRng) + f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng) // Release resources held by the FileOperations. f.FileOperations.Release() diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go index 926538d90..8a5d9c7eb 100644 --- a/pkg/sentry/fs/lock/lock.go +++ b/pkg/sentry/fs/lock/lock.go @@ -62,7 +62,7 @@ import ( type LockType int // UniqueID is a unique identifier of the holder of a regional file lock. -type UniqueID uint64 +type UniqueID interface{} const ( // ReadLock describes a POSIX regional file lock to be taken @@ -98,12 +98,7 @@ type Lock struct { // If len(Readers) > 0 then HasWriter must be false. Readers map[UniqueID]bool - // HasWriter indicates that this is a write lock held by a single - // UniqueID. - HasWriter bool - - // Writer is only valid if HasWriter is true. It identifies a - // single write lock holder. + // Writer holds the writer unique ID. It's nil if there are no writers. Writer UniqueID } @@ -186,7 +181,6 @@ func makeLock(uid UniqueID, t LockType) Lock { case ReadLock: value.Readers[uid] = true case WriteLock: - value.HasWriter = true value.Writer = uid default: panic(fmt.Sprintf("makeLock: invalid lock type %d", t)) @@ -196,10 +190,7 @@ func makeLock(uid UniqueID, t LockType) Lock { // isHeld returns true if uid is a holder of Lock. func (l Lock) isHeld(uid UniqueID) bool { - if l.HasWriter && l.Writer == uid { - return true - } - return l.Readers[uid] + return l.Writer == uid || l.Readers[uid] } // lock sets uid as a holder of a typed lock on Lock. @@ -214,20 +205,20 @@ func (l *Lock) lock(uid UniqueID, t LockType) { } // We cannot downgrade a write lock to a read lock unless the // uid is the same. - if l.HasWriter { + if l.Writer != nil { if l.Writer != uid { panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer)) } // Ensure that there is only one reader if upgrading. l.Readers = make(map[UniqueID]bool) // Ensure that there is no longer a writer. - l.HasWriter = false + l.Writer = nil } l.Readers[uid] = true return case WriteLock: // If we are already the writer, then this is a no-op. - if l.HasWriter && l.Writer == uid { + if l.Writer == uid { return } // We can only upgrade a read lock to a write lock if there @@ -243,7 +234,6 @@ func (l *Lock) lock(uid UniqueID, t LockType) { } // Ensure that there is only a writer. l.Readers = make(map[UniqueID]bool) - l.HasWriter = true l.Writer = uid default: panic(fmt.Sprintf("lock: invalid lock type %d", t)) @@ -277,9 +267,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool { switch t { case ReadLock: return l.lockable(r, func(value Lock) bool { - // If there is no writer, there's no problem adding - // another reader. - if !value.HasWriter { + // If there is no writer, there's no problem adding another reader. + if value.Writer == nil { return true } // If there is a writer, then it must be the same uid @@ -289,10 +278,9 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool { case WriteLock: return l.lockable(r, func(value Lock) bool { // If there are only readers. - if !value.HasWriter { - // Then this uid can only take a write lock if - // this is a private upgrade, meaning that the - // only reader is uid. + if value.Writer == nil { + // Then this uid can only take a write lock if this is a private + // upgrade, meaning that the only reader is uid. return len(value.Readers) == 1 && value.Readers[uid] } // If the uid is already a writer on this region, then @@ -304,7 +292,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool { } } -// lock returns true if uid took a lock of type t on the entire range of LockRange. +// lock returns true if uid took a lock of type t on the entire range of +// LockRange. // // Preconditions: r.Start <= r.End (will panic otherwise). func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool { @@ -339,7 +328,7 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool { seg, _ = l.SplitUnchecked(seg, r.End) } - // Set the lock on the segment. This is guaranteed to + // Set the lock on the segment. This is guaranteed to // always be safe, given canLock above. value := seg.ValuePtr() value.lock(uid, t) @@ -386,7 +375,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) { value := seg.Value() var remove bool - if value.HasWriter && value.Writer == uid { + if value.Writer == uid { // If we are unlocking a writer, then since there can // only ever be one writer and no readers, then this // lock should always be removed from the set. diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go index 8a3ace0c1..50a16e662 100644 --- a/pkg/sentry/fs/lock/lock_set_functions.go +++ b/pkg/sentry/fs/lock/lock_set_functions.go @@ -44,14 +44,9 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock) return Lock{}, false } } - if val1.HasWriter != val2.HasWriter { + if val1.Writer != val2.Writer { return Lock{}, false } - if val1.HasWriter { - if val1.Writer != val2.Writer { - return Lock{}, false - } - } return val1, true } @@ -62,7 +57,6 @@ func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock) for k, v := range val.Readers { val0.Readers[k] = v } - val0.HasWriter = val.HasWriter val0.Writer = val.Writer return val, val0 diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go index ba002aeb7..fad90984b 100644 --- a/pkg/sentry/fs/lock/lock_test.go +++ b/pkg/sentry/fs/lock/lock_test.go @@ -42,9 +42,6 @@ func equals(e0, e1 []entry) bool { if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) { return false } - if e0[i].Lock.HasWriter != e1[i].Lock.HasWriter { - return false - } if e0[i].Lock.Writer != e1[i].Lock.Writer { return false } @@ -105,7 +102,7 @@ func TestCanLock(t *testing.T) { LockRange: LockRange{2048, 3072}, }, { - Lock: Lock{HasWriter: true, Writer: 1}, + Lock: Lock{Writer: 1}, LockRange: LockRange{3072, 4096}, }, }) @@ -241,7 +238,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, LockEOF}, }, }, @@ -254,7 +251,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, LockEOF}, }, }, @@ -273,7 +270,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{0, 4096}, }, { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -301,7 +298,7 @@ func TestSetLock(t *testing.T) { // 0 4096 max uint64 after: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, 4096}, }, { @@ -318,7 +315,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, LockEOF}, }, }, @@ -550,7 +547,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{1024, 4096}, }, { @@ -594,7 +591,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{1024, 3072}, }, { @@ -633,7 +630,7 @@ func TestSetLock(t *testing.T) { // 0 1024 2048 4096 max uint64 before: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, 1024}, }, { @@ -663,11 +660,11 @@ func TestSetLock(t *testing.T) { // 0 1024 max uint64 after: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{1024, LockEOF}, }, }, @@ -675,28 +672,30 @@ func TestSetLock(t *testing.T) { } for _, test := range tests { - l := fill(test.before) + t.Run(test.name, func(t *testing.T) { + l := fill(test.before) - r := LockRange{Start: test.start, End: test.end} - success := l.lock(test.uid, test.lockType, r) - var got []entry - for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - got = append(got, entry{ - Lock: seg.Value(), - LockRange: seg.Range(), - }) - } + r := LockRange{Start: test.start, End: test.end} + success := l.lock(test.uid, test.lockType, r) + var got []entry + for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + got = append(got, entry{ + Lock: seg.Value(), + LockRange: seg.Range(), + }) + } - if success != test.success { - t.Errorf("%s: setlock(%v, %+v, %d, %d) got success %v, want %v", test.name, test.before, r, test.uid, test.lockType, success, test.success) - continue - } + if success != test.success { + t.Errorf("setlock(%v, %+v, %d, %d) got success %v, want %v", test.before, r, test.uid, test.lockType, success, test.success) + return + } - if success { - if !equals(got, test.after) { - t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after) + if success { + if !equals(got, test.after) { + t.Errorf("got set %+v, want %+v", got, test.after) + } } - } + }) } } @@ -782,7 +781,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, LockEOF}, }, }, @@ -824,7 +823,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, LockEOF}, }, }, @@ -837,7 +836,7 @@ func TestUnlock(t *testing.T) { // 0 4096 max uint64 after: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -876,7 +875,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, LockEOF}, }, }, @@ -889,7 +888,7 @@ func TestUnlock(t *testing.T) { // 0 4096 after: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, 4096}, }, }, @@ -906,7 +905,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{1024, 4096}, }, { @@ -974,7 +973,7 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 max uint64 before: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, 1024}, }, { @@ -991,7 +990,7 @@ func TestUnlock(t *testing.T) { // 0 8 4096 max uint64 after: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, 8}, }, { @@ -1008,7 +1007,7 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 max uint64 before: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, 1024}, }, { @@ -1025,7 +1024,7 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 8192 max uint64 after: []entry{ { - Lock: Lock{HasWriter: true, Writer: 0}, + Lock: Lock{Writer: 0}, LockRange: LockRange{0, 1024}, }, { @@ -1041,19 +1040,21 @@ func TestUnlock(t *testing.T) { } for _, test := range tests { - l := fill(test.before) + t.Run(test.name, func(t *testing.T) { + l := fill(test.before) - r := LockRange{Start: test.start, End: test.end} - l.unlock(test.uid, r) - var got []entry - for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - got = append(got, entry{ - Lock: seg.Value(), - LockRange: seg.Range(), - }) - } - if !equals(got, test.after) { - t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after) - } + r := LockRange{Start: test.start, End: test.end} + l.unlock(test.uid, r) + var got []entry + for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + got = append(got, entry{ + Lock: seg.Value(), + LockRange: seg.Range(), + }) + } + if !equals(got, test.after) { + t.Errorf("got set %+v, want %+v", got, test.after) + } + }) } } diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 585764223..cf440dce8 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -23,6 +23,7 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/unimpl", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index c03c65445..9b0e0cca2 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -116,6 +117,8 @@ type rootInode struct { kernfs.InodeNotSymlink kernfs.OrderedChildren + locks lock.FileLocks + // Keep a reference to this inode's dentry. dentry kernfs.Dentry @@ -183,7 +186,7 @@ func (i *rootInode) masterClose(t *Terminal) { // Open implements kernfs.Inode.Open. func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 7a7ce5d81..1d22adbe3 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -34,6 +35,8 @@ type masterInode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink + locks lock.FileLocks + // Keep a reference to this inode's dentry. dentry kernfs.Dentry @@ -55,6 +58,7 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf inode: mi, t: t, } + fd.LockFD.Init(&mi.locks) if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { mi.DecRef() return nil, err @@ -85,6 +89,7 @@ func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds type masterFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl + vfs.LockFD inode *masterInode t *Terminal diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go index 526cd406c..7fe475080 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/slave.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -33,6 +34,8 @@ type slaveInode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink + locks lock.FileLocks + // Keep a reference to this inode's dentry. dentry kernfs.Dentry @@ -51,6 +54,7 @@ func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs fd := &slaveFileDescription{ inode: si, } + fd.LockFD.Init(&si.locks) if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { si.DecRef() return nil, err @@ -91,6 +95,7 @@ func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds type slaveFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl + vfs.LockFD inode *slaveInode } diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index c573d7935..d12d78b84 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -37,6 +37,7 @@ type EventFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD // queue is used to notify interested parties when the event object // becomes readable or writable. diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index ff861d0fe..973fa0def 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -60,6 +60,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/syscalls/linux", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go index 92f7da40d..90b086468 100644 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -26,6 +26,7 @@ import ( type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl + vfs.LockFD } func (fd *fileDescription) filesystem() *filesystem { diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 485f86f4b..e4b434b13 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -54,6 +55,8 @@ type inode struct { // diskInode gives us access to the inode struct on disk. Immutable. diskInode disklayout.Inode + locks lock.FileLocks + // This is immutable. The first field of the implementations must have inode // as the first field to ensure temporality. impl interface{} @@ -157,6 +160,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt switch in.impl.(type) { case *regularFile: var fd regularFileFD + fd.LockFD.Init(&in.locks) if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } @@ -168,6 +172,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt return nil, syserror.EISDIR } var fd directoryFD + fd.LockFD.Init(&in.locks) if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } @@ -178,6 +183,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt return nil, syserror.ELOOP } var fd symlinkFD + fd.LockFD.Init(&in.locks) fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil default: diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index 30135ddb0..f7015c44f 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -77,6 +77,7 @@ func (in *inode) isRegular() bool { // vfs.FileDescriptionImpl. type regularFileFD struct { fileDescription + vfs.LockFD // off is the file offset. off is accessed using atomic memory operations. off int64 diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index f5f35a3bc..5cdeeaeb5 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -54,6 +54,7 @@ go_library( "//pkg/p9", "//pkg/safemem", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/host", "//pkg/sentry/hostfd", "//pkg/sentry/kernel", @@ -68,6 +69,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 36e0e1856..40933b74b 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -801,6 +801,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return nil, err } fd := ®ularFileFD{} + fd.LockFD.Init(&d.locks) if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ AllowDirectIO: true, }); err != nil { @@ -826,6 +827,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf } } fd := &directoryFD{} + fd.LockFD.Init(&d.locks) if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } @@ -842,7 +844,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf } case linux.S_IFIFO: if d.isSynthetic() { - return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags) + return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags, &d.locks) } } return d.openSpecialFileLocked(ctx, mnt, opts) @@ -902,7 +904,7 @@ retry: return nil, err } } - fd, err := newSpecialFileFD(h, mnt, d, opts.Flags) + fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags) if err != nil { h.close(ctx) return nil, err @@ -989,6 +991,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving var childVFSFD *vfs.FileDescription if useRegularFileFD { fd := ®ularFileFD{} + fd.LockFD.Init(&child.locks) if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{ AllowDirectIO: true, }); err != nil { @@ -1003,7 +1006,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving if fdobj != nil { h.fd = int32(fdobj.Release()) } - fd, err := newSpecialFileFD(h, mnt, child, opts.Flags) + fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags) if err != nil { h.close(ctx) return nil, err diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 3f3bd56f0..0d88a328e 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -45,6 +45,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -52,6 +53,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/usermem" @@ -662,6 +664,8 @@ type dentry struct { // If this dentry represents a synthetic named pipe, pipe is the pipe // endpoint bound to this file. pipe *pipe.VFSPipe + + locks lock.FileLocks } // dentryAttrMask returns a p9.AttrMask enabling all attributes used by the @@ -1366,6 +1370,9 @@ func (d *dentry) decLinks() { type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl + vfs.LockFD + + lockLogging sync.Once } func (fd *fileDescription) filesystem() *filesystem { @@ -1416,3 +1423,19 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name) } + +// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { + fd.lockLogging.Do(func() { + log.Infof("File lock using gofer file handled internally.") + }) + return fd.LockFD.LockBSD(ctx, uid, t, block) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { + fd.lockLogging.Do(func() { + log.Infof("Range lock using gofer file handled internally.") + }) + return fd.LockFD.LockPOSIX(ctx, uid, t, rng, block) +} diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index ff6126b87..289efdd25 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -51,7 +52,7 @@ type specialFileFD struct { off int64 } -func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) { +func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *lock.FileLocks, flags uint32) (*specialFileFD, error) { ftype := d.fileType() seekable := ftype == linux.S_IFREG mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK @@ -60,6 +61,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*speci seekable: seekable, mayBlock: mayBlock, } + fd.LockFD.Init(locks) if mayBlock && h.fd >= 0 { if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil { return nil, err diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index ca0fe6d2b..54f16ad63 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -39,6 +39,7 @@ go_library( "//pkg/sentry/unimpl", "//pkg/sentry/uniqueid", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 18b127521..5ec5100b8 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -34,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -182,6 +183,8 @@ type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink + locks lock.FileLocks + // When the reference count reaches zero, the host fd is closed. refs.AtomicRefCount @@ -468,7 +471,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u return nil, err } // Currently, we only allow Unix sockets to be imported. - return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d) + return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks) } // TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that @@ -478,6 +481,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u fileDescription: fileDescription{inode: i}, termios: linux.DefaultSlaveTermios, } + fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { return nil, err @@ -486,6 +490,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u } fd := &fileDescription{inode: i} + fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { return nil, err @@ -497,6 +502,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl + vfs.LockFD // inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but // cached to reduce indirections and casting. fileDescription does not hold diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index ef34cb28a..0299dbde9 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -49,6 +49,7 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", @@ -67,6 +68,7 @@ go_test( "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/usermem", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 1568a9d49..6418de0a3 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -38,7 +39,8 @@ type DynamicBytesFile struct { InodeNotDirectory InodeNotSymlink - data vfs.DynamicBytesSource + locks lock.FileLocks + data vfs.DynamicBytesSource } var _ Inode = (*DynamicBytesFile)(nil) @@ -55,7 +57,7 @@ func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint // Open implements Inode.Open. func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &DynamicBytesFD{} - if err := fd.Init(rp.Mount(), vfsd, f.data, opts.Flags); err != nil { + if err := fd.Init(rp.Mount(), vfsd, f.data, &f.locks, opts.Flags); err != nil { return nil, err } return &fd.vfsfd, nil @@ -77,13 +79,15 @@ func (*DynamicBytesFile) SetStat(context.Context, *vfs.Filesystem, *auth.Credent type DynamicBytesFD struct { vfs.FileDescriptionDefaultImpl vfs.DynamicBytesFileDescriptionImpl + vfs.LockFD vfsfd vfs.FileDescription inode Inode } // Init initializes a DynamicBytesFD. -func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) error { +func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *lock.FileLocks, flags uint32) error { + fd.LockFD.Init(locks) if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 8284e76a7..33a5968ca 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -42,6 +43,7 @@ import ( type GenericDirectoryFD struct { vfs.FileDescriptionDefaultImpl vfs.DirectoryFileDescriptionDefaultImpl + vfs.LockFD vfsfd vfs.FileDescription children *OrderedChildren @@ -55,9 +57,9 @@ type GenericDirectoryFD struct { // NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its // dentry. -func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) { +func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *lock.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) { fd := &GenericDirectoryFD{} - if err := fd.Init(children, opts); err != nil { + if err := fd.Init(children, locks, opts); err != nil { return nil, err } if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { @@ -69,11 +71,12 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildre // Init initializes a GenericDirectoryFD. Use it when overriding // GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the // correct implementation. -func (fd *GenericDirectoryFD) Init(children *OrderedChildren, opts *vfs.OpenOptions) error { +func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *lock.FileLocks, opts *vfs.OpenOptions) error { if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 { // Can't open directories for writing. return syserror.EISDIR } + fd.LockFD.Init(locks) fd.children = children return nil } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 982daa2e6..0e4927215 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -555,6 +556,8 @@ type StaticDirectory struct { InodeAttrs InodeNoDynamicLookup OrderedChildren + + locks lock.FileLocks } var _ Inode = (*StaticDirectory)(nil) @@ -584,7 +587,7 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint3 // Open implements kernfs.Inode. func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &opts) + fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 412cf6ac9..6749facf7 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -100,8 +101,10 @@ type readonlyDir struct { kernfs.InodeNotSymlink kernfs.InodeNoDynamicLookup kernfs.InodeDirectoryNoNewChildren - kernfs.OrderedChildren + + locks lock.FileLocks + dentry kernfs.Dentry } @@ -117,7 +120,7 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod } func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) if err != nil { return nil, err } @@ -128,10 +131,12 @@ type dir struct { attrs kernfs.InodeNotSymlink kernfs.InodeNoDynamicLookup + kernfs.OrderedChildren + + locks lock.FileLocks fs *filesystem dentry kernfs.Dentry - kernfs.OrderedChildren } func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { @@ -147,7 +152,7 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte } func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD index 5950a2d59..c618dbe6c 100644 --- a/pkg/sentry/fsimpl/pipefs/BUILD +++ b/pkg/sentry/fsimpl/pipefs/BUILD @@ -15,6 +15,7 @@ go_library( "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/usermem", ], diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index cab771211..e4dabaa33 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -81,7 +82,8 @@ type inode struct { kernfs.InodeNotSymlink kernfs.InodeNoopRefCount - pipe *pipe.VFSPipe + locks lock.FileLocks + pipe *pipe.VFSPipe ino uint64 uid auth.KUID @@ -147,7 +149,7 @@ func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth. // Open implements kernfs.Inode.Open. func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags) + return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks) } // NewConnectedPipeFDs returns a pair of FileDescriptions representing the read diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 17c1342b5..351ba4ee9 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -35,6 +35,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/tcpip/header", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index 36a911db4..e2cdb7ee9 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -37,6 +38,8 @@ type subtasksInode struct { kernfs.OrderedChildren kernfs.AlwaysValid + locks lock.FileLocks + fs *filesystem task *kernel.Task pidns *kernel.PIDNamespace @@ -153,7 +156,7 @@ func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) erro // Open implements kernfs.Inode. func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &subtasksFD{task: i.task} - if err := fd.Init(&i.OrderedChildren, &opts); err != nil { + if err := fd.Init(&i.OrderedChildren, &i.locks, &opts); err != nil { return nil, err } if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 482055db1..44078a765 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -38,6 +39,8 @@ type taskInode struct { kernfs.InodeAttrs kernfs.OrderedChildren + locks lock.FileLocks + task *kernel.Task } @@ -103,7 +106,7 @@ func (i *taskInode) Valid(ctx context.Context) bool { // Open implements kernfs.Inode. func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 44ccc9e4a..ef6c1d04f 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -53,6 +54,8 @@ func taskFDExists(t *kernel.Task, fd int32) bool { } type fdDir struct { + locks lock.FileLocks + fs *filesystem task *kernel.Task @@ -143,7 +146,7 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro // Open implements kernfs.Inode. func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) if err != nil { return nil, err } @@ -270,7 +273,7 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, // Open implements kernfs.Inode. func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 2f297e48a..e5eaa91cd 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -775,6 +776,8 @@ type namespaceInode struct { kernfs.InodeNoopRefCount kernfs.InodeNotDirectory kernfs.InodeNotSymlink + + locks lock.FileLocks } var _ kernfs.Inode = (*namespaceInode)(nil) @@ -791,6 +794,7 @@ func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32 func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &namespaceFD{inode: i} i.IncRef() + fd.LockFD.Init(&i.locks) if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } @@ -801,6 +805,7 @@ func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd * // /proc/[pid]/ns/*. type namespaceFD struct { vfs.FileDescriptionDefaultImpl + vfs.LockFD vfsfd vfs.FileDescription inode *namespaceInode @@ -825,8 +830,3 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release() { fd.inode.DecRef() } - -// OnClose implements FileDescriptionImpl. -func (*namespaceFD) OnClose(context.Context) error { - return nil -} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index b51d43954..58c8b9d05 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -43,6 +44,8 @@ type tasksInode struct { kernfs.OrderedChildren kernfs.AlwaysValid + locks lock.FileLocks + fs *filesystem pidns *kernel.PIDNamespace @@ -197,7 +200,7 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback // Open implements kernfs.Inode. func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index d29ef3f83..242ba9b5d 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -31,6 +31,7 @@ type SignalFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD // target is the original signal target task. // diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index a741e2bb6..237f17def 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -15,6 +15,7 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 0af373604..b84463d3a 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -98,8 +99,10 @@ type dir struct { kernfs.InodeNoDynamicLookup kernfs.InodeNotSymlink kernfs.InodeDirectoryNoNewChildren - kernfs.OrderedChildren + + locks lock.FileLocks + dentry kernfs.Dentry } @@ -121,7 +124,7 @@ func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.Set // Open implements kernfs.Inode.Open. func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go index 60c92d626..2dc90d484 100644 --- a/pkg/sentry/fsimpl/timerfd/timerfd.go +++ b/pkg/sentry/fsimpl/timerfd/timerfd.go @@ -32,6 +32,7 @@ type TimerFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD events waiter.Queue timer *ktime.Timer diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index e801680e8..72399b321 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -399,6 +399,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open switch impl := d.inode.impl.(type) { case *regularFile: var fd regularFileFD + fd.LockFD.Init(&d.inode.locks) if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } @@ -414,15 +415,16 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open return nil, syserror.EISDIR } var fd directoryFD + fd.LockFD.Init(&d.inode.locks) if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return &fd.vfsfd, nil case *symlink: - // Can't open symlinks without O_PATH (which is unimplemented). + // TODO(gvisor.dev/issue/2782): Can't open symlinks without O_PATH. return nil, syserror.ELOOP case *namedPipe: - return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags) + return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags, &d.inode.locks) case *deviceFile: return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts) case *socketFile: diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 4f2ae04d2..77447b32c 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -366,28 +365,6 @@ func (fd *regularFileFD) Sync(ctx context.Context) error { return nil } -// LockBSD implements vfs.FileDescriptionImpl.LockBSD. -func (fd *regularFileFD) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error { - return fd.inode().lockBSD(uid, t, block) -} - -// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. -func (fd *regularFileFD) UnlockBSD(ctx context.Context, uid lock.UniqueID) error { - fd.inode().unlockBSD(uid) - return nil -} - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error { - return fd.inode().lockPOSIX(uid, t, rng, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error { - fd.inode().unlockPOSIX(uid, rng) - return nil -} - // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { file := fd.inode().impl.(*regularFile) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 7ce1b86c7..71a7522af 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -36,7 +36,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -311,7 +310,6 @@ type inode struct { ctime int64 // nanoseconds mtime int64 // nanoseconds - // Advisory file locks, which lock at the inode level. locks lock.FileLocks // Inotify watches for this inode. @@ -539,44 +537,6 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu return nil } -// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular. -func (i *inode) lockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { - switch i.impl.(type) { - case *regularFile: - return i.locks.LockBSD(uid, t, block) - } - return syserror.EBADF -} - -// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular. -func (i *inode) unlockBSD(uid fslock.UniqueID) error { - switch i.impl.(type) { - case *regularFile: - i.locks.UnlockBSD(uid) - return nil - } - return syserror.EBADF -} - -// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular. -func (i *inode) lockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { - switch i.impl.(type) { - case *regularFile: - return i.locks.LockPOSIX(uid, t, rng, block) - } - return syserror.EBADF -} - -// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular. -func (i *inode) unlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) error { - switch i.impl.(type) { - case *regularFile: - i.locks.UnlockPOSIX(uid, rng) - return nil - } - return syserror.EBADF -} - // allocatedBlocksForSize returns the number of 512B blocks needed to // accommodate the given size in bytes, as appropriate for struct // stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block @@ -708,6 +668,7 @@ func (i *inode) userXattrSupported() bool { type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl + vfs.LockFD } func (fd *fileDescription) filesystem() *filesystem { diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index dbfcef0fa..b35afafe3 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -80,9 +80,6 @@ type FDTable struct { refs.AtomicRefCount k *Kernel - // uid is a unique identifier. - uid uint64 - // mu protects below. mu sync.Mutex `state:"nosave"` @@ -130,7 +127,7 @@ func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { // drop drops the table reference. func (f *FDTable) drop(file *fs.File) { // Release locks. - file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lock.UniqueID(f.uid), lock.LockRange{0, lock.LockEOF}) + file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF}) // Send inotify events. d := file.Dirent @@ -164,17 +161,9 @@ func (f *FDTable) dropVFS2(file *vfs.FileDescription) { file.DecRef() } -// ID returns a unique identifier for this FDTable. -func (f *FDTable) ID() uint64 { - return f.uid -} - // NewFDTable allocates a new FDTable that may be used by tasks in k. func (k *Kernel) NewFDTable() *FDTable { - f := &FDTable{ - k: k, - uid: atomic.AddUint64(&k.fdMapUids, 1), - } + f := &FDTable{k: k} f.init() return f } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 5efeb3767..bcbeb6a39 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -194,11 +194,6 @@ type Kernel struct { // cpuClockTickerSetting is protected by runningTasksMu. cpuClockTickerSetting ktime.Setting - // fdMapUids is an ever-increasing counter for generating FDTable uids. - // - // fdMapUids is mutable, and is accessed using atomic memory operations. - fdMapUids uint64 - // uniqueID is used to generate unique identifiers. // // uniqueID is mutable, and is accessed using atomic memory operations. diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 7bfa9075a..0db546b98 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 2602bed72..c0e9ee1f4 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -61,11 +62,13 @@ func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe { // // Preconditions: statusFlags should not contain an open access mode. func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) { - return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags) + // Connected pipes share the same locks. + locks := &lock.FileLocks{} + return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) } // Open opens the pipe represented by vp. -func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, error) { +func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *lock.FileLocks) (*vfs.FileDescription, error) { vp.mu.Lock() defer vp.mu.Unlock() @@ -75,7 +78,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s return nil, syserror.EINVAL } - fd := vp.newFD(mnt, vfsd, statusFlags) + fd := vp.newFD(mnt, vfsd, statusFlags, locks) // Named pipes have special blocking semantics during open: // @@ -127,10 +130,11 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s } // Preconditions: vp.mu must be held. -func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) *vfs.FileDescription { +func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *lock.FileLocks) *vfs.FileDescription { fd := &VFSPipeFD{ pipe: &vp.pipe, } + fd.LockFD.Init(locks) fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{ DenyPRead: true, DenyPWrite: true, @@ -159,6 +163,7 @@ type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD pipe *Pipe } diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index e82d6cd1e..60c9896fc 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip/stack", diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 677743113..027add1fd 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -35,6 +36,7 @@ import ( type socketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl + vfs.LockFD // We store metadata for hostinet sockets internally. Technically, we should // access metadata (e.g. through stat, chmod) on the host for correctness, @@ -59,6 +61,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in fd: fd, }, } + s.LockFD.Init(&lock.FileLocks{}) if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 7212d8644..420e573c9 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index b854bf990..8bfee5193 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -40,6 +41,7 @@ type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD socketOpsCommon } @@ -66,7 +68,7 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV return nil, err } - return &SocketVFS2{ + fd := &SocketVFS2{ socketOpsCommon: socketOpsCommon{ ports: t.Kernel().NetlinkPorts(), protocol: protocol, @@ -75,7 +77,9 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV connection: connection, sendBufferSize: defaultSendBufferSize, }, - }, nil + } + fd.LockFD.Init(&lock.FileLocks{}) + return fd, nil } // Readiness implements waiter.Waitable.Readiness. diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 8f0f5466e..0f592ecc3 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -37,6 +37,7 @@ go_library( "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index fcd8013c0..1412a4810 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,6 +39,7 @@ type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD socketOpsCommon } @@ -64,6 +66,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu protocol: protocol, }, } + s.LockFD.Init(&lock.FileLocks{}) vfsfd := &s.vfsfd if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index de2cc4bdf..7d4cc80fe 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", + "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 45e109361..8c32371a2 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -39,6 +40,7 @@ type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD socketOpsCommon } @@ -51,7 +53,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) - fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d) + fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &lock.FileLocks{}) if err != nil { return nil, syserr.FromError(err) } @@ -60,7 +62,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) // NewFileDescription creates and returns a socket file description // corresponding to the given mount and dentry. -func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry) (*vfs.FileDescription, error) { +func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *lock.FileLocks) (*vfs.FileDescription, error) { // You can create AF_UNIX, SOCK_RAW sockets. They're the same as // SOCK_DGRAM and don't require CAP_NET_RAW. if stype == linux.SOCK_RAW { @@ -73,6 +75,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 stype: stype, }, } + sock.LockFD.Init(locks) vfsfd := &sock.vfsfd if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 35a98212a..8347617bd 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -998,9 +998,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err } - // The lock uid is that of the Task's FDTable. - lockUniqueID := lock.UniqueID(t.FDTable().ID()) - // These locks don't block; execute the non-blocking operation using the inode's lock // context directly. switch flock.Type { @@ -1010,12 +1007,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } if cmd == linux.F_SETLK { // Non-blocking lock, provide a nil lock.Blocker. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, nil) { return 0, nil, syserror.EAGAIN } } else { // Blocking lock, pass in the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, t) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, t) { return 0, nil, syserror.EINTR } } @@ -1026,18 +1023,18 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } if cmd == linux.F_SETLK { // Non-blocking lock, provide a nil lock.Blocker. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, nil) { return 0, nil, syserror.EAGAIN } } else { // Blocking lock, pass in the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, t) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, t) { return 0, nil, syserror.EINTR } } return 0, nil, nil case linux.F_UNLCK: - file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lockUniqueID, rng) + file.Dirent.Inode.LockCtx.Posix.UnlockRegion(t.FDTable(), rng) return 0, nil, nil default: return 0, nil, syserror.EINVAL @@ -2157,22 +2154,6 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall nonblocking := operation&linux.LOCK_NB != 0 operation &^= linux.LOCK_NB - // flock(2): - // Locks created by flock() are associated with an open file table entry. This means that - // duplicate file descriptors (created by, for example, fork(2) or dup(2)) refer to the - // same lock, and this lock may be modified or released using any of these descriptors. Furthermore, - // the lock is released either by an explicit LOCK_UN operation on any of these duplicate - // descriptors, or when all such descriptors have been closed. - // - // If a process uses open(2) (or similar) to obtain more than one descriptor for the same file, - // these descriptors are treated independently by flock(). An attempt to lock the file using - // one of these file descriptors may be denied by a lock that the calling process has already placed via - // another descriptor. - // - // We use the File UniqueID as the lock UniqueID because it needs to reference the same lock across dup(2) - // and fork(2). - lockUniqueID := lock.UniqueID(file.UniqueID) - // A BSD style lock spans the entire file. rng := lock.LockRange{ Start: 0, @@ -2183,29 +2164,29 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.LOCK_EX: if nonblocking { // Since we're nonblocking we pass a nil lock.Blocker implementation. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, nil) { return 0, nil, syserror.EWOULDBLOCK } } else { // Because we're blocking we will pass the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, t) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, t) { return 0, nil, syserror.EINTR } } case linux.LOCK_SH: if nonblocking { // Since we're nonblocking we pass a nil lock.Blocker implementation. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, nil) { return 0, nil, syserror.EWOULDBLOCK } } else { // Because we're blocking we will pass the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, t) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, t) { return 0, nil, syserror.EINTR } } case linux.LOCK_UN: - file.Dirent.Inode.LockCtx.BSD.UnlockRegion(lockUniqueID, rng) + file.Dirent.Inode.LockCtx.BSD.UnlockRegion(file, rng) default: // flock(2): EINVAL operation is invalid. return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index c0d005247..9f93f4354 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -14,6 +14,7 @@ go_library( "getdents.go", "inotify.go", "ioctl.go", + "lock.go", "memfd.go", "mmap.go", "mount.go", @@ -42,6 +43,7 @@ go_library( "//pkg/fspath", "//pkg/gohacks", "//pkg/sentry/arch", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/eventfd", "//pkg/sentry/fsimpl/pipefs", diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go new file mode 100644 index 000000000..bf19028c4 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/lock.go @@ -0,0 +1,64 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Flock implements linux syscall flock(2). +func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + operation := args[1].Int() + + file := t.GetFileVFS2(fd) + if file == nil { + // flock(2): EBADF fd is not an open file descriptor. + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + nonblocking := operation&linux.LOCK_NB != 0 + operation &^= linux.LOCK_NB + + var blocker lock.Blocker + if !nonblocking { + blocker = t + } + + switch operation { + case linux.LOCK_EX: + if err := file.LockBSD(t, lock.WriteLock, blocker); err != nil { + return 0, nil, err + } + case linux.LOCK_SH: + if err := file.LockBSD(t, lock.ReadLock, blocker); err != nil { + return 0, nil, err + } + case linux.LOCK_UN: + if err := file.UnlockBSD(t); err != nil { + return 0, nil, err + } + default: + // flock(2): EINVAL operation is invalid. + return 0, nil, syserror.EINVAL + } + + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 7b6e7571a..954c82f97 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -62,7 +62,7 @@ func Override() { s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt) s.Table[59] = syscalls.Supported("execve", Execve) s.Table[72] = syscalls.Supported("fcntl", Fcntl) - delete(s.Table, 73) // flock + s.Table[73] = syscalls.Supported("fcntl", Flock) s.Table[74] = syscalls.Supported("fsync", Fsync) s.Table[75] = syscalls.Supported("fdatasync", Fdatasync) s.Table[76] = syscalls.Supported("truncate", Truncate) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 774cc66cc..16d9f3a28 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -72,6 +72,7 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/uniqueid", + "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 8297f964b..599c3131c 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -31,6 +31,7 @@ type EpollInstance struct { vfsfd FileDescription FileDescriptionDefaultImpl DentryMetadataFileDescriptionImpl + NoLockFD // q holds waiters on this EpollInstance. q waiter.Queue diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index bb294563d..97b9b18d7 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -73,6 +73,8 @@ type FileDescription struct { // writable is analogous to Linux's FMODE_WRITE. writable bool + usedLockBSD uint32 + // impl is the FileDescriptionImpl associated with this Filesystem. impl is // immutable. This should be the last field in FileDescription. impl FileDescriptionImpl @@ -175,6 +177,12 @@ func (fd *FileDescription) DecRef() { } ep.interestMu.Unlock() } + + // If BSD locks were used, release any lock that it may have acquired. + if atomic.LoadUint32(&fd.usedLockBSD) != 0 { + fd.impl.UnlockBSD(context.Background(), fd) + } + // Release implementation resources. fd.impl.Release() if fd.writable { @@ -420,13 +428,9 @@ type FileDescriptionImpl interface { Removexattr(ctx context.Context, name string) error // LockBSD tries to acquire a BSD-style advisory file lock. - // - // TODO(gvisor.dev/issue/1480): BSD-style file locking LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error - // LockBSD releases a BSD-style advisory file lock. - // - // TODO(gvisor.dev/issue/1480): BSD-style file locking + // UnlockBSD releases a BSD-style advisory file lock. UnlockBSD(ctx context.Context, uid lock.UniqueID) error // LockPOSIX tries to acquire a POSIX-style advisory file lock. @@ -736,3 +740,14 @@ func (fd *FileDescription) InodeID() uint64 { func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error { return fd.Sync(ctx) } + +// LockBSD tries to acquire a BSD-style advisory file lock. +func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, blocker lock.Blocker) error { + atomic.StoreUint32(&fd.usedLockBSD, 1) + return fd.impl.LockBSD(ctx, fd, lockType, blocker) +} + +// UnlockBSD releases a BSD-style advisory file lock. +func (fd *FileDescription) UnlockBSD(ctx context.Context) error { + return fd.impl.UnlockBSD(ctx, fd) +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index f4c111926..af7213dfd 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -21,8 +21,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs/lock" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -153,26 +154,6 @@ func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) return syserror.ENOTSUP } -// LockBSD implements FileDescriptionImpl.LockBSD. -func (FileDescriptionDefaultImpl) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error { - return syserror.EBADF -} - -// UnlockBSD implements FileDescriptionImpl.UnlockBSD. -func (FileDescriptionDefaultImpl) UnlockBSD(ctx context.Context, uid lock.UniqueID) error { - return syserror.EBADF -} - -// LockPOSIX implements FileDescriptionImpl.LockPOSIX. -func (FileDescriptionDefaultImpl) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error { - return syserror.EBADF -} - -// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX. -func (FileDescriptionDefaultImpl) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error { - return syserror.EBADF -} - // DirectoryFileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl that always represent directories to obtain // implementations of non-directory I/O methods that return EISDIR. @@ -384,3 +365,60 @@ func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.M fd.IncRef() return nil } + +// LockFD may be used by most implementations of FileDescriptionImpl.Lock* +// functions. Caller must call Init(). +type LockFD struct { + locks *lock.FileLocks +} + +// Init initializes fd with FileLocks to use. +func (fd *LockFD) Init(locks *lock.FileLocks) { + fd.locks = locks +} + +// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { + return fd.locks.LockBSD(uid, t, block) +} + +// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. +func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { + fd.locks.UnlockBSD(uid) + return nil +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { + return fd.locks.LockPOSIX(uid, t, rng, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, rng fslock.LockRange) error { + fd.locks.UnlockPOSIX(uid, rng) + return nil +} + +// NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface +// returning ENOLCK. +type NoLockFD struct{} + +// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { + return syserror.ENOLCK +} + +// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. +func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { + return syserror.ENOLCK +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { + return syserror.ENOLCK +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, rng fslock.LockRange) error { + return syserror.ENOLCK +} diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 3a75d4d62..5061f6ac9 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -33,6 +33,7 @@ import ( type fileDescription struct { vfsfd FileDescription FileDescriptionDefaultImpl + NoLockFD } // genCount contains the number of times its DynamicBytesSource.Generate() diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 05a3051a4..7fa7d2d0c 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -57,6 +57,7 @@ type Inotify struct { vfsfd FileDescription FileDescriptionDefaultImpl DentryMetadataFileDescriptionImpl + NoLockFD // Unique identifier for this inotify instance. We don't just reuse the // inotify fd because fds can be duped. These should not be exposed to the diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index ae2aa44dc..4a1486e14 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -802,10 +802,13 @@ cc_binary( ], linkstatic = 1, deps = [ + ":socket_test_util", "//test/util:file_descriptor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", gtest, + "//test/util:epoll_util", + "//test/util:eventfd_util", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index 3ecb8db8e..638a93979 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -21,6 +21,7 @@ #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/file_base.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -34,11 +35,6 @@ namespace { class FlockTest : public FileTest {}; -TEST_F(FlockTest, BadFD) { - // EBADF: fd is not an open file descriptor. - ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF)); -} - TEST_F(FlockTest, InvalidOpCombinations) { // The operation cannot be both exclusive and shared. EXPECT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_SH | LOCK_NB), @@ -57,15 +53,6 @@ TEST_F(FlockTest, NoOperationSpecified) { SyscallFailsWithErrno(EINVAL)); } -TEST(FlockTestNoFixture, FlockSupportsPipes) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds()); - EXPECT_THAT(close(fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(fds[1]), SyscallSucceeds()); -} - TEST_F(FlockTest, TestSimpleExLock) { // Test that we can obtain an exclusive lock (no other holders) // and that we can unlock it. @@ -583,6 +570,66 @@ TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive_NoRandomSave) { EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); } +TEST(FlockTestNoFixture, BadFD) { + // EBADF: fd is not an open file descriptor. + ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF)); +} + +TEST(FlockTestNoFixture, FlockDir) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000)); + EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds()); +} + +TEST(FlockTestNoFixture, FlockSymlink) { + // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH + // is supported. + SKIP_IF(IsRunningOnGvisor()); + + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto symlink = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path())); + + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000)); + EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EBADF)); +} + +TEST(FlockTestNoFixture, FlockProc) { + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000)); + EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds()); +} + +TEST(FlockTestNoFixture, FlockPipe) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + + EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds()); + // Check that the pipe was locked above. + EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EAGAIN)); + + EXPECT_THAT(flock(fds[0], LOCK_UN), SyscallSucceeds()); + EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallSucceeds()); + + EXPECT_THAT(close(fds[0]), SyscallSucceeds()); + EXPECT_THAT(close(fds[1]), SyscallSucceeds()); +} + +TEST(FlockTestNoFixture, FlockSocket) { + int sock = socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_THAT(sock, SyscallSucceeds()); + + struct sockaddr_un addr = + ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX)); + ASSERT_THAT( + bind(sock, reinterpret_cast(&addr), sizeof(addr)), + SyscallSucceeds()); + + EXPECT_THAT(flock(sock, LOCK_EX | LOCK_NB), SyscallSucceeds()); + EXPECT_THAT(close(sock), SyscallSucceeds()); +} + } // namespace } // namespace testing -- cgit v1.2.3 From 4b9652d63b319414e764696f1b77ee39cd36d96d Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Wed, 10 Jun 2020 13:36:02 -0700 Subject: {S,G}etsockopt for TCP_KEEPCNT option. TCP_KEEPCNT is used to set the maximum keepalive probes to be sent before dropping the connection. WANT_LGTM=jchacon PiperOrigin-RevId: 315758094 --- pkg/abi/linux/tcp.go | 1 + pkg/sentry/socket/netstack/netstack.go | 35 ++++++++----- test/syscalls/linux/socket_ip_tcp_generic.cc | 73 ++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 11 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go index 174d470e2..2a8d4708b 100644 --- a/pkg/abi/linux/tcp.go +++ b/pkg/abi/linux/tcp.go @@ -57,4 +57,5 @@ const ( const ( MAX_TCP_KEEPIDLE = 32767 MAX_TCP_KEEPINTVL = 32767 + MAX_TCP_KEEPCNT = 127 ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e1e0c5931..738277391 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1246,6 +1246,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_KEEPCNT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil + case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1786,6 +1798,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_KEEPCNT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + if v < 1 || v > linux.MAX_TCP_KEEPCNT { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v))) + case linux.TCP_USER_TIMEOUT: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2115,30 +2138,20 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) { switch name { case linux.TCP_CONGESTION, linux.TCP_CORK, - linux.TCP_DEFER_ACCEPT, linux.TCP_FASTOPEN, linux.TCP_FASTOPEN_CONNECT, linux.TCP_FASTOPEN_KEY, linux.TCP_FASTOPEN_NO_COOKIE, - linux.TCP_KEEPCNT, - linux.TCP_KEEPIDLE, - linux.TCP_KEEPINTVL, - linux.TCP_LINGER2, - linux.TCP_MAXSEG, linux.TCP_QUEUE_SEQ, - linux.TCP_QUICKACK, linux.TCP_REPAIR, linux.TCP_REPAIR_QUEUE, linux.TCP_REPAIR_WINDOW, linux.TCP_SAVED_SYN, linux.TCP_SAVE_SYN, - linux.TCP_SYNCNT, linux.TCP_THIN_DUPACK, linux.TCP_THIN_LINEAR_TIMEOUTS, linux.TCP_TIMESTAMP, - linux.TCP_ULP, - linux.TCP_USER_TIMEOUT, - linux.TCP_WINDOW_CLAMP: + linux.TCP_ULP: t.Kernel().EmitUnimplementedEvent(t) } diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index fa81845fd..15adc8d0e 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -524,6 +524,7 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlZero) { // Copied from include/net/tcp.h. constexpr int MAX_TCP_KEEPIDLE = 32767; constexpr int MAX_TCP_KEEPINTVL = 32767; +constexpr int MAX_TCP_KEEPCNT = 127; TEST_P(TCPSocketPairTest, SetTCPKeepidleAboveMax) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -575,6 +576,78 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlToMax) { EXPECT_EQ(get, MAX_TCP_KEEPINTVL); } +TEST_P(TCPSocketPairTest, TCPKeepcountDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 9); // 9 keepalive probes. +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &kZero, + sizeof(kZero)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountAboveMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAboveMax = MAX_TCP_KEEPCNT + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, + &kAboveMax, sizeof(kAboveMax)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountToMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, + &MAX_TCP_KEEPCNT, sizeof(MAX_TCP_KEEPCNT)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, MAX_TCP_KEEPCNT); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountToOne) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int keepaliveCount = 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, + &keepaliveCount, sizeof(keepaliveCount)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, keepaliveCount); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountToNegative) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int keepaliveCount = -5; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, + &keepaliveCount, sizeof(keepaliveCount)), + SyscallFailsWithErrno(EINVAL)); +} + TEST_P(TCPSocketPairTest, SetOOBInline) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); -- cgit v1.2.3 From 96519e2c9d3fa1f15537c4dfc081a19d8d1ce1a2 Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Wed, 17 Jun 2020 10:02:41 -0700 Subject: Implement POSIX locks - Change FileDescriptionImpl Lock/UnlockPOSIX signature to take {start,length,whence}, so the correct offset can be calculated in the implementations. - Create PosixLocker interface to make it possible to share the same locking code from different implementations. Closes #1480 PiperOrigin-RevId: 316910286 --- pkg/sentry/fsimpl/devpts/BUILD | 2 +- pkg/sentry/fsimpl/devpts/devpts.go | 3 +- pkg/sentry/fsimpl/devpts/master.go | 14 +++- pkg/sentry/fsimpl/devpts/slave.go | 14 +++- pkg/sentry/fsimpl/ext/BUILD | 2 +- pkg/sentry/fsimpl/ext/directory.go | 11 +++ pkg/sentry/fsimpl/ext/inode.go | 3 +- pkg/sentry/fsimpl/ext/regular_file.go | 11 +++ pkg/sentry/fsimpl/ext/symlink.go | 1 + pkg/sentry/fsimpl/gofer/BUILD | 1 - pkg/sentry/fsimpl/gofer/gofer.go | 12 ++- pkg/sentry/fsimpl/gofer/special_file.go | 3 +- pkg/sentry/fsimpl/host/BUILD | 2 +- pkg/sentry/fsimpl/host/host.go | 14 +++- pkg/sentry/fsimpl/host/tty.go | 11 +++ pkg/sentry/fsimpl/kernfs/BUILD | 3 +- pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go | 16 +++- pkg/sentry/fsimpl/kernfs/fd_impl_util.go | 16 +++- pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 3 +- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 5 +- pkg/sentry/fsimpl/overlay/BUILD | 2 +- pkg/sentry/fsimpl/overlay/overlay.go | 14 +++- pkg/sentry/fsimpl/pipefs/BUILD | 1 - pkg/sentry/fsimpl/pipefs/pipefs.go | 3 +- pkg/sentry/fsimpl/proc/BUILD | 2 +- pkg/sentry/fsimpl/proc/subtasks.go | 3 +- pkg/sentry/fsimpl/proc/task.go | 3 +- pkg/sentry/fsimpl/proc/task_fds.go | 3 +- pkg/sentry/fsimpl/proc/task_files.go | 14 +++- pkg/sentry/fsimpl/proc/tasks.go | 3 +- pkg/sentry/fsimpl/sys/BUILD | 1 - pkg/sentry/fsimpl/sys/sys.go | 3 +- pkg/sentry/fsimpl/tmpfs/BUILD | 1 - pkg/sentry/fsimpl/tmpfs/regular_file_test.go | 33 +++----- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 15 +++- pkg/sentry/kernel/fd_table.go | 10 ++- pkg/sentry/kernel/pipe/BUILD | 2 +- pkg/sentry/kernel/pipe/vfs.go | 18 +++- pkg/sentry/socket/hostinet/BUILD | 2 +- pkg/sentry/socket/hostinet/socket_vfs2.go | 14 +++- pkg/sentry/socket/netlink/BUILD | 2 +- pkg/sentry/socket/netlink/socket_vfs2.go | 14 +++- pkg/sentry/socket/netstack/BUILD | 2 +- pkg/sentry/socket/netstack/netstack_vfs2.go | 14 +++- pkg/sentry/socket/unix/BUILD | 2 +- pkg/sentry/socket/unix/unix_vfs2.go | 16 +++- pkg/sentry/syscalls/linux/vfs2/fd.go | 38 +++++++++ pkg/sentry/vfs/BUILD | 2 +- pkg/sentry/vfs/file_description.go | 18 ++-- pkg/sentry/vfs/file_description_impl_util.go | 25 ++---- pkg/sentry/vfs/lock.go | 109 +++++++++++++++++++++++++ pkg/sentry/vfs/lock/BUILD | 13 --- pkg/sentry/vfs/lock/lock.go | 72 ---------------- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/fcntl.cc | 96 +++++++++++++++------- 55 files changed, 484 insertions(+), 234 deletions(-) create mode 100644 pkg/sentry/vfs/lock.go delete mode 100644 pkg/sentry/vfs/lock/BUILD delete mode 100644 pkg/sentry/vfs/lock/lock.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index cf440dce8..93512c9b6 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -18,12 +18,12 @@ go_library( "//pkg/context", "//pkg/safemem", "//pkg/sentry/arch", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/unimpl", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index 9b0e0cca2..e6fda2b4f 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -117,7 +116,7 @@ type rootInode struct { kernfs.InodeNotSymlink kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks // Keep a reference to this inode's dentry. dentry kernfs.Dentry diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 1d22adbe3..69879498a 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -18,11 +18,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -35,7 +35,7 @@ type masterInode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks // Keep a reference to this inode's dentry. dentry kernfs.Dentry @@ -189,6 +189,16 @@ func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions return mfd.inode.Stat(fs, opts) } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (mfd *masterFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return mfd.Locks().LockPOSIX(ctx, &mfd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (mfd *masterFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return mfd.Locks().UnlockPOSIX(ctx, &mfd.vfsfd, uid, start, length, whence) +} + // maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid. func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { switch cmd { diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go index 7fe475080..cf1a0f0ac 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/slave.go @@ -18,10 +18,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -34,7 +34,7 @@ type slaveInode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks // Keep a reference to this inode's dentry. dentry kernfs.Dentry @@ -185,3 +185,13 @@ func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() return sfd.inode.Stat(fs, opts) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 973fa0def..ef24f8159 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -54,13 +54,13 @@ go_library( "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/syscalls/linux", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 43be6928a..357512c7e 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -305,3 +306,13 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in fd.off = offset return offset, nil } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *directoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *directoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 5caaf14ed..30636cf66 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -55,7 +54,7 @@ type inode struct { // diskInode gives us access to the inode struct on disk. Immutable. diskInode disklayout.Inode - locks lock.FileLocks + locks vfs.FileLocks // This is immutable. The first field of the implementations must have inode // as the first field to ensure temporality. diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index 152036b2e..66d14bb95 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -149,3 +150,13 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt // TODO(b/134676337): Implement mmap(2). return syserror.ENODEV } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index acb28d85b..62efd4095 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -66,6 +66,7 @@ func (in *inode) isSymlink() bool { // O_PATH. For this reason most of the functions return EBADF. type symlinkFD struct { fileDescription + vfs.NoLockFD } // Compiles only if symlinkFD implements vfs.FileDescriptionImpl. diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 5cdeeaeb5..4a800dcf9 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -69,7 +69,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index ac051b3a7..d8ae475ed 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -53,7 +53,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/usermem" @@ -665,7 +664,7 @@ type dentry struct { // endpoint bound to this file. pipe *pipe.VFSPipe - locks lock.FileLocks + locks vfs.FileLocks } // dentryAttrMask returns a p9.AttrMask enabling all attributes used by the @@ -1439,9 +1438,14 @@ func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t f } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { fd.lockLogging.Do(func() { log.Infof("Range lock using gofer file handled internally.") }) - return fd.LockFD.LockPOSIX(ctx, uid, t, rng, block) + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) } diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 289efdd25..e6e29b329 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -52,7 +51,7 @@ type specialFileFD struct { off int64 } -func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *lock.FileLocks, flags uint32) (*specialFileFD, error) { +func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { ftype := d.fileType() seekable := ftype == linux.S_IFREG mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 54f16ad63..44a09d87a 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/hostfd", "//pkg/sentry/kernel", @@ -39,7 +40,6 @@ go_library( "//pkg/sentry/unimpl", "//pkg/sentry/uniqueid", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 5ec5100b8..7906242c9 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -28,13 +28,13 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -183,7 +183,7 @@ type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks // When the reference count reaches zero, the host fd is closed. refs.AtomicRefCount @@ -718,3 +718,13 @@ func (f *fileDescription) EventUnregister(e *waiter.Entry) { func (f *fileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { return fdnotifier.NonBlockingPoll(int32(f.inode.hostFD), mask) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (f *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return f.Locks().LockPOSIX(ctx, &f.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (f *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return f.Locks().UnlockPOSIX(ctx, &f.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index 68af6e5af..0fbc543b1 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -377,3 +378,13 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) return kernel.ERESTARTSYS } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (t *TTYFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, typ fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return t.Locks().LockPOSIX(ctx, &t.vfsfd, uid, typ, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (t *TTYFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return t.Locks().UnlockPOSIX(ctx, &t.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 0299dbde9..179df6c1e 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -45,11 +45,11 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", @@ -68,7 +68,6 @@ go_test( "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/usermem", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 6418de0a3..c1215b70a 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -19,9 +19,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -39,7 +39,7 @@ type DynamicBytesFile struct { InodeNotDirectory InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks data vfs.DynamicBytesSource } @@ -86,7 +86,7 @@ type DynamicBytesFD struct { } // Init initializes a DynamicBytesFD. -func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *lock.FileLocks, flags uint32) error { +func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error { fd.LockFD.Init(locks) if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { return err @@ -135,3 +135,13 @@ func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { // DynamicBytesFiles are immutable. return syserror.EPERM } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *DynamicBytesFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *DynamicBytesFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 33a5968ca..5f7853a2a 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -19,10 +19,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -57,7 +57,7 @@ type GenericDirectoryFD struct { // NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its // dentry. -func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *lock.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) { +func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) { fd := &GenericDirectoryFD{} if err := fd.Init(children, locks, opts); err != nil { return nil, err @@ -71,7 +71,7 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildre // Init initializes a GenericDirectoryFD. Use it when overriding // GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the // correct implementation. -func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *lock.FileLocks, opts *vfs.OpenOptions) error { +func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) error { if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 { // Can't open directories for writing. return syserror.EISDIR @@ -235,3 +235,13 @@ func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptio inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode return inode.SetStat(ctx, fd.filesystem(), creds, opts) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *GenericDirectoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 0e4927215..650bd7b88 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -557,7 +556,7 @@ type StaticDirectory struct { InodeNoDynamicLookup OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks } var _ Inode = (*StaticDirectory)(nil) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 6749facf7..dc407eb1d 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -103,7 +102,7 @@ type readonlyDir struct { kernfs.InodeDirectoryNoNewChildren kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks dentry kernfs.Dentry } @@ -133,7 +132,7 @@ type dir struct { kernfs.InodeNoDynamicLookup kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks fs *filesystem dentry kernfs.Dentry diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index f9413bbdd..8cf5b35d3 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -29,11 +29,11 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index e660d0e2c..e11a3ff19 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -35,9 +35,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -415,7 +415,7 @@ type dentry struct { devMinor uint32 ino uint64 - locks lock.FileLocks + locks vfs.FileLocks } // newDentry creates a new dentry. The dentry initially has no references; it @@ -610,3 +610,13 @@ func (fd *fileDescription) filesystem() *filesystem { func (fd *fileDescription) dentry() *dentry { return fd.vfsfd.Dentry().Impl().(*dentry) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD index c618dbe6c..5950a2d59 100644 --- a/pkg/sentry/fsimpl/pipefs/BUILD +++ b/pkg/sentry/fsimpl/pipefs/BUILD @@ -15,7 +15,6 @@ go_library( "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/usermem", ], diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index e4dabaa33..dd7eaf4a8 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -82,7 +81,7 @@ type inode struct { kernfs.InodeNotSymlink kernfs.InodeNoopRefCount - locks lock.FileLocks + locks vfs.FileLocks pipe *pipe.VFSPipe ino uint64 diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 351ba4ee9..6014138ff 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -22,6 +22,7 @@ go_library( "//pkg/log", "//pkg/refs", "//pkg/safemem", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/inet", @@ -35,7 +36,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/tcpip/header", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index e2cdb7ee9..36a89540c 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -38,7 +37,7 @@ type subtasksInode struct { kernfs.OrderedChildren kernfs.AlwaysValid - locks lock.FileLocks + locks vfs.FileLocks fs *filesystem task *kernel.Task diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 44078a765..8bb2b0ce1 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -39,7 +38,7 @@ type taskInode struct { kernfs.InodeAttrs kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks task *kernel.Task } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index ef6c1d04f..7debdb07a 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -54,7 +53,7 @@ func taskFDExists(t *kernel.Task, fd int32) bool { } type fdDir struct { - locks lock.FileLocks + locks vfs.FileLocks fs *filesystem task *kernel.Task diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index e5eaa91cd..ba4405026 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -30,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -777,7 +777,7 @@ type namespaceInode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks } var _ kernfs.Inode = (*namespaceInode)(nil) @@ -830,3 +830,13 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release() { fd.inode.DecRef() } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *namespaceFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *namespaceFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 58c8b9d05..2f214d0c2 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -44,7 +43,7 @@ type tasksInode struct { kernfs.OrderedChildren kernfs.AlwaysValid - locks lock.FileLocks + locks vfs.FileLocks fs *filesystem pidns *kernel.PIDNamespace diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 237f17def..a741e2bb6 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -15,7 +15,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index b84463d3a..fe02f7ee9 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -101,7 +100,7 @@ type dir struct { kernfs.InodeDirectoryNoNewChildren kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks dentry kernfs.Dentry } diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 062321cbc..e73732a6b 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -62,7 +62,6 @@ go_library( "//pkg/sentry/uniqueid", "//pkg/sentry/usage", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sentry/vfs/memxattr", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go index 64e1c40ad..146c7fdfe 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go @@ -138,48 +138,37 @@ func TestLocks(t *testing.T) { } defer cleanup() - var ( - uid1 lock.UniqueID - uid2 lock.UniqueID - // Non-blocking. - block lock.Blocker - ) - - uid1 = 123 - uid2 = 456 - - if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, block); err != nil { + uid1 := 123 + uid2 := 456 + if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, block); err != nil { + if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want) } if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil { t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err) } - if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block); err != nil { + if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - rng1 := lock.LockRange{0, 1} - rng2 := lock.LockRange{1, 2} - - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, rng1, block); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, 0, 1, linux.SEEK_SET, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng2, block); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 1, 2, linux.SEEK_SET, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, rng1, block); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, 0, 1, linux.SEEK_SET, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng1, block), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 0, 1, linux.SEEK_SET, nil), syserror.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want) } - if err := fd.Impl().UnlockPOSIX(ctx, uid1, rng1); err != nil { + if err := fd.Impl().UnlockPOSIX(ctx, uid1, 0, 1, linux.SEEK_SET); err != nil { t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err) } } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 71a7522af..d0a3e1a5c 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -36,11 +36,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -310,7 +310,7 @@ type inode struct { ctime int64 // nanoseconds mtime int64 // nanoseconds - locks lock.FileLocks + locks vfs.FileLocks // Inotify watches for this inode. watches vfs.Watches @@ -761,9 +761,20 @@ func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name s // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with // FMODE_READ | FMODE_WRITE. var fd regularFileFD + fd.Init(&inode.locks) flags := uint32(linux.O_RDWR) if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return &fd.vfsfd, nil } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 48911240f..4b7d234a4 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" ) // FDFlags define flags for an individual descriptor. @@ -148,7 +149,12 @@ func (f *FDTable) drop(file *fs.File) { // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(file *vfs.FileDescription) { - // TODO(gvisor.dev/issue/1480): Release locks. + // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the + // entire file. + err := file.UnlockPOSIX(context.Background(), f, 0, 0, linux.SEEK_SET) + if err != nil && err != syserror.ENOLCK { + panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) + } // Generate inotify events. ev := uint32(linux.IN_CLOSE_NOWRITE) @@ -157,7 +163,7 @@ func (f *FDTable) dropVFS2(file *vfs.FileDescription) { } file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent) - // Drop the table reference. + // Drop the table's reference. file.DecRef() } diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 0db546b98..449643118 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -26,8 +26,8 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index c0e9ee1f4..a4519363f 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -63,12 +63,12 @@ func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe { // Preconditions: statusFlags should not contain an open access mode. func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) { // Connected pipes share the same locks. - locks := &lock.FileLocks{} + locks := &vfs.FileLocks{} return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) } // Open opens the pipe represented by vp. -func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *lock.FileLocks) (*vfs.FileDescription, error) { +func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { vp.mu.Lock() defer vp.mu.Unlock() @@ -130,7 +130,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s } // Preconditions: vp.mu must be held. -func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *lock.FileLocks) *vfs.FileDescription { +func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription { fd := &VFSPipeFD{ pipe: &vp.pipe, } @@ -451,3 +451,13 @@ func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFr } return n, err } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *VFSPipeFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *VFSPipeFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 60c9896fc..ff81ea6e6 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -26,6 +26,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostfd", "//pkg/sentry/inet", @@ -34,7 +35,6 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip/stack", diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 027add1fd..ad5f64799 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -21,12 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -61,7 +61,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in fd: fd, }, } - s.LockFD.Init(&lock.FileLocks{}) + s.LockFD.Init(&vfs.FileLocks{}) if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -134,6 +134,16 @@ func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return int64(n), err } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + type socketProviderVFS2 struct { family int } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 420e573c9..d5ca3ac56 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -29,7 +30,6 @@ go_library( "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index 8bfee5193..dbcd8b49a 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -18,12 +18,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -78,7 +78,7 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV sendBufferSize: defaultSendBufferSize, }, } - fd.LockFD.Init(&lock.FileLocks{}) + fd.LockFD.Init(&vfs.FileLocks{}) return fd, nil } @@ -140,3 +140,13 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 0f592ecc3..ea6ebd0e2 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", @@ -37,7 +38,6 @@ go_library( "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 1412a4810..d65a89316 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -19,13 +19,13 @@ import ( "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -66,7 +66,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu protocol: protocol, }, } - s.LockFD.Init(&lock.FileLocks{}) + s.LockFD.Init(&vfs.FileLocks{}) vfsfd := &s.vfsfd if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, @@ -318,3 +318,13 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 7d4cc80fe..cca5e70f1 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", @@ -29,7 +30,6 @@ go_library( "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 8c32371a2..ff2149250 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -26,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -53,7 +53,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) - fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &lock.FileLocks{}) + fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) if err != nil { return nil, syserr.FromError(err) } @@ -62,7 +62,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) // NewFileDescription creates and returns a socket file description // corresponding to the given mount and dentry. -func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *lock.FileLocks) (*vfs.FileDescription, error) { +func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) { // You can create AF_UNIX, SOCK_RAW sockets. They're the same as // SOCK_DGRAM and don't require CAP_NET_RAW. if stype == linux.SOCK_RAW { @@ -300,6 +300,16 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + // providerVFS2 is a unix domain socket provider for VFS2. type providerVFS2 struct{} diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index f9ccb303c..f5eaa076b 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -17,10 +17,12 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) @@ -167,8 +169,44 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } err := tmpfs.AddSeals(file, args[2].Uint()) return 0, nil, err + case linux.F_SETLK, linux.F_SETLKW: + return 0, nil, posixLock(t, args, file, cmd) default: // TODO(gvisor.dev/issue/2920): Everything else is not yet supported. return 0, nil, syserror.EINVAL } } + +func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error { + // Copy in the lock request. + flockAddr := args[2].Pointer() + var flock linux.Flock + if _, err := t.CopyIn(flockAddr, &flock); err != nil { + return err + } + + var blocker lock.Blocker + if cmd == linux.F_SETLKW { + blocker = t + } + + switch flock.Type { + case linux.F_RDLCK: + if !file.IsReadable() { + return syserror.EBADF + } + return file.LockPOSIX(t, t.FDTable(), lock.ReadLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker) + + case linux.F_WRLCK: + if !file.IsWritable() { + return syserror.EBADF + } + return file.LockPOSIX(t, t.FDTable(), lock.WriteLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker) + + case linux.F_UNLCK: + return file.UnlockPOSIX(t, t.FDTable(), uint64(flock.Start), uint64(flock.Len), flock.Whence) + + default: + return syserror.EINVAL + } +} diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 16d9f3a28..642769e7c 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -44,6 +44,7 @@ go_library( "filesystem_impl_util.go", "filesystem_type.go", "inotify.go", + "lock.go", "mount.go", "mount_unsafe.go", "options.go", @@ -72,7 +73,6 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/uniqueid", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 13c48824e..e0538ea53 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -438,14 +438,10 @@ type FileDescriptionImpl interface { UnlockBSD(ctx context.Context, uid lock.UniqueID) error // LockPOSIX tries to acquire a POSIX-style advisory file lock. - // - // TODO(gvisor.dev/issue/1480): POSIX-style file locking - LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error + LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, length uint64, whence int16, block lock.Blocker) error // UnlockPOSIX releases a POSIX-style advisory file lock. - // - // TODO(gvisor.dev/issue/1480): POSIX-style file locking - UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error + UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, length uint64, whence int16) error } // Dirent holds the information contained in struct linux_dirent64. @@ -764,3 +760,13 @@ func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, func (fd *FileDescription) UnlockBSD(ctx context.Context) error { return fd.impl.UnlockBSD(ctx, fd) } + +// LockPOSIX locks a POSIX-style file range lock. +func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, end uint64, whence int16, block lock.Blocker) error { + return fd.impl.LockPOSIX(ctx, uid, t, start, end, whence, block) +} + +// UnlockPOSIX unlocks a POSIX-style file range lock. +func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error { + return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence) +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index af7213dfd..1e66997ce 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -369,14 +368,19 @@ func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.M // LockFD may be used by most implementations of FileDescriptionImpl.Lock* // functions. Caller must call Init(). type LockFD struct { - locks *lock.FileLocks + locks *FileLocks } // Init initializes fd with FileLocks to use. -func (fd *LockFD) Init(locks *lock.FileLocks) { +func (fd *LockFD) Init(locks *FileLocks) { fd.locks = locks } +// Locks returns the locks associated with this file. +func (fd *LockFD) Locks() *FileLocks { + return fd.locks +} + // LockBSD implements vfs.FileDescriptionImpl.LockBSD. func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { return fd.locks.LockBSD(uid, t, block) @@ -388,17 +392,6 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { return nil } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { - return fd.locks.LockPOSIX(uid, t, rng, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, rng fslock.LockRange) error { - fd.locks.UnlockPOSIX(uid, rng) - return nil -} - // NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface // returning ENOLCK. type NoLockFD struct{} @@ -414,11 +407,11 @@ func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { +func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return syserror.ENOLCK } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, rng fslock.LockRange) error { +func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { return syserror.ENOLCK } diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go new file mode 100644 index 000000000..6c7583a81 --- /dev/null +++ b/pkg/sentry/vfs/lock.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package lock provides POSIX and BSD style file locking for VFS2 file +// implementations. +// +// The actual implementations can be found in the lock package under +// sentry/fs/lock. +package vfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/syserror" +) + +// FileLocks supports POSIX and BSD style locks, which correspond to fcntl(2) +// and flock(2) respectively in Linux. It can be embedded into various file +// implementations for VFS2 that support locking. +// +// Note that in Linux these two types of locks are _not_ cooperative, because +// race and deadlock conditions make merging them prohibitive. We do the same +// and keep them oblivious to each other. +type FileLocks struct { + // bsd is a set of BSD-style advisory file wide locks, see flock(2). + bsd fslock.Locks + + // posix is a set of POSIX-style regional advisory locks, see fcntl(2). + posix fslock.Locks +} + +// LockBSD tries to acquire a BSD-style lock on the entire file. +func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { + if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) { + return nil + } + return syserror.ErrWouldBlock +} + +// UnlockBSD releases a BSD-style lock on the entire file. +// +// This operation is always successful, even if there did not exist a lock on +// the requested region held by uid in the first place. +func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) { + fl.bsd.UnlockRegion(uid, fslock.LockRange{0, fslock.LockEOF}) +} + +// LockPOSIX tries to acquire a POSIX-style lock on a file region. +func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + rng, err := computeRange(ctx, fd, start, length, whence) + if err != nil { + return err + } + if fl.posix.LockRegion(uid, t, rng, block) { + return nil + } + return syserror.ErrWouldBlock +} + +// UnlockPOSIX releases a POSIX-style lock on a file region. +// +// This operation is always successful, even if there did not exist a lock on +// the requested region held by uid in the first place. +func (fl *FileLocks) UnlockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, start, length uint64, whence int16) error { + rng, err := computeRange(ctx, fd, start, length, whence) + if err != nil { + return err + } + fl.posix.UnlockRegion(uid, rng) + return nil +} + +func computeRange(ctx context.Context, fd *FileDescription, start uint64, length uint64, whence int16) (fslock.LockRange, error) { + var off int64 + switch whence { + case linux.SEEK_SET: + off = 0 + case linux.SEEK_CUR: + // Note that Linux does not hold any mutexes while retrieving the file + // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk. + curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR) + if err != nil { + return fslock.LockRange{}, err + } + off = curOff + case linux.SEEK_END: + stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE}) + if err != nil { + return fslock.LockRange{}, err + } + off = int64(stat.Size) + default: + return fslock.LockRange{}, syserror.EINVAL + } + + return fslock.ComputeRange(int64(start), int64(length), off) +} diff --git a/pkg/sentry/vfs/lock/BUILD b/pkg/sentry/vfs/lock/BUILD deleted file mode 100644 index d9ab063b7..000000000 --- a/pkg/sentry/vfs/lock/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "lock", - srcs = ["lock.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/fs/lock", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/vfs/lock/lock.go b/pkg/sentry/vfs/lock/lock.go deleted file mode 100644 index 724dfe743..000000000 --- a/pkg/sentry/vfs/lock/lock.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package lock provides POSIX and BSD style file locking for VFS2 file -// implementations. -// -// The actual implementations can be found in the lock package under -// sentry/fs/lock. -package lock - -import ( - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/syserror" -) - -// FileLocks supports POSIX and BSD style locks, which correspond to fcntl(2) -// and flock(2) respectively in Linux. It can be embedded into various file -// implementations for VFS2 that support locking. -// -// Note that in Linux these two types of locks are _not_ cooperative, because -// race and deadlock conditions make merging them prohibitive. We do the same -// and keep them oblivious to each other. -type FileLocks struct { - // bsd is a set of BSD-style advisory file wide locks, see flock(2). - bsd fslock.Locks - - // posix is a set of POSIX-style regional advisory locks, see fcntl(2). - posix fslock.Locks -} - -// LockBSD tries to acquire a BSD-style lock on the entire file. -func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { - if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) { - return nil - } - return syserror.ErrWouldBlock -} - -// UnlockBSD releases a BSD-style lock on the entire file. -// -// This operation is always successful, even if there did not exist a lock on -// the requested region held by uid in the first place. -func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) { - fl.bsd.UnlockRegion(uid, fslock.LockRange{0, fslock.LockEOF}) -} - -// LockPOSIX tries to acquire a POSIX-style lock on a file region. -func (fl *FileLocks) LockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { - if fl.posix.LockRegion(uid, t, rng, block) { - return nil - } - return syserror.ErrWouldBlock -} - -// UnlockPOSIX releases a POSIX-style lock on a file region. -// -// This operation is always successful, even if there did not exist a lock on -// the requested region held by uid in the first place. -func (fl *FileLocks) UnlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) { - fl.posix.UnlockRegion(uid, rng) -} diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 96044928e..078e4a284 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -791,6 +791,7 @@ cc_binary( deps = [ ":socket_test_util", "//test/util:cleanup", + "//test/util:epoll_util", "//test/util:eventfd_util", "//test/util:fs_util", "@com_google_absl//absl/base:core_headers", diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 35e8a4ff3..25bef2522 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -191,45 +191,85 @@ TEST(FcntlTest, SetFlags) { EXPECT_EQ(rflags, expected); } -TEST_F(FcntlLockTest, SetLockBadFd) { +void TestLock(int fd, short lock_type = F_RDLCK) { // NOLINT, type in flock struct flock fl; - fl.l_type = F_WRLCK; + fl.l_type = lock_type; fl.l_whence = SEEK_SET; fl.l_start = 0; - // len 0 has a special meaning: lock all bytes despite how - // large the file grows. + // len 0 locks all bytes despite how large the file grows. fl.l_len = 0; - EXPECT_THAT(fcntl(-1, F_SETLK, &fl), SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(fcntl(fd, F_SETLK, &fl), SyscallSucceeds()); } -TEST_F(FcntlLockTest, SetLockPipe) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - +void TestLockBadFD(int fd, + short lock_type = F_RDLCK) { // NOLINT, type in flock struct flock fl; - fl.l_type = F_WRLCK; + fl.l_type = lock_type; fl.l_whence = SEEK_SET; fl.l_start = 0; - // Same as SetLockBadFd, but doesn't matter, we expect this to fail. + // len 0 locks all bytes despite how large the file grows. fl.l_len = 0; - EXPECT_THAT(fcntl(fds[0], F_SETLK, &fl), SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(close(fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(fds[1]), SyscallSucceeds()); + EXPECT_THAT(fcntl(fd, F_SETLK, &fl), SyscallFailsWithErrno(EBADF)); } +TEST_F(FcntlLockTest, SetLockBadFd) { TestLockBadFD(-1); } + TEST_F(FcntlLockTest, SetLockDir) { auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0666)); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000)); + TestLock(fd.get()); +} - struct flock fl; - fl.l_type = F_RDLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; +TEST_F(FcntlLockTest, SetLockSymlink) { + // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH + // is supported. + SKIP_IF(IsRunningOnGvisor()); - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto symlink = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path())); + + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000)); + TestLockBadFD(fd.get()); +} + +TEST_F(FcntlLockTest, SetLockProc) { + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000)); + TestLock(fd.get()); +} + +TEST_F(FcntlLockTest, SetLockPipe) { + SKIP_IF(IsRunningWithVFS1()); + + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + + TestLock(fds[0]); + TestLockBadFD(fds[0], F_WRLCK); + + TestLock(fds[1], F_WRLCK); + TestLockBadFD(fds[1]); + + EXPECT_THAT(close(fds[0]), SyscallSucceeds()); + EXPECT_THAT(close(fds[1]), SyscallSucceeds()); +} + +TEST_F(FcntlLockTest, SetLockSocket) { + SKIP_IF(IsRunningWithVFS1()); + + int sock = socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_THAT(sock, SyscallSucceeds()); + + struct sockaddr_un addr = + ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX)); + ASSERT_THAT( + bind(sock, reinterpret_cast(&addr), sizeof(addr)), + SyscallSucceeds()); + + TestLock(sock); + EXPECT_THAT(close(sock), SyscallSucceeds()); } TEST_F(FcntlLockTest, SetLockBadOpenFlagsWrite) { @@ -241,8 +281,7 @@ TEST_F(FcntlLockTest, SetLockBadOpenFlagsWrite) { fl0.l_type = F_WRLCK; fl0.l_whence = SEEK_SET; fl0.l_start = 0; - // Same as SetLockBadFd. - fl0.l_len = 0; + fl0.l_len = 0; // Lock all file // Expect that setting a write lock using a read only file descriptor // won't work. @@ -704,7 +743,7 @@ TEST_F(FcntlLockTest, SetWriteLockThenBlockingWriteLock) { << "Exited with code: " << status; } -// This test will veirfy that blocking works as expected when another process +// This test will verify that blocking works as expected when another process // holds a read lock when obtaining a write lock. This test will hold the lock // for some amount of time and then wait for the second process to send over the // socket_fd the amount of time it was blocked for before the lock succeeded. @@ -1109,8 +1148,7 @@ int main(int argc, char** argv) { fl.l_start = absl::GetFlag(FLAGS_child_setlock_start); fl.l_len = absl::GetFlag(FLAGS_child_setlock_len); - // Test the fcntl, no need to log, the error is unambiguously - // from fcntl at this point. + // Test the fcntl. int err = 0; int ret = 0; @@ -1123,6 +1161,8 @@ int main(int argc, char** argv) { if (ret == -1 && errno != 0) { err = errno; + std::cerr << "CHILD lock " << setlock_on << " failed " << err + << std::endl; } // If there is a socket fd let's send back the time in microseconds it took -- cgit v1.2.3 From 07ff909e76d8233827e705476ec116fc2cecec2f Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 18 Jun 2020 06:05:47 -0700 Subject: Support setsockopt SO_SNDBUF/SO_RCVBUF for raw/udp sockets. Updates #173,#6 Fixes #2888 PiperOrigin-RevId: 317087652 --- benchmarks/tcp/tcp_proxy.go | 8 +- pkg/sentry/socket/netstack/stack.go | 12 +- pkg/tcpip/tcpip.go | 26 ++ pkg/tcpip/transport/raw/endpoint.go | 72 +++- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 16 +- pkg/tcpip/transport/tcp/endpoint_state.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 59 ++-- pkg/tcpip/transport/tcp/tcp_sack_test.go | 4 +- pkg/tcpip/transport/tcp/tcp_test.go | 26 +- pkg/tcpip/transport/tcp/testing/context/context.go | 10 +- pkg/tcpip/transport/udp/endpoint.go | 57 +++- pkg/tcpip/transport/udp/protocol.go | 70 +++- runsc/boot/loader.go | 6 +- test/syscalls/linux/raw_socket_ipv4.cc | 379 +++++++++++++++++++++ test/syscalls/linux/socket_ipv4_udp_unbound.cc | 216 ++++++++++++ test/syscalls/linux/udp_socket_test_cases.cc | 115 +++++++ 17 files changed, 984 insertions(+), 96 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index f5aa0b515..b3a4dbea3 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -228,19 +228,19 @@ func newNetstackImpl(mode string) (impl, error) { }) // Set protocol options. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(*sack)); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err) } // Enable Receive Buffer Auto-Tuning. if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(*moderateRecvBuf)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) + return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) } // Set Congestion Control to cubic if requested. if *cubic { if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %v", err) + return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %s", err) } } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 9b44c2b89..f97f9b6f3 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -144,7 +144,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - var rs tcp.ReceiveBufferSizeOption + var rs tcpip.StackReceiveBufferSizeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs) return inet.TCPBufferSize{ Min: rs.Min, @@ -155,7 +155,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - rs := tcp.ReceiveBufferSizeOption{ + rs := tcpip.StackReceiveBufferSizeOption{ Min: size.Min, Default: size.Default, Max: size.Max, @@ -165,7 +165,7 @@ func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { // TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - var ss tcp.SendBufferSizeOption + var ss tcpip.StackSendBufferSizeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss) return inet.TCPBufferSize{ Min: ss.Min, @@ -176,7 +176,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - ss := tcp.SendBufferSizeOption{ + ss := tcpip.StackSendBufferSizeOption{ Min: size.Min, Default: size.Default, Max: size.Max, @@ -186,14 +186,14 @@ func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { // TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. func (s *Stack) TCPSACKEnabled() (bool, error) { - var sack tcp.SACKEnabled + var sack tcpip.StackSACKEnabled err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack) return bool(sack), syserr.TranslateNetstackError(err).ToError() } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enabled))).ToError() } // Statistics implements inet.Stack.Statistics. diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index b7b227328..3ad130b23 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -813,6 +813,32 @@ type OutOfBandInlineOption int // a default TTL. type DefaultTTLOption uint8 +// StackSACKEnabled is used by stack.(*Stack).TransportProtocolOption to +// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. +type StackSACKEnabled bool + +// StackDelayEnabled is used by stack.(Stack*).TransportProtocolOption to +// enable/disable Nagle's algorithm in TCP. +type StackDelayEnabled bool + +// StackSendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption +// to get/set the default, min and max send buffer sizes. +type StackSendBufferSizeOption struct { + Min int + Default int + Max int +} + +// StackReceiveBufferSizeOption is used by +// stack.(Stack*).TransportProtocolOption to get/set the default, min and max +// receive buffer sizes. +type StackReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +// // IPPacketInfo is the message struture for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index a406d815e..6a7977259 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,6 +26,8 @@ package raw import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -66,16 +68,17 @@ type endpoint struct { // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` rcvList rawPacketList - rcvBufSizeMax int `state:".(int)"` rcvBufSize int + rcvBufSizeMax int `state:".(int)"` rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - connected bool - bound bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + connected bool + bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` @@ -103,10 +106,21 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, + sndBufSizeMax: 32 * 1024, associated: associated, } + // Override with stack defaults. + var ss tcpip.StackSendBufferSizeOption + if err := s.TransportProtocolOption(transProto, &ss); err == nil { + e.sndBufSizeMax = ss.Default + } + + var rs tcpip.StackReceiveBufferSizeOption + if err := s.TransportProtocolOption(transProto, &rs); err == nil { + e.rcvBufSizeMax = rs.Default + } + // Unassociated endpoints are write-only and users call Write() with IP // headers included. Because they're write-only, We don't need to // register with the stack. @@ -523,7 +537,46 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss tcpip.StackSendBufferSizeOption + if err := e.stack.TransportProtocolOption(e.TransProto, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + e.mu.Lock() + e.sndBufSizeMax = v + e.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs tcpip.StackReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(e.TransProto, &rs); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + e.rcvMu.Lock() + e.rcvBufSizeMax = v + e.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -563,7 +616,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.SendBufferSizeOption: e.mu.Lock() - v := e.sndBufSize + v := e.sndBufSizeMax e.mu.Unlock() return v, nil @@ -636,7 +689,6 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 9d4dce826..377643b82 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -521,7 +521,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled SACKEnabled + var sackEnabled tcpip.StackSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f225b00e7..10df2bcd5 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -851,12 +851,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss SendBufferSizeOption + var ss tcpip.StackSendBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs ReceiveBufferSizeOption + var rs tcpip.StackReceiveBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -871,7 +871,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } - var de DelayEnabled + var de tcpip.StackDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -1588,7 +1588,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs ReceiveBufferSizeOption + var rs tcpip.StackReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min { v = rs.Min @@ -1638,7 +1638,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss SendBufferSizeOption + var ss tcpip.StackSendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if v < ss.Min { v = ss.Min @@ -1678,7 +1678,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs ReceiveBufferSizeOption + var rs tcpip.StackReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -2609,7 +2609,7 @@ func (e *endpoint) receiveBufferSize() int { } func (e *endpoint) maxReceiveBufferSize() int { - var rs ReceiveBufferSizeOption + var rs tcpip.StackReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2690,7 +2690,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v SACKEnabled + var v tcpip.StackSACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index cbb779666..0bebec2d1 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -186,7 +186,7 @@ func (e *endpoint) Resume(s *stack.Stack) { epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption + var ss tcpip.StackSendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 73b8a6782..3cff55afa 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -71,29 +71,6 @@ const ( DefaultSynRetries = 6 ) -// SACKEnabled option can be used to enable SACK support in the TCP -// protocol. See: https://tools.ietf.org/html/rfc2018. -type SACKEnabled bool - -// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol. -type DelayEnabled bool - -// SendBufferSizeOption allows the default, min and max send buffer sizes for -// TCP endpoints to be queried or configured. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - -// ReceiveBufferSizeOption allows the default, min and max receive buffer size -// for TCP endpoints to be queried or configured. -type ReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - const ( ccReno = "reno" ccCubic = "cubic" @@ -160,8 +137,8 @@ type protocol struct { mu sync.RWMutex sackEnabled bool delayEnabled bool - sendBufferSize SendBufferSizeOption - recvBufferSize ReceiveBufferSizeOption + sendBufferSize tcpip.StackSendBufferSizeOption + recvBufferSize tcpip.StackReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -272,19 +249,19 @@ func replyWithReset(s *segment, tos, ttl uint8) { // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { - case SACKEnabled: + case tcpip.StackSACKEnabled: p.mu.Lock() p.sackEnabled = bool(v) p.mu.Unlock() return nil - case DelayEnabled: + case tcpip.StackDelayEnabled: p.mu.Lock() p.delayEnabled = bool(v) p.mu.Unlock() return nil - case SendBufferSizeOption: + case tcpip.StackSendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -293,7 +270,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case tcpip.StackReceiveBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -386,25 +363,25 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { - case *SACKEnabled: + case *tcpip.StackSACKEnabled: p.mu.RLock() - *v = SACKEnabled(p.sackEnabled) + *v = tcpip.StackSACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *DelayEnabled: + case *tcpip.StackDelayEnabled: p.mu.RLock() - *v = DelayEnabled(p.delayEnabled) + *v = tcpip.StackDelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *SendBufferSizeOption: + case *tcpip.StackSendBufferSizeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.StackReceiveBufferSizeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -514,8 +491,16 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ - sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, + sendBufferSize: tcpip.StackSendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultSendBufferSize, + Max: MaxBufferSize, + }, + recvBufferSize: tcpip.StackReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultReceiveBufferSize, + Max: MaxBufferSize, + }, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index fcc165f17..812e503bc 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,8 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enable)); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, StackSACKEnabled(%t) = %s", enable, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0668cedc9..aca6a7951 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3987,7 +3987,10 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -4001,8 +4004,11 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) } ep.Close() @@ -4029,11 +4035,11 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5672,7 +5678,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5793,7 +5799,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5935,7 +5941,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcp.DelayEnabled + delayEnabled tcpip.StackDelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -5950,10 +5956,10 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.StackDelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcp.DelayEnabled + var gotDelayEnabled tcpip.StackDelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 9721f6caf..9e262c272 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -144,12 +144,12 @@ func New(t *testing.T, mtu uint32) *Context { }) // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Increase minimum RTO in tests to avoid test flakes due to early @@ -1091,7 +1091,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled + var v tcpip.StackSACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index df5efbf6a..f51988047 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,6 +15,8 @@ package udp import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -94,6 +96,7 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int + sndBufSizeMax int state EndpointState route stack.Route `state:"manual"` dstPort uint16 @@ -159,7 +162,7 @@ type multicastMembership struct { } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { - return &endpoint{ + e := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, @@ -181,10 +184,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue multicastTTL: 1, multicastLoop: true, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, + sndBufSizeMax: 32 * 1024, state: StateInitial, uniqueID: s.UniqueID(), } + + // Override with stack defaults. + var ss tcpip.StackSendBufferSizeOption + if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + e.sndBufSizeMax = ss.Default + } + + var rs tcpip.StackReceiveBufferSizeOption + if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + e.rcvBufSizeMax = rs.Default + } + + return e } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -611,8 +627,43 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.mu.Unlock() case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs tcpip.StackReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, rs, err)) + } + + if v < rs.Min { + v = rs.Min + } + if v > rs.Max { + v = rs.Max + } + + e.mu.Lock() + e.rcvBufSizeMax = v + e.mu.Unlock() + return nil case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss tcpip.StackSendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, ss, err)) + } + + if v < ss.Min { + v = ss.Min + } + if v > ss.Max { + v = ss.Max + } + e.mu.Lock() + e.sndBufSizeMax = v + e.mu.Unlock() + return nil } return nil @@ -861,7 +912,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.SendBufferSizeOption: e.mu.Lock() - v := e.sndBufSize + v := e.sndBufSizeMax e.mu.Unlock() return v, nil diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 4218e7d03..fc93f93c0 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -21,6 +21,7 @@ package udp import ( + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -32,9 +33,27 @@ import ( const ( // ProtocolNumber is the udp protocol number. ProtocolNumber = header.UDPProtocolNumber + + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4KiB bytes. + + // DefaultSendBufferSize is the default size of the send buffer for + // an endpoint. + DefaultSendBufferSize = 32 << 10 // 32KiB + + // DefaultReceiveBufferSize is the default size of the receive buffer + // for an endpoint. + DefaultReceiveBufferSize = 32 << 10 // 32KiB + + // MaxBufferSize is the largest size a receive/send buffer can grow to. + MaxBufferSize = 4 << 20 // 4MiB ) -type protocol struct{} +type protocol struct { + mu sync.RWMutex + sendBufferSize tcpip.StackSendBufferSizeOption + recvBufferSize tcpip.StackReceiveBufferSizeOption +} // Number returns the udp protocol number. func (*protocol) Number() tcpip.TransportProtocolNumber { @@ -183,13 +202,49 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case tcpip.StackSendBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.sendBufferSize = v + p.mu.Unlock() + return nil + + case tcpip.StackReceiveBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.recvBufferSize = v + p.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (p *protocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *tcpip.StackSendBufferSizeOption: + p.mu.RLock() + *v = p.sendBufferSize + p.mu.RUnlock() + return nil + + case *tcpip.StackReceiveBufferSizeOption: + p.mu.RLock() + *v = p.recvBufferSize + p.mu.RUnlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // Close implements stack.TransportProtocol.Close. @@ -212,5 +267,8 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{} + return &protocol{ + sendBufferSize: tcpip.StackSendBufferSizeOption{Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize}, + recvBufferSize: tcpip.StackReceiveBufferSizeOption{Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize}, + } } diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index b05a8bd45..c6efcdc83 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1058,8 +1058,8 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in })} // Enable SACK Recovery. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { - return nil, fmt.Errorf("failed to enable SACK: %v", err) + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(true)); err != nil { + return nil, fmt.Errorf("failed to enable SACK: %s", err) } // Set default TTLs as required by socket/netstack. @@ -1068,7 +1068,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in // Enable Receive Buffer Auto-Tuning. if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) + return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) } s.FillIPTablesMetadata() diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc index cde2f07c9..0116c3e94 100644 --- a/test/syscalls/linux/raw_socket_ipv4.cc +++ b/test/syscalls/linux/raw_socket_ipv4.cc @@ -357,10 +357,389 @@ TEST_P(RawSocketTest, BindConnectSendAndReceive) { EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); } +// Check that setting SO_RCVBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(RawSocketTest, SetSocketRecvBufBelowMin) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover minimum receive buf size by trying to set it to zero. + // See: + // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820 + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &below_min, sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_RCVBUF above max is clamped to the maximum +// receive buffer size. +TEST_P(RawSocketTest, SetSocketRecvBufAboveMax) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover max buf size by trying to set the largest possible buffer size. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &above_max, sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_RCVBUF min <= kRcvBufSz <= max is honored. +TEST_P(RawSocketTest, SetSocketRecvBuf) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int max = 0; + int min = 0; + { + // Discover max buf size by trying to set a really large buffer size. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by trying to set a zero size receive buffer + // size. + // See: + // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820 + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + ASSERT_EQ(quarter_sz, val); +} + +// Check that setting SO_SNDBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(RawSocketTest, SetSocketSendBufBelowMin) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover minimum buffer size by trying to set it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &below_min, sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_SNDBUF above max is clamped to the maximum +// send buffer size. +TEST_P(RawSocketTest, SetSocketSendBufAboveMax) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover maximum buffer size by trying to set it to a large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &above_max, sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored. +TEST_P(RawSocketTest, SetSocketSendBuf) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int max = 0; + int min = 0; + { + // Discover maximum buffer size by trying to set it to a large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack + // matches linux behavior. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + + ASSERT_EQ(quarter_sz, val); +} + void RawSocketTest::SendBuf(const char* buf, int buf_len) { ASSERT_NO_FATAL_FAILURE(SendBufTo(s_, addr_, buf, buf_len)); } +// Test that receive buffer limits are not enforced when the recv buffer is +// empty. +TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast(&addr_), sizeof(addr_)), + SyscallSucceeds()); + ASSERT_THAT( + connect(s_, reinterpret_cast(&addr_), sizeof(addr_)), + SyscallSucceeds()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + { + // Send data of size min and verify that it's received. + std::vector buf(min); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + + // Receive the packet and make sure it's identical. + std::vector recv_buf(buf.size() + sizeof(struct iphdr)); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ( + memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), buf.size()), + 0); + } + + { + // Send data of size min + 1 and verify that its received. Both linux and + // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer + // is currently empty. + std::vector buf(min + 1); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + // Receive the packet and make sure it's identical. + std::vector recv_buf(buf.size() + sizeof(struct iphdr)); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ( + memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), buf.size()), + 0); + } +} + +TEST_P(RawSocketTest, RecvBufLimits) { + // TCP stack generates RSTs for unknown endpoints and it complicates the test + // as we have to deal with the RST packets as well. For testing the raw socket + // endpoints buffer limit enforcement we can just test for UDP. + // + // We don't use SKIP_IF here because root_test_runner explicitly fails if a + // test is skipped. + if (Protocol() == IPPROTO_TCP) { + return; + } + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast(&addr_), sizeof(addr_)), + SyscallSucceeds()); + ASSERT_THAT( + connect(s_, reinterpret_cast(&addr_), sizeof(addr_)), + SyscallSucceeds()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + // Now set the limit to min * 2. + int new_rcv_buf_sz = min * 4; + if (!IsRunningOnGvisor()) { + // Linux doubles the value specified so just set to min. + new_rcv_buf_sz = min * 2; + } + + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, + sizeof(new_rcv_buf_sz)), + SyscallSucceeds()); + int rcv_buf_sz = 0; + { + socklen_t rcv_buf_len = sizeof(rcv_buf_sz); + ASSERT_THAT( + getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, &rcv_buf_len), + SyscallSucceeds()); + } + + // Set a receive timeout so that we don't block forever on reads if the test + // fails. + struct timeval tv { + .tv_sec = 1, .tv_usec = 0, + }; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + { + std::vector buf(min); + RandomizeBuffer(buf.data(), buf.size()); + + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + int sent = 4; + if (IsRunningOnGvisor()) { + // Linux seems to drop the 4th packet even though technically it should + // fit in the receive buffer. + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + sent++; + } + + // Verify that the expected number of packets are available to be read. + for (int i = 0; i < sent - 1; i++) { + // Receive the packet and make sure it's identical. + std::vector recv_buf(buf.size() + sizeof(struct iphdr)); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), + buf.size()), + 0); + } + + // Assert that the last packet is dropped because the receive buffer should + // be full after the first four packets. + std::vector recv_buf(buf.size() + sizeof(struct iphdr)); + struct iovec iov = {}; + iov.iov_base = static_cast(const_cast(recv_buf.data())); + iov.iov_len = buf.size(); + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + void RawSocketTest::SendBufTo(int sock, const struct sockaddr_in& addr, const char* buf, int buf_len) { // It's safe to use const_cast here because sendmsg won't modify the iovec or diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 4db9553e1..de0f5f01b 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -2236,5 +2237,220 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK)); } + +// Check that setting SO_RCVBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufBelowMin) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &below_min, + sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_RCVBUF above max is clamped to the maximum +// receive buffer size. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufAboveMax) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &above_max, + sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_RCVBUF min <= rcvBufSz <= max is honored. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBuf) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int max = 0; + int min = 0; + { + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &quarter_sz, + sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + ASSERT_EQ(quarter_sz, val); +} + +// Check that setting SO_SNDBUF below min is clamped to the minimum +// send buffer size. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBufBelowMin) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Discover minimum buffer size by setting it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz, + sizeof(kSndBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &below_min, + sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_SNDBUF above max is clamped to the maximum +// send buffer size. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBufAboveMax) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz, + sizeof(kSndBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &above_max, + sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int max = 0; + int min = 0; + { + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz, + sizeof(kSndBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by setting it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz, + sizeof(kSndBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &quarter_sz, + sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + + ASSERT_EQ(quarter_sz, val); +} } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index 42521efef..cc1db3de8 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1543,5 +1543,120 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) { memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); EXPECT_EQ(received_tos, sent_tos); } + +TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Bind s_ to loopback. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + { + // Send data of size min and verify that it's received. + std::vector buf(min); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(buf.size())); + std::vector received(buf.size()); + EXPECT_THAT(recv(s_, received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + } + + { + // Send data of size min + 1 and verify that its received. Both linux and + // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer + // is currently empty. + std::vector buf(min + 1); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(buf.size())); + + std::vector received(buf.size()); + EXPECT_THAT(recv(s_, received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + } +} + +// Test that receive buffer limits are enforced. +TEST_P(UdpSocketTest, RecvBufLimits) { + // Bind s_ to loopback. + ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + // Now set the limit to min * 4. + int new_rcv_buf_sz = min * 4; + if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { + // Linux doubles the value specified so just set to min * 2. + new_rcv_buf_sz = min * 2; + } + + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, + sizeof(new_rcv_buf_sz)), + SyscallSucceeds()); + int rcv_buf_sz = 0; + { + socklen_t rcv_buf_len = sizeof(rcv_buf_sz); + ASSERT_THAT( + getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, &rcv_buf_len), + SyscallSucceeds()); + } + + { + std::vector buf(min); + RandomizeBuffer(buf.data(), buf.size()); + + ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(buf.size())); + int sent = 4; + if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { + // Linux seems to drop the 4th packet even though technically it should + // fit in the receive buffer. + ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), + SyscallSucceedsWithValue(buf.size())); + sent++; + } + + for (int i = 0; i < sent - 1; i++) { + // Receive the data. + std::vector received(buf.size()); + EXPECT_THAT(recv(s_, received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); + } + + // The last receive should fail with EAGAIN as the last packet should have + // been dropped due to lack of space in the receive buffer. + std::vector received(buf.size()); + EXPECT_THAT(recv(s_, received.data(), received.size(), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 28b8a5cc3ac538333756084da28d7f13f13b5c87 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 18 Jun 2020 17:00:47 -0700 Subject: iptables: remove metadata struct Metadata was useful for debugging and safety, but enough tests exist that we should see failures when (de)serialization is broken. It made stack initialization more cumbersome and it's also getting in the way of ip6tables. PiperOrigin-RevId: 317210653 --- pkg/sentry/socket/netfilter/netfilter.go | 103 ++++++------------------------- pkg/sentry/socket/netstack/stack.go | 6 -- pkg/tcpip/stack/iptables.go | 8 --- pkg/tcpip/stack/iptables_types.go | 16 +---- runsc/boot/loader.go | 2 - 5 files changed, 21 insertions(+), 114 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 66015e2bc..f7abe77d3 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -41,19 +41,6 @@ const errorTargetName = "ERROR" // change the destination port/destination IP for packets. const redirectTargetName = "REDIRECT" -// Metadata is used to verify that we are correctly serializing and -// deserializing iptables into structs consumable by the iptables tool. We save -// a metadata struct when the tables are written, and when they are read out we -// verify that certain fields are the same. -// -// metadata is used by this serialization/deserializing code, not netstack. -type metadata struct { - HookEntry [linux.NF_INET_NUMHOOKS]uint32 - Underflow [linux.NF_INET_NUMHOOKS]uint32 - NumEntries uint32 - Size uint32 -} - // enableLogging controls whether to log the (de)serialization of netfilter // structs between userspace and netstack. These logs are useful when // developing iptables, but can pollute sentry logs otherwise. @@ -83,29 +70,13 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT return linux.IPTGetinfo{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, info.Name) + _, info, err := convertNetstackToBinary(stack, info.Name) if err != nil { - nflog("%v", err) + nflog("couldn't convert iptables: %v", err) return linux.IPTGetinfo{}, syserr.ErrInvalidArgument } - // Get the hooks that apply to this table. - info.ValidHooks = table.ValidHooks() - - // Grab the metadata struct, which is used to store information (e.g. - // the number of entries) that applies to the user's encoding of - // iptables, but not netstack's. - metadata := table.Metadata().(metadata) - - // Set values from metadata. - info.HookEntry = metadata.HookEntry - info.Underflow = metadata.Underflow - info.NumEntries = metadata.NumEntries - info.Size = metadata.Size - nflog("returning info: %+v", info) - return info, nil } @@ -118,23 +89,13 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return linux.KernelIPTGetEntries{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, userEntries.Name) - if err != nil { - nflog("%v", err) - return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument - } - // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table) + entries, _, err := convertNetstackToBinary(stack, userEntries.Name) if err != nil { nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if meta != table.Metadata().(metadata) { - panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta)) - } if binary.Size(entries) > uintptr(outLen) { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument @@ -143,44 +104,26 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - table, ok := stk.IPTables().GetTable(tablename.String()) - if !ok { - return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) - } - return table, nil -} - -// FillIPTablesMetadata populates stack's IPTables with metadata. -func FillIPTablesMetadata(stk *stack.Stack) { - stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) - } - table.SetMetadata(metadata) - tables[name] = table - } - }) -} - // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) { - // Return values. +func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + table, ok := stack.IPTables().GetTable(tablename.String()) + if !ok { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + var entries linux.KernelIPTGetEntries - var meta metadata + var info linux.IPTGetinfo + info.ValidHooks = table.ValidHooks() // The table name has to fit in the struct. if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("table name %q too long.", tablename) + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - copy(entries.Name[:], tablename) + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], tablename[:]) for ruleIdx, rule := range table.Rules { nflog("convert to binary: current offset: %d", entries.Size) @@ -189,14 +132,14 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI for hook, hookRuleIdx := range table.BuiltinChains { if hookRuleIdx == ruleIdx { nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) - meta.HookEntry[hook] = entries.Size + info.HookEntry[hook] = entries.Size } } // Is this a chain underflow point? for underflow, underflowRuleIdx := range table.Underflows { if underflowRuleIdx == ruleIdx { nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) - meta.Underflow[underflow] = entries.Size + info.Underflow[underflow] = entries.Size } } @@ -251,12 +194,12 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI entries.Size += uint32(entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) - meta.NumEntries++ + info.NumEntries++ } - nflog("convert to binary: finished with an marshalled size of %d", meta.Size) - meta.Size = entries.Size - return entries, meta, nil + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + info.Size = entries.Size + return entries, info, nil } func marshalTarget(target stack.Target) []byte { @@ -569,12 +512,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - table.SetMetadata(metadata{ - HookEntry: replace.HookEntry, - Underflow: replace.Underflow, - NumEntries: replace.NumEntries, - Size: replace.Size, - }) stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f97f9b6f3..ee11742a6 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -366,11 +365,6 @@ func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillIPTablesMetadata populates stack's IPTables with metadata. -func (s *Stack) FillIPTablesMetadata() { - netfilter.FillIPTablesMetadata(s.Stack) -} - // Resume implements inet.Stack.Resume. func (s *Stack) Resume() { s.Stack.Resume() diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 4e9b404c8..dc2b77c9d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -173,14 +173,6 @@ func (it *IPTables) ReplaceTable(name string, table Table) { it.tables[name] = table } -// ModifyTables acquires write-lock and calls fn with internal name-to-table -// map. This function can be used to update multiple tables atomically. -func (it *IPTables) ModifyTables(fn func(map[string]Table)) { - it.mu.Lock() - defer it.mu.Unlock() - fn(it.tables) -} - // GetPriorities returns slice of priorities associated with hook. func (it *IPTables) GetPriorities(hook Hook) []string { it.mu.RLock() diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4a6a5c6f1..72f1dd329 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -95,7 +95,7 @@ type IPTables struct { } // A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules with some metadata for entrypoints and such. +// really just a list of rules. type Table struct { // Rules holds the rules that make up the table. Rules []Rule @@ -110,10 +110,6 @@ type Table struct { // UserChains holds user-defined chains for the keyed by name. Users // can give their chains arbitrary names. UserChains map[string]int - - // Metadata holds information about the Table that is useful to users - // of IPTables, but not to the netstack IPTables code itself. - metadata interface{} } // ValidHooks returns a bitmap of the builtin hooks for the given table. @@ -125,16 +121,6 @@ func (table *Table) ValidHooks() uint32 { return hooks } -// Metadata returns the metadata object stored in table. -func (table *Table) Metadata() interface{} { - return table.metadata -} - -// SetMetadata sets the metadata object stored in table. -func (table *Table) SetMetadata(metadata interface{}) { - table.metadata = metadata -} - // A Rule is a packet processing rule. It consists of two pieces. First it // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index c6efcdc83..081db39c1 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1071,8 +1071,6 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) } - s.FillIPTablesMetadata() - return &s, nil } -- cgit v1.2.3 From d962f9f3842c5c352bc61411cf27e38ba2219317 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 19 Jun 2020 11:41:37 -0700 Subject: Implement UDP cheksum verification. Test: - TestIncrementChecksumErrors Fixes #2943 PiperOrigin-RevId: 317348158 --- pkg/sentry/socket/netstack/netstack.go | 1 + pkg/sentry/socket/netstack/stack.go | 2 +- pkg/tcpip/tcpip.go | 6 + pkg/tcpip/transport/udp/endpoint.go | 18 ++ pkg/tcpip/transport/udp/udp_test.go | 320 +++++++++++++++++++++++++-------- 5 files changed, 276 insertions(+), 71 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 738277391..c0b63a803 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -191,6 +191,7 @@ var Metrics = tcpip.Stats{ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), + ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), }, } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index ee11742a6..d2fb655ea 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -313,7 +313,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { udp.PacketsSent.Value(), // OutDatagrams. udp.ReceiveBufferErrors.Value(), // RcvbufErrors. 0, // Udp/SndbufErrors. - 0, // Udp/InCsumErrors. + udp.ChecksumErrors.Value(), // Udp/InCsumErrors. 0, // Udp/IgnoredMulti. } default: diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3ad130b23..956232a44 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1224,6 +1224,9 @@ type UDPStats struct { // PacketSendErrors is the number of datagrams failed to be sent. PacketSendErrors *StatCounter + + // ChecksumErrors is the number of datagrams dropped due to bad checksums. + ChecksumErrors *StatCounter } // Stats holds statistics about the networking stack. @@ -1267,6 +1270,9 @@ type ReceiveErrors struct { // ClosedReceiver is the number of received packets dropped because // of receiving endpoint state being closed. ClosedReceiver StatCounter + + // ChecksumErrors is the number of packets dropped due to bad checksums. + ChecksumErrors StatCounter } // SendErrors collects packet send errors within the transport layer for diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f51988047..40d66ef09 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1350,6 +1350,24 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } + // Verify checksum unless RX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value means + // the transmitter omitted the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && + (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + if hdr.CalculateChecksum(xsum) != 0xffff { + // Checksum Error. + e.stack.Stats().UDP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() + return + } + } + e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 313a3f117..ff9f60cf9 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -292,15 +292,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio wep = sniffer.New(ep) } if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } s.SetRouteTable([]tcpip.Route{ @@ -391,17 +391,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } else { - c.injectV6Packet(payload, &h, true /* valid */) + buf := c.buildV6Packet(payload, &h) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } } -// injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV6Packet creates a V6 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) payloadStart := len(buf) - len(payload) @@ -420,16 +424,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) - l := uint16(header.UDPMinimumSize + len(payload)) - if !valid { - // Change the UDP payload length to corrupt the header - // as requested by the caller. - l++ - } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: l, + Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. @@ -439,17 +437,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + return buf } -// injectV4Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV4Packet creates a V4 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) payloadStart := len(buf) - len(payload) @@ -483,11 +476,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + return buf } func newPayload() []byte { @@ -509,7 +498,7 @@ func TestBindToDeviceOption(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() @@ -643,7 +632,7 @@ func TestBindEphemeralPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } } @@ -654,19 +643,19 @@ func TestBindReservedPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } addr, err := c.ep.GetLocalAddress() if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) + t.Fatalf("GetLocalAddress failed: %s", err) } // We can't bind the address reserved by the connected endpoint above. { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { @@ -677,7 +666,7 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint @@ -687,7 +676,7 @@ func TestBindReservedPort(t *testing.T) { } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() @@ -697,11 +686,11 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() } @@ -714,7 +703,7 @@ func TestV4ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -729,7 +718,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { // Bind to v4 mapped wildcard. if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -744,7 +733,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -759,7 +748,7 @@ func TestV6ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -796,7 +785,10 @@ func TestV4ReadSelfSource(t *testing.T) { h := unicastV4.header4Tuple(incoming) h.srcAddr = h.dstAddr - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) @@ -817,7 +809,7 @@ func TestV4ReadOnV4(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -955,7 +947,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... payload := buffer.View(newPayload()) n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { - c.t.Fatalf("Write failed: %v", err) + c.t.Fatalf("Write failed: %s", err) } if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) @@ -1005,7 +997,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } p := testDualWrite(c) @@ -1022,7 +1014,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV6) @@ -1043,7 +1035,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV4in6) @@ -1070,7 +1062,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Write to v6 address. @@ -1085,7 +1077,7 @@ func TestV6WriteOnConnected(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV6) @@ -1099,7 +1091,7 @@ func TestV4WriteOnConnected(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV4) @@ -1234,7 +1226,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testRead(c, unicastV4) @@ -1506,12 +1498,12 @@ func TestMulticastInterfaceOption(t *testing.T) { Port: stackPort, } if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } } if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOpt failed: %s", err) } // Verify multicast interface addr and NIC were set correctly. @@ -1519,7 +1511,7 @@ func TestMulticastInterfaceOption(t *testing.T) { ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} var ifoptGot tcpip.MulticastInterfaceOption if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) + c.t.Fatalf("GetSockOpt failed: %s", err) } if ifoptGot != ifoptWant { c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) @@ -1691,7 +1683,7 @@ func TestV6UnknownDestination(t *testing.T) { } // TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats get incremented. +// global and endpoint stats are incremented. func TestIncrementMalformedPacketsReceived(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1699,20 +1691,25 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } payload := newPayload() - c.t.Helper() h := unicastV6.header4Tuple(incoming) - c.injectV6Packet(payload, &h, false /* !valid */) + buf := c.buildV6Packet(payload, &h) + // Invalidate the packet length field in the UDP header by adding one. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetLength(u.Length() + 1) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) - var want uint64 = 1 + const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) } if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) } } @@ -1728,7 +1725,6 @@ func TestShortHeader(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - c.t.Helper() h := unicastV6.header4Tuple(incoming) // Allocate a buffer for an IPv6 and too-short UDP header. @@ -1768,6 +1764,190 @@ func TestShortHeader(t *testing.T) { } } +// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Invalidate the checksum field in the UDP header by adding one. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Invalidate the checksum field in the UDP header by adding one. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV4 verifies if the checksum value is zero, global and +// endpoint states are *not* incremented (UDP checksum is optional on IPv4). +func TestChecksumZeroV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 0 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV6 verifies if the checksum value is zero, global and +// endpoint states are incremented (UDP checksum is *not* optional on IPv6). +func TestChecksumZeroV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + // TestShutdownRead verifies endpoint read shutdown and error // stats increment on packet receive. func TestShutdownRead(t *testing.T) { @@ -1778,15 +1958,15 @@ func TestShutdownRead(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingRead(c, unicastV6, true /* expectReadError */) @@ -1809,11 +1989,11 @@ func TestShutdownWrite(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) -- cgit v1.2.3 From b070e218c6fe61c6ef98e0a3af5ad58d7e627632 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 24 Jun 2020 10:21:44 -0700 Subject: Add support for Stack level options. Linux controls socket send/receive buffers using a few sysctl variables - net.core.rmem_default - net.core.rmem_max - net.core.wmem_max - net.core.wmem_default - net.ipv4.tcp_rmem - net.ipv4.tcp_wmem The first 4 control the default socket buffer sizes for all sockets raw/packet/tcp/udp and also the maximum permitted socket buffer that can be specified in setsockopt(SOL_SOCKET, SO_(RCV|SND)BUF,...). The last two control the TCP auto-tuning limits and override the default specified in rmem_default/wmem_default as well as the max limits. Netstack today only implements tcp_rmem/tcp_wmem and incorrectly uses it to limit the maximum size in setsockopt() as well as uses it for raw/udp sockets. This changelist introduces the other 4 and updates the udp/raw sockets to use the newly introduced variables. The values for min/max match the current tcp_rmem/wmem values and the default value buffers for UDP/RAW sockets is updated to match the linux value of 212KiB up from the really low current value of 32 KiB. Updates #3043 Fixes #3043 PiperOrigin-RevId: 318089805 --- benchmarks/tcp/tcp_proxy.go | 2 +- pkg/sentry/socket/netstack/stack.go | 12 +-- pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/stack.go | 18 ++++ pkg/tcpip/stack/stack_options.go | 106 +++++++++++++++++++++ pkg/tcpip/stack/stack_test.go | 80 ++++++++++++++++ pkg/tcpip/tcpip.go | 27 +----- pkg/tcpip/transport/raw/endpoint.go | 20 ++-- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 16 ++-- pkg/tcpip/transport/tcp/endpoint_state.go | 10 +- pkg/tcpip/transport/tcp/protocol.go | 53 ++++++++--- pkg/tcpip/transport/tcp/tcp_sack_test.go | 4 +- pkg/tcpip/transport/tcp/tcp_test.go | 18 ++-- pkg/tcpip/transport/tcp/testing/context/context.go | 6 +- pkg/tcpip/transport/udp/endpoint.go | 20 ++-- pkg/tcpip/transport/udp/protocol.go | 49 +--------- runsc/boot/loader.go | 2 +- 18 files changed, 306 insertions(+), 140 deletions(-) create mode 100644 pkg/tcpip/stack/stack_options.go (limited to 'pkg/sentry/socket/netstack') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index b3a4dbea3..4b7ca7a14 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -228,7 +228,7 @@ func newNetstackImpl(mode string) (impl, error) { }) // Set protocol options. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(*sack)); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err) } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index d2fb655ea..548442b96 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -143,7 +143,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - var rs tcpip.StackReceiveBufferSizeOption + var rs tcp.ReceiveBufferSizeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs) return inet.TCPBufferSize{ Min: rs.Min, @@ -154,7 +154,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - rs := tcpip.StackReceiveBufferSizeOption{ + rs := tcp.ReceiveBufferSizeOption{ Min: size.Min, Default: size.Default, Max: size.Max, @@ -164,7 +164,7 @@ func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { // TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - var ss tcpip.StackSendBufferSizeOption + var ss tcp.SendBufferSizeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss) return inet.TCPBufferSize{ Min: ss.Min, @@ -175,7 +175,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - ss := tcpip.StackSendBufferSizeOption{ + ss := tcp.SendBufferSizeOption{ Min: size.Min, Default: size.Default, Max: size.Max, @@ -185,14 +185,14 @@ func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { // TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. func (s *Stack) TCPSACKEnabled() (bool, error) { - var sack tcpip.StackSACKEnabled + var sack tcp.SACKEnabled err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack) return bool(sack), syserr.TranslateNetstackError(err).ToError() } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enabled))).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } // Statistics implements inet.Stack.Statistics. diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 24f52b735..794ddb5c8 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -48,6 +48,7 @@ go_library( "route.go", "stack.go", "stack_global_state.go", + "stack_options.go", "transport_demuxer.go", ], visibility = ["//visibility:public"], diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e92ec0c24..cdcfb8321 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -471,6 +471,14 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. randomGenerator *mathrand.Rand + + // sendBufferSize holds the min/default/max send buffer sizes for + // endpoints other than TCP. + sendBufferSize SendBufferSizeOption + + // receiveBufferSize holds the min/default/max receive buffer sizes for + // endpoints other than TCP. + receiveBufferSize ReceiveBufferSizeOption } // UniqueID is an abstract generator of unique identifiers. @@ -683,6 +691,16 @@ func New(opts Options) *Stack { tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, + receiveBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, } // Add specified network protocols. diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go new file mode 100644 index 000000000..0b093e6c5 --- /dev/null +++ b/pkg/tcpip/stack/stack_options.go @@ -0,0 +1,106 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "gvisor.dev/gvisor/pkg/tcpip" + +const ( + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4 KiB + + // DefaultBufferSize is the default size of the send/recv buffer for a + // transport endpoint. + DefaultBufferSize = 212 << 10 // 212 KiB + + // DefaultMaxBufferSize is the default maximum permitted size of a + // send/receive buffer. + DefaultMaxBufferSize = 4 << 20 // 4 MiB +) + +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +// SetOption allows setting stack wide options. +func (s *Stack) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SendBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.sendBufferSize = v + s.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.receiveBufferSize = v + s.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option allows retrieving stack wide options. +func (s *Stack) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SendBufferSizeOption: + s.mu.RLock() + *v = s.sendBufferSize + s.mu.RUnlock() + return nil + + case *ReceiveBufferSizeOption: + s.mu.RLock() + *v = s.receiveBufferSize + s.mu.RUnlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 5aacbf53e..7657a4101 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3338,3 +3338,83 @@ func TestDoDADWhenNICEnabled(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) } } + +func TestStackReceiveBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + rs stack.ReceiveBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.rs); err != tc.err { + t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) + } + var rs stack.ReceiveBufferSizeOption + if tc.err == nil { + if err := s.Option(&rs); err != nil { + t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) + } + if got, want := rs, tc.rs; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestStackSendBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + ss stack.SendBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.ss); err != tc.err { + t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + } + var ss stack.SendBufferSizeOption + if tc.err == nil { + if err := s.Option(&ss); err != nil { + t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) + } + if got, want := ss, tc.ss; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 956232a44..4d45dcc42 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -813,33 +813,8 @@ type OutOfBandInlineOption int // a default TTL. type DefaultTTLOption uint8 -// StackSACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type StackSACKEnabled bool - -// StackDelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type StackDelayEnabled bool - -// StackSendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max send buffer sizes. -type StackSendBufferSizeOption struct { - Min int - Default int - Max int -} - -// StackReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// receive buffer sizes. -type StackReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - // -// IPPacketInfo is the message struture for IP_PKTINFO. +// IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable type IPPacketInfo struct { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 6a7977259..dd514d397 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -111,13 +111,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt } // Override with stack defaults. - var ss tcpip.StackSendBufferSizeOption - if err := s.TransportProtocolOption(transProto, &ss); err == nil { + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { e.sndBufSizeMax = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption - if err := s.TransportProtocolOption(transProto, &rs); err == nil { + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { e.rcvBufSizeMax = rs.Default } @@ -541,9 +541,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption - if err := e.stack.TransportProtocolOption(e.TransProto, &ss); err != nil { - panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, ss, err)) + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) } if v > ss.Max { v = ss.Max @@ -559,9 +559,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(e.TransProto, &rs); err != nil { - panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, rs, err)) + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) } if v > rs.Max { v = rs.Max diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 377643b82..9d4dce826 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -521,7 +521,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled tcpip.StackSACKEnabled + var sackEnabled SACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1e4c2f507..99a691815 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -847,12 +847,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -867,7 +867,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } - var de tcpip.StackDelayEnabled + var de DelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -1584,7 +1584,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min { v = rs.Min @@ -1634,7 +1634,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if v < ss.Min { v = ss.Min @@ -1674,7 +1674,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -2595,7 +2595,7 @@ func (e *endpoint) receiveBufferSize() int { } func (e *endpoint) maxReceiveBufferSize() int { - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2676,7 +2676,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v tcpip.StackSACKEnabled + var v SACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 8258c0ecc..abf1ac5c9 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -182,13 +182,17 @@ func (e *endpoint) Resume(s *stack.Stack) { epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) } } } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 3cff55afa..f2ae6ce50 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -76,6 +76,31 @@ const ( ccCubic = "cubic" ) +// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to +// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to +// enable/disable Nagle's algorithm in TCP. +type DelayEnabled bool + +// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption +// to get/set the default, min and max TCP send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by +// stack.(Stack*).TransportProtocolOption to get/set the default, min and max +// TCP receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -137,8 +162,8 @@ type protocol struct { mu sync.RWMutex sackEnabled bool delayEnabled bool - sendBufferSize tcpip.StackSendBufferSizeOption - recvBufferSize tcpip.StackReceiveBufferSizeOption + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -249,19 +274,19 @@ func replyWithReset(s *segment, tos, ttl uint8) { // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { - case tcpip.StackSACKEnabled: + case SACKEnabled: p.mu.Lock() p.sackEnabled = bool(v) p.mu.Unlock() return nil - case tcpip.StackDelayEnabled: + case DelayEnabled: p.mu.Lock() p.delayEnabled = bool(v) p.mu.Unlock() return nil - case tcpip.StackSendBufferSizeOption: + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -270,7 +295,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case tcpip.StackReceiveBufferSizeOption: + case ReceiveBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -363,25 +388,25 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { - case *tcpip.StackSACKEnabled: + case *SACKEnabled: p.mu.RLock() - *v = tcpip.StackSACKEnabled(p.sackEnabled) + *v = SACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *tcpip.StackDelayEnabled: + case *DelayEnabled: p.mu.RLock() - *v = tcpip.StackDelayEnabled(p.delayEnabled) + *v = DelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *tcpip.StackSendBufferSizeOption: + case *SendBufferSizeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *tcpip.StackReceiveBufferSizeOption: + case *ReceiveBufferSizeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -491,12 +516,12 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ - sendBufferSize: tcpip.StackSendBufferSizeOption{ + sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: tcpip.StackReceiveBufferSizeOption{ + recvBufferSize: ReceiveBufferSizeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 812e503bc..99521f0c1 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,8 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, StackSACKEnabled(%t) = %s", enable, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2632a3c67..169adb16b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4005,7 +4005,7 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{ + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ Min: 1, Default: tcp.DefaultSendBufferSize * 2, Max: tcp.DefaultSendBufferSize * 20}); err != nil { @@ -4022,7 +4022,7 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{ + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ Min: 1, Default: tcp.DefaultReceiveBufferSize * 3, Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { @@ -4053,11 +4053,11 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5696,7 +5696,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5817,7 +5817,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5959,7 +5959,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcpip.StackDelayEnabled + delayEnabled tcp.DelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -5974,10 +5974,10 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.StackDelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcpip.StackDelayEnabled + var gotDelayEnabled tcp.DelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 9e262c272..06fde2a79 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -144,11 +144,11 @@ func New(t *testing.T, mtu uint32) *Context { }) // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -1091,7 +1091,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcpip.StackSACKEnabled + var v tcp.SACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 6ea212093..8bdc1ee1f 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -190,13 +190,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } // Override with stack defaults. - var ss tcpip.StackSendBufferSizeOption - if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { e.sndBufSizeMax = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption - if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { e.rcvBufSizeMax = rs.Default } @@ -629,9 +629,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, rs, err)) + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) } if v < rs.Min { @@ -648,9 +648,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, ss, err)) + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) } if v < ss.Min { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index fc93f93c0..0e7464e3a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -21,7 +21,6 @@ package udp import ( - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,9 +49,6 @@ const ( ) type protocol struct { - mu sync.RWMutex - sendBufferSize tcpip.StackSendBufferSizeOption - recvBufferSize tcpip.StackReceiveBufferSizeOption } // Number returns the udp protocol number. @@ -203,48 +199,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case tcpip.StackSendBufferSizeOption: - if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue - } - p.mu.Lock() - p.sendBufferSize = v - p.mu.Unlock() - return nil - - case tcpip.StackReceiveBufferSizeOption: - if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue - } - p.mu.Lock() - p.recvBufferSize = v - p.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *tcpip.StackSendBufferSizeOption: - p.mu.RLock() - *v = p.sendBufferSize - p.mu.RUnlock() - return nil - - case *tcpip.StackReceiveBufferSizeOption: - p.mu.RLock() - *v = p.recvBufferSize - p.mu.RUnlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // Close implements stack.TransportProtocol.Close. @@ -267,8 +227,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ - sendBufferSize: tcpip.StackSendBufferSizeOption{Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize}, - recvBufferSize: tcpip.StackReceiveBufferSizeOption{Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize}, - } + return &protocol{} } diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 081db39c1..b5df1deb9 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1058,7 +1058,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in })} // Enable SACK Recovery. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(true)); err != nil { + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { return nil, fmt.Errorf("failed to enable SACK: %s", err) } -- cgit v1.2.3 From 8dbeac53ce1b3c1cf4a5f2f0ccdd7196f4656fd8 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 26 Jun 2020 17:49:32 -0700 Subject: Implement SO_NO_CHECK socket option. SO_NO_CHECK is used to skip the UDP checksum generation on a TX socket (UDP checksum is optional on IPv4). Test: - TestNoChecksum - SoNoCheckOffByDefault (UdpSocketTest) - SoNoCheck (UdpSocketTest) Fixes #3055 PiperOrigin-RevId: 318575215 --- pkg/sentry/socket/netstack/netstack.go | 19 +++++ pkg/sentry/socket/socket.go | 1 - pkg/tcpip/checker/checker.go | 16 +++++ pkg/tcpip/tcpip.go | 100 +++++++++++++++------------ pkg/tcpip/transport/udp/endpoint.go | 24 +++++-- pkg/tcpip/transport/udp/udp_test.go | 24 +++++++ test/syscalls/linux/udp_socket_test_cases.cc | 38 ++++++++++ 7 files changed, 173 insertions(+), 49 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index c0b63a803..e7d2c83d7 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1169,6 +1169,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return int32(v), nil + case linux.SO_NO_CHECK: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + return boolToInt32(v), nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } @@ -1720,6 +1731,14 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + case linux.SO_NO_CHECK: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 6580bd6e9..fcd7f9d7f 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -407,7 +407,6 @@ func emitUnimplementedEvent(t *kernel.Task, name int) { linux.SO_MARK, linux.SO_MAX_PACING_RATE, linux.SO_NOFCS, - linux.SO_NO_CHECK, linux.SO_OOBINLINE, linux.SO_PASSCRED, linux.SO_PASSSEC, diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index c1745ba6a..ee264b726 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -320,6 +320,22 @@ func DstPort(port uint16) TransportChecker { } } +// NoChecksum creates a checker that checks if the checksum is zero. +func NoChecksum(noChecksum bool) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + udp, ok := h.(header.UDP) + if !ok { + return + } + + if b := udp.Checksum() == 0; b != noChecksum { + t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) + } + } +} + // SeqNum creates a checker that checks the sequence number. func SeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4d45dcc42..2be1c107a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -585,59 +585,68 @@ type WriteOptions struct { type SockOptBool int const ( - // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether - // datagram sockets are allowed to send packets to a broadcast address. + // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify + // whether datagram sockets are allowed to send packets to a broadcast + // address. BroadcastOption SockOptBool = iota - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be - // held until segments are full by the TCP transport protocol. + // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if + // data should be held until segments are full by the TCP transport + // protocol. CorkOption - // DelayOption is used by SetSockOpt/GetSockOpt to specify if data - // should be sent out immediately by the transport protocol. For TCP, - // it determines if the Nagle algorithm is on or off. + // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if + // data should be sent out immediately by the transport protocol. For + // TCP, it determines if the Nagle algorithm is on or off. DelayOption - // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether - // TCP keepalive is enabled for this socket. + // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to + // specify whether TCP keepalive is enabled for this socket. KeepaliveEnabledOption - // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether - // multicast packets sent over a non-loopback interface will be looped back. + // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to + // specify whether multicast packets sent over a non-loopback interface + // will be looped back. MulticastLoopOption - // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether - // SCM_CREDENTIALS socket control messages are enabled. + // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify + // whether UDP checksum is disabled for this socket. + NoChecksumOption + + // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify + // whether SCM_CREDENTIALS socket control messages are enabled. // // Only supported on Unix sockets. PasscredOption - // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. + // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. QuickAckOption - // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the - // IPV6_TCLASS ancillary message is passed with incoming packets. + // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to + // specify if the IPV6_TCLASS ancillary message is passed with incoming + // packets. ReceiveTClassOption - // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS - // ancillary message is passed with incoming packets. + // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify + // if the TOS ancillary message is passed with incoming packets. ReceiveTOSOption - // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify - // if more inforamtion is provided with incoming packets such - // as interface index and address. + // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to + // specify if more inforamtion is provided with incoming packets such as + // interface index and address. ReceiveIPPacketInfoOption - // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() - // should allow reuse of local address. + // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to + // specify whether Bind() should allow reuse of local address. ReuseAddressOption - // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets - // to be bound to an identical socket address. + // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit + // multiple sockets to be bound to an identical socket address. ReusePortOption - // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 - // socket is to be restricted to sending and receiving IPv6 packets only. + // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify + // whether an IPv6 socket is to be restricted to sending and receiving + // IPv6 packets only. V6OnlyOption ) @@ -645,25 +654,27 @@ const ( type SockOptInt int const ( - // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number - // of un-ACKed TCP keepalives that will be sent before the connection is - // closed. + // KeepaliveCountOption is used by SetSockOptInt/GetSockOptInt to + // specify the number of un-ACKed TCP keepalives that will be sent + // before the connection is closed. KeepaliveCountOption SockOptInt = iota - // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS + // IPv4TOSOption is used by SetSockOptInt/GetSockOptInt to specify TOS // for all subsequent outgoing IPv4 packets from the endpoint. IPv4TOSOption - // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS - // for all subsequent outgoing IPv6 packets from the endpoint. + // IPv6TrafficClassOption is used by SetSockOptInt/GetSockOptInt to + // specify TOS for all subsequent outgoing IPv6 packets from the + // endpoint. IPv6TrafficClassOption - // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current - // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. + // MaxSegOption is used by SetSockOptInt/GetSockOptInt to set/get the + // current Maximum Segment Size(MSS) value as specified using the + // TCP_MAXSEG option. MaxSegOption - // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default - // TTL value for multicast messages. The default is 1. + // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control + // the default TTL value for multicast messages. The default is 1. MulticastTTLOption // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the @@ -682,21 +693,22 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption - // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop - // limit value for unicast messages. The default is protocol specific. + // TTLOption is used by SetSockOptInt/GetSockOptInt to control the + // default TTL/hop limit value for unicast messages. The default is + // protocol specific. // // A zero value indicates the default. TTLOption - // TCPSynCountOption is used by SetSockOpt/GetSockOpt to specify the number of - // SYN retransmits that TCP should send before aborting the attempt to - // connect. It cannot exceed 255. + // TCPSynCountOption is used by SetSockOptInt/GetSockOptInt to specify + // the number of SYN retransmits that TCP should send before aborting + // the attempt to connect. It cannot exceed 255. // // NOTE: This option is currently only stubbed out and is no-op. TCPSynCountOption - // TCPWindowClampOption is used by SetSockOpt/GetSockOpt to bound the size - // of the advertised window to this value. + // TCPWindowClampOption is used by SetSockOptInt/GetSockOptInt to bound + // the size of the advertised window to this value. // // NOTE: This option is currently only stubed out and is a no-op TCPWindowClampOption diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 8bdc1ee1f..cae29fbff 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -109,6 +109,7 @@ type endpoint struct { portFlags ports.Flags bindToDevice tcpip.NICID broadcast bool + noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -529,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -553,6 +554,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.multicastLoop = v e.mu.Unlock() + case tcpip.NoChecksumOption: + e.mu.Lock() + e.noChecksum = v + e.mu.Unlock() + case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v @@ -825,6 +831,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.NoChecksumOption: + e.mu.RLock() + v := e.noChecksum + e.mu.RUnlock() + return v, nil + case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS @@ -959,7 +971,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -973,8 +985,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Length: length, }) - // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { xsum = header.Checksum(v, xsum) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index ff9f60cf9..db59eb5a0 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1251,6 +1251,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } +func TestNoChecksum(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Disable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + // This option is effective on IPv4 only. + testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) + + // Enable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) + }) + } +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index cc1db3de8..1d13432ca 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1239,6 +1239,44 @@ TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { EXPECT_EQ(n, 0); } +TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoNoCheck) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = kSockOptOn; + socklen_t optlen = sizeof(v); + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOn); + ASSERT_EQ(optlen, sizeof(v)); + + v = kSockOptOff; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + TEST_P(UdpSocketTest, SoTimestampOffByDefault) { // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by // hostinet. -- cgit v1.2.3 From 5946f111827fa4e342a2e6e9c043c198d2e5cb03 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 9 Jul 2020 16:24:43 -0700 Subject: Add support for IP_HDRINCL IP option for raw sockets. Updates #2746 Fixes #3158 PiperOrigin-RevId: 320497190 --- pkg/sentry/socket/netstack/netstack.go | 11 ++++- pkg/tcpip/tcpip.go | 5 +++ pkg/tcpip/transport/raw/endpoint.go | 34 +++++++++++---- test/syscalls/linux/raw_socket_hdrincl.cc | 70 ++++++++++++++++++++++++++++++- 4 files changed, 110 insertions(+), 10 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e7d2c83d7..3b248a953 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2112,13 +2112,22 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_HDRINCL: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, linux.IP_CHECKSUM, linux.IP_DROP_SOURCE_MEMBERSHIP, linux.IP_FREEBIND, - linux.IP_HDRINCL, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2f9872dc6..25534a10d 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -648,6 +648,11 @@ const ( // whether an IPv6 socket is to be restricted to sending and receiving // IPv6 packets only. V6OnlyOption + + // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw + // endpoint that all packets being written have an IP header and the + // endpoint should not attach an IP header. + IPHdrIncludedOption ) // SockOptInt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 766c7648e..5b6e7d102 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -63,6 +63,7 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -108,6 +109,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, } // Override with stack defaults. @@ -182,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -263,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -353,7 +351,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if !e.associated { + if e.hdrIncluded { if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { @@ -513,6 +511,13 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } @@ -577,6 +582,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.KeepaliveEnabledOption: return false, nil + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -616,8 +627,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc index 16cfc1d75..5bb14d57c 100644 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -167,7 +167,7 @@ TEST_F(RawHDRINCL, NotReadable) { // nothing to be read. char buf[117]; ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EINVAL)); + SyscallFailsWithErrno(EAGAIN)); } // Test that we can connect() to a valid IP (loopback). @@ -332,6 +332,74 @@ TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); } +// Send and receive a packet w/ the IP_HDRINCL option set. +TEST_F(RawHDRINCL, SendAndReceiveIPHdrIncl) { + int port = 40000; + if (!IsRunningOnGvisor()) { + port = static_cast(ASSERT_NO_ERRNO_AND_VALUE( + PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); + } + + FileDescriptor recv_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + FileDescriptor send_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + // Enable IP_HDRINCL option so that we can build and send w/ an IP + // header. + constexpr int kSockOptOn = 1; + ASSERT_THAT(setsockopt(send_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + // This is not strictly required but we do it to make sure that setting + // IP_HDRINCL on a non IPPROTO_RAW socket does not prevent it from receiving + // packets. + ASSERT_THAT(setsockopt(recv_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Construct a packet with an IP header, UDP header, and payload. + constexpr char kPayload[] = "toto"; + char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; + ASSERT_TRUE( + FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); + + socklen_t addrlen = sizeof(addr_); + ASSERT_NO_FATAL_FAILURE(sendto(send_sock.get(), &packet, sizeof(packet), 0, + reinterpret_cast(&addr_), + addrlen)); + + // Receive the payload. + char recv_buf[sizeof(packet)]; + struct sockaddr_in src; + socklen_t src_size = sizeof(src); + ASSERT_THAT(recvfrom(recv_sock.get(), recv_buf, sizeof(recv_buf), 0, + reinterpret_cast(&src), &src_size), + SyscallSucceedsWithValue(sizeof(packet))); + EXPECT_EQ( + memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr), + sizeof(kPayload)), + 0); + // The network stack should have set the source address. + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); + struct iphdr iphdr = {}; + memcpy(&iphdr, recv_buf, sizeof(iphdr)); + EXPECT_NE(iphdr.id, 0); + + // Also verify that the packet we just sent was not delivered to the + // IPPROTO_RAW socket. + { + char recv_buf[sizeof(packet)]; + struct sockaddr_in src; + socklen_t src_size = sizeof(src); + ASSERT_THAT(recvfrom(socket_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast(&src), &src_size), + SyscallFailsWithErrno(EAGAIN)); + } +} + } // namespace } // namespace testing -- cgit v1.2.3 From 5df3a8fedef7e54550d4c6b7172e25216600ee9f Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 9 Jul 2020 22:33:53 -0700 Subject: Discard multicast UDP source address. RFC-1122 (and others) specify that UDP should not receive datagrams that have a source address that is a multicast address. Packets should never be received FROM a multicast address. See also, RFC 768: 'User Datagram Protocol' J. Postel, ISI, 28 August 1980 A UDP datagram received with an invalid IP source address (e.g., a broadcast or multicast address) must be discarded by UDP or by the IP layer (see rfc 1122 Section 3.2.1.3). This CL does not address TCP or broadcast which is more complicated. Also adds a test for both ipv6 and ipv4 UDP. Fixes #3154 PiperOrigin-RevId: 320547674 --- pkg/sentry/socket/netstack/netstack.go | 1 + pkg/tcpip/tcpip.go | 5 +- pkg/tcpip/transport/udp/endpoint.go | 11 +++- pkg/tcpip/transport/udp/udp_test.go | 104 ++++++++++++++++++++++++++++----- 4 files changed, 103 insertions(+), 18 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 3b248a953..5a3cedd7c 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -192,6 +192,7 @@ var Metrics = tcpip.Stats{ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), + InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 25534a10d..cf7291d09 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -782,7 +782,7 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// buffer moderation. +// ModerateReceiveBufferOption is used by buffer moderation. type ModerateReceiveBufferOption bool // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the @@ -1244,6 +1244,9 @@ type UDPStats struct { // ChecksumErrors is the number of datagrams dropped due to bad checksums. ChecksumErrors *StatCounter + + // InvalidSourceAddress is the number of invalid sourced datagrams dropped. + InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0584ec8dc..4e9e114a9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1377,6 +1377,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + // Verify checksum unless RX checksum offload is enabled. // On IPv4, UDP checksum is optional, and a zero value means // the transmitter omitted the checksum generation (RFC768). @@ -1395,10 +1404,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } } - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 91ba031fa..90781cf49 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -872,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { -- cgit v1.2.3 From 216dcebc066c82907b0de790a77a3deb6a734805 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Sat, 11 Jul 2020 06:21:34 -0700 Subject: Stub out SO_DETACH_FILTER. Updates #2746 PiperOrigin-RevId: 320757963 --- pkg/sentry/socket/netstack/netstack.go | 5 +++ pkg/tcpip/tcpip.go | 5 ++- pkg/tcpip/transport/icmp/endpoint.go | 4 ++ pkg/tcpip/transport/packet/endpoint.go | 8 +++- pkg/tcpip/transport/raw/endpoint.go | 8 +++- pkg/tcpip/transport/tcp/endpoint.go | 3 ++ pkg/tcpip/transport/udp/endpoint.go | 3 ++ test/syscalls/linux/BUILD | 3 ++ test/syscalls/linux/packet_socket_raw.cc | 34 ++++++++++++++++ test/syscalls/linux/raw_socket.cc | 37 +++++++++++++++-- test/syscalls/linux/tcp_socket.cc | 60 ++++++++++++++++++++++++++++ test/syscalls/linux/udp_socket_test_cases.cc | 54 +++++++++++++++++++++++++ 12 files changed, 217 insertions(+), 7 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 5a3cedd7c..78a842973 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1754,6 +1754,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return nil + case linux.SO_DETACH_FILTER: + // optval is ignored. + var v tcpip.SocketDetachFilterOption + return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index cf7291d09..71bcee785 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -855,7 +855,10 @@ type OutOfBandInlineOption int // a default TTL. type DefaultTTLOption uint8 -// +// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached +// classic BPF filter on a given endpoint. +type SocketDetachFilterOption int + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 62d1acad4..678f4e016 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + } return nil } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index a8f8454dd..57b7f5c19 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -278,7 +278,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5b6e7d102..c2e9fd29f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -506,7 +506,13 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index caac6ef57..83dc10ed0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1792,6 +1792,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.deferAccept = time.Duration(v) e.UnlockUser() + case tcpip.SocketDetachFilterOption: + return nil + default: return nil } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4e9e114a9..a14643ae8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -816,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() + + case tcpip.SocketDetachFilterOption: + return nil } return nil } diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 9e097c888..662d780d8 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1330,6 +1330,7 @@ cc_binary( name = "packet_socket_raw_test", testonly = 1, srcs = ["packet_socket_raw.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", @@ -1809,6 +1810,7 @@ cc_binary( name = "raw_socket_test", testonly = 1, srcs = ["raw_socket.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", @@ -3407,6 +3409,7 @@ cc_binary( name = "tcp_socket_test", testonly = 1, srcs = ["tcp_socket.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 4093ac813..6a963b12c 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -14,6 +14,9 @@ #include #include +#ifndef __fuchsia__ +#include +#endif // __fuchsia__ #include #include #include @@ -556,6 +559,37 @@ TEST_P(RawPacketTest, SetSocketSendBuf) { ASSERT_EQ(quarter_sz, val); } +#ifndef __fuchsia__ + +TEST_P(RawPacketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + // + // gVisor returns no error on SO_DETACH_FILTER even if there is no filter + // attached unlike linux which does return ENOENT in such cases. This is + // because gVisor doesn't support SO_ATTACH_FILTER and just silently returns + // success. + if (IsRunningOnGvisor()) { + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); + return; + } + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(RawPacketTest, GetSocketDetachFilter) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 05c4ed03f..ce54dc064 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -13,6 +13,9 @@ // limitations under the License. #include +#ifndef __fuchsia__ +#include +#endif // __fuchsia__ #include #include #include @@ -21,6 +24,7 @@ #include #include #include + #include #include "gtest/gtest.h" @@ -790,10 +794,30 @@ void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf, ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len)); } -INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest, - ::testing::Combine( - ::testing::Values(IPPROTO_TCP, IPPROTO_UDP), - ::testing::Values(AF_INET, AF_INET6))); +#ifndef __fuchsia__ + +TEST_P(RawSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + if (IsRunningOnGvisor()) { + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); + return; + } + + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(RawSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ // AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to. TEST(RawSocketTest, IPv6ProtoRaw) { @@ -813,6 +837,11 @@ TEST(RawSocketTest, IPv6ProtoRaw) { SyscallFailsWithErrno(EINVAL)); } +INSTANTIATE_TEST_SUITE_P( + AllInetTests, RawSocketTest, + ::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP), + ::testing::Values(AF_INET, AF_INET6))); + } // namespace } // namespace testing diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index a4d2953e1..0cea7d11f 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -13,6 +13,9 @@ // limitations under the License. #include +#ifndef __fuchsia__ +#include +#endif // __fuchsia__ #include #include #include @@ -1559,6 +1562,63 @@ TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) { } } +#ifndef __fuchsia__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(SimpleTcpSocketTest, SetSocketAttachDetachFilter) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // Program generated using sudo tcpdump -i lo tcp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000006}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000006}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +TEST_P(SimpleTcpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index 9cc6be4fb..60c48ed6e 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -16,6 +16,9 @@ #include #include +#ifndef __fuchsia__ +#include +#endif // __fuchsia__ #include #include #include @@ -1723,5 +1726,56 @@ TEST_P(UdpSocketTest, RecvBufLimits) { } } +#ifndef __fuchsia__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(UdpSocketTest, SetSocketDetachFilter) { + // Program generated using sudo tcpdump -i lo udp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT( + getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From fef90c61c6186c113cfdb0bbcf53f4ca70f9741a Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 15 Jul 2020 14:13:42 -0700 Subject: Fix minor bugs in a couple of interface IOCTLs. gVisor incorrectly returns the wrong ARP type for SIOGIFHWADDR. This breaks tcpdump as it tries to interpret the packets incorrectly. Similarly, SIOCETHTOOL is used by tcpdump to query interface properties which fails with an EINVAL since we don't implement it. For now change it to return EOPNOTSUPP to indicate that we don't support the query rather than return EINVAL. NOTE: ARPHRD types for link endpoints are distinct from NIC capabilities and NIC flags. In Linux all 3 exist eg. ARPHRD types are stored in dev->type field while NIC capabilities are more like the device features which can be queried using SIOCETHTOOL but not modified and NIC Flags are fields that can be modified from user space. eg. NIC status (UP/DOWN/MULTICAST/BROADCAST) etc. Updates #2746 PiperOrigin-RevId: 321436525 --- pkg/abi/linux/ioctl.go | 27 +++++++++-- pkg/abi/linux/netlink_route.go | 2 + pkg/sentry/socket/netstack/netstack.go | 74 ++++++++++++++++++------------ pkg/sentry/socket/netstack/stack.go | 22 ++++++--- pkg/tcpip/header/arp.go | 77 +++++++++++++++++++++----------- pkg/tcpip/link/channel/BUILD | 1 + pkg/tcpip/link/channel/channel.go | 6 +++ pkg/tcpip/link/fdbased/endpoint.go | 8 ++++ pkg/tcpip/link/loopback/loopback.go | 5 +++ pkg/tcpip/link/muxed/BUILD | 1 + pkg/tcpip/link/muxed/injectable.go | 6 +++ pkg/tcpip/link/nested/BUILD | 1 + pkg/tcpip/link/nested/nested.go | 6 +++ pkg/tcpip/link/qdisc/fifo/BUILD | 1 + pkg/tcpip/link/qdisc/fifo/endpoint.go | 6 +++ pkg/tcpip/link/sharedmem/sharedmem.go | 5 +++ pkg/tcpip/link/tun/device.go | 10 +++++ pkg/tcpip/link/waitable/BUILD | 2 + pkg/tcpip/link/waitable/waitable.go | 6 +++ pkg/tcpip/link/waitable/waitable_test.go | 6 +++ pkg/tcpip/network/ip_test.go | 9 +++- pkg/tcpip/stack/forwarder_test.go | 6 +++ pkg/tcpip/stack/nic_test.go | 5 +++ pkg/tcpip/stack/registration.go | 7 +++ pkg/tcpip/stack/stack.go | 6 +++ test/syscalls/linux/socket_netdevice.cc | 23 ++++++++++ 26 files changed, 263 insertions(+), 65 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 2062e6a4b..2c5e56ae5 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -67,10 +67,29 @@ const ( // ioctl(2) requests provided by uapi/linux/sockios.h const ( - SIOCGIFMEM = 0x891f - SIOCGIFPFLAGS = 0x8935 - SIOCGMIIPHY = 0x8947 - SIOCGMIIREG = 0x8948 + SIOCGIFNAME = 0x8910 + SIOCGIFCONF = 0x8912 + SIOCGIFFLAGS = 0x8913 + SIOCGIFADDR = 0x8915 + SIOCGIFDSTADDR = 0x8917 + SIOCGIFBRDADDR = 0x8919 + SIOCGIFNETMASK = 0x891b + SIOCGIFMETRIC = 0x891d + SIOCGIFMTU = 0x8921 + SIOCGIFMEM = 0x891f + SIOCGIFHWADDR = 0x8927 + SIOCGIFINDEX = 0x8933 + SIOCGIFPFLAGS = 0x8935 + SIOCGIFTXQLEN = 0x8942 + SIOCETHTOOL = 0x8946 + SIOCGMIIPHY = 0x8947 + SIOCGMIIREG = 0x8948 + SIOCGIFMAP = 0x8970 +) + +// ioctl(2) requests provided by uapi/asm-generic/sockios.h +const ( + SIOCGSTAMP = 0x8906 ) // ioctl(2) directions. Used to calculate requests number. diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 40bec566c..ceda0a8d3 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -187,6 +187,8 @@ const ( // Device types, from uapi/linux/if_arp.h. const ( + ARPHRD_NONE = 65534 + ARPHRD_ETHER = 1 ARPHRD_LOOPBACK = 772 ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 78a842973..0b1be1bd2 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2747,7 +2747,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2788,18 +2788,19 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ @@ -2815,7 +2816,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc }) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf @@ -2889,7 +2890,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2912,21 +2913,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2935,7 +2943,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2946,32 +2954,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2988,6 +2996,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 548442b96..67737ae87 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -15,6 +15,8 @@ package netstack import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -40,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool { return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber) } +// Converts Netstack's ARPHardwareType to equivalent linux constants. +func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { + switch t { + case header.ARPHardwareNone: + return linux.ARPHRD_NONE + case header.ARPHardwareLoopback: + return linux.ARPHRD_LOOPBACK + case header.ARPHardwareEther: + return linux.ARPHRD_ETHER + default: + panic(fmt.Sprintf("unknown ARPHRD type: %d", t)) + } +} + // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - var devType uint16 - if ni.Flags.Loopback { - devType = linux.ARPHRD_LOOPBACK - } is[int32(id)] = inet.Interface{ Name: ni.Name, Addr: []byte(ni.LinkAddress), Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: devType, + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), MTU: ni.MTU, } } diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go index 718a4720a..83189676e 100644 --- a/pkg/tcpip/header/arp.go +++ b/pkg/tcpip/header/arp.go @@ -14,14 +14,33 @@ package header -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // ARPProtocolNumber is the ARP network protocol number. ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 // ARPSize is the size of an IPv4-over-Ethernet ARP packet. - ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 + ARPSize = 28 +) + +// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header. +type ARPHardwareType uint16 + +// Typical ARP HardwareType values. Some of the constants have to be specific +// values as they are egressed on the wire in the HTYPE field of an ARP header. +const ( + ARPHardwareNone ARPHardwareType = 0 + // ARPHardwareEther specifically is the HTYPE for Ethernet as specified + // in the IANA list here: + // + // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2 + ARPHardwareEther ARPHardwareType = 1 + ARPHardwareLoopback ARPHardwareType = 2 ) // ARPOp is an ARP opcode. @@ -36,54 +55,64 @@ const ( // ARP is an ARP packet stored in a byte array as described in RFC 826. type ARP []byte -func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } -func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } -func (a ARP) hardwareAddressSize() int { return int(a[4]) } -func (a ARP) protocolAddressSize() int { return int(a[5]) } +const ( + hTypeOffset = 0 + protocolOffset = 2 + haAddressSizeOffset = 4 + protoAddressSizeOffset = 5 + opCodeOffset = 6 + senderHAAddressOffset = 8 + senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize + targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize + targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize +) + +func (a ARP) hardwareAddressType() ARPHardwareType { + return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:])) +} + +func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) } +func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) } +func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) } // Op is the ARP opcode. -func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } +func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) } // SetOp sets the ARP opcode. func (a ARP) SetOp(op ARPOp) { - a[6] = uint8(op >> 8) - a[7] = uint8(op) + binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op)) } // SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet. func (a ARP) SetIPv4OverEthernet() { - a[0], a[1] = 0, 1 // htypeEthernet - a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber - a[4] = 6 // macSize - a[5] = uint8(IPv4AddressSize) + binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther)) + binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber)) + a[haAddressSizeOffset] = EthernetAddressSize + a[protoAddressSizeOffset] = uint8(IPv4AddressSize) } // HardwareAddressSender is the link address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressSender() []byte { - const s = 8 - return a[s : s+6] + return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize] } // ProtocolAddressSender is the protocol address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressSender() []byte { - const s = 8 + 6 - return a[s : s+4] + return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize] } // HardwareAddressTarget is the link address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressTarget() []byte { - const s = 8 + 6 + 4 - return a[s : s+6] + return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize] } // ProtocolAddressTarget is the protocol address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressTarget() []byte { - const s = 8 + 6 + 4 + 6 - return a[s : s+4] + return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize] } // IsValid reports whether this is an ARP packet for IPv4 over Ethernet. @@ -91,10 +120,8 @@ func (a ARP) IsValid() bool { if len(a) < ARPSize { return false } - const htypeEthernet = 1 - const macSize = 6 - return a.hardwareAddressSpace() == htypeEthernet && + return a.hardwareAddressType() == ARPHardwareEther && a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && - a.hardwareAddressSize() == macSize && + a.hardwareAddressSize() == EthernetAddressSize && a.protocolAddressSize() == IPv4AddressSize } diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index b8b93e78e..39ca774ef 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 20b183da0..a2bb773d4 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -296,3 +297,8 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { e.q.RemoveNotify(handle) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f34082e1a..32abe2a13 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -626,6 +626,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + // InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes // to the FD, but does not read from it. All reads come from injected packets. type InjectableEndpoint struct { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 568c6874f..3b17d8c28 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -113,3 +113,8 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { return nil } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareLoopback +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 82b441b79..e7493e5c5 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index c69d6b7e9..c305d9e86 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -18,6 +18,7 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -129,6 +130,11 @@ func (m *InjectableEndpoint) Wait() { } } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unsupported operation") +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD index bdd5276ad..2cdb23475 100644 --- a/pkg/tcpip/link/nested/BUILD +++ b/pkg/tcpip/link/nested/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 2998f9c4f..328bd048e 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -129,3 +130,8 @@ func (e *Endpoint) GSOMaxSize() uint32 { } return 0 } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.child.ARPHardwareType() +} diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 054c213bc..1d0079bd6 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b5dfb7850..c84fe1bb9 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -207,3 +208,8 @@ func (e *endpoint) Wait() { e.wg.Wait() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0374a2441..a36862c67 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -287,3 +287,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { e.completed.Done() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 6bc9033d0..47446efec 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -139,6 +139,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE stack: s, nicID: id, name: name, + isTap: prefix == "tap", } endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { @@ -348,6 +349,7 @@ type tunEndpoint struct { stack *stack.Stack nicID tcpip.NICID name string + isTap bool } // DecRef decrements refcount of e, removes NIC if refcount goes to 0. @@ -356,3 +358,11 @@ func (e *tunEndpoint) DecRef() { e.stack.RemoveNIC(e.nicID) }) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.isTap { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0956d2c65..ee84c3d96 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/gate", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) @@ -25,6 +26,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 949b3f2b2..24a8dc2eb 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -147,3 +148,8 @@ func (e *Endpoint) WaitDispatch() { // Wait implements stack.LinkEndpoint.Wait. func (e *Endpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 63bf40562..ffb2354be 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -81,6 +82,11 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return nil } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unimplemented") +} + // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 7c8fb3e0a..a5b780ca2 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,14 +172,19 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { +func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testObject) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index a6546cef0..eefb4b07f 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -301,6 +302,11 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 31f865260..3bc9fd831 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -84,6 +84,11 @@ func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // An IPv6 NetworkEndpoint that throws away outgoing packets. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5cbc946b6..f260eeb7f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/waiter" ) @@ -436,6 +437,12 @@ type LinkEndpoint interface { // Wait will not block if the endpoint hasn't started any goroutines // yet, even if it might later. Wait() + + // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint. + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 + ARPHardwareType() header.ARPHardwareType } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0aa815447..2b7ece851 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1095,6 +1095,11 @@ type NICInfo struct { // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext + + // ARPHardwareType holds the ARP Hardware type of the NIC. This is the + // value sent in haType field of an ARP Request sent by this NIC and the + // value expected in the haType field of an ARP response. + ARPHardwareType header.ARPHardwareType } // HasNIC returns true if the NICID is defined in the stack. @@ -1126,6 +1131,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 15d4b85a7..5f8d7f981 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -49,6 +50,7 @@ TEST(NetdeviceTest, Loopback) { // Check that the loopback is zero hardware address. ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + EXPECT_EQ(ifr.ifr_hwaddr.sa_family, ARPHRD_LOOPBACK); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0); @@ -178,6 +180,27 @@ TEST(NetdeviceTest, InterfaceMTU) { EXPECT_GT(ifr.ifr_mtu, 0); } +TEST(NetdeviceTest, EthtoolGetTSInfo) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ethtool_ts_info tsi = {}; + tsi.cmd = ETHTOOL_GET_TS_INFO; // Get NIC's Timestamping capabilities. + + // Prepare the request. + struct ifreq ifr = {}; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + ifr.ifr_data = (void*)&tsi; + + // Check that SIOCGIFMTU returns a nonzero MTU. + if (IsRunningOnGvisor()) { + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), + SyscallFailsWithErrno(EOPNOTSUPP)); + return; + } + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), SyscallSucceeds()); +} + } // namespace } // namespace testing -- cgit v1.2.3 From dcf6ddc2772b8fcf824f1f48e0281e1cc80b93ea Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 16 Jul 2020 18:38:28 -0700 Subject: Add support to return protocol in recvmsg for AF_PACKET. Updates #173 PiperOrigin-RevId: 321690756 --- pkg/sentry/socket/netstack/netstack.go | 24 +++++++++++++++++++++--- pkg/tcpip/tcpip.go | 19 +++++++++++++++++++ pkg/tcpip/transport/packet/endpoint.go | 18 ++++++++++++++++-- test/syscalls/linux/packet_socket.cc | 1 + test/syscalls/linux/packet_socket_raw.cc | 1 + 5 files changed, 58 insertions(+), 5 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0b1be1bd2..49a04e613 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -297,8 +297,9 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -447,8 +448,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) @@ -2509,6 +2523,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + } } if peek { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 71bcee785..48ad56d4d 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -549,6 +549,25 @@ type Endpoint interface { SetOwner(owner PacketOwner) } +// LinkPacketInfo holds Link layer information for a received packet. +// +// +stateify savable +type LinkPacketInfo struct { + // Protocol is the NetworkProtocolNumber for the packet. + Protocol NetworkProtocolNumber +} + +// PacketEndpoint are additional methods that are only implemented by Packet +// endpoints. +type PacketEndpoint interface { + // ReadPacket reads a datagram/packet from the endpoint and optionally + // returns the sender and additional LinkPacketInfo. + // + // This method does not block if there is no data pending. It will also + // either return an error or data, never both. + ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) +} + // EndpointInfo is the interface implemented by each endpoint info struct. type EndpointInfo interface { // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 92b487381..7b2083a09 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -45,6 +45,9 @@ type packet struct { timestampNS int64 // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + // packetInfo holds additional information like the protocol + // of the packet etc. + packetInfo tcpip.LinkPacketInfo } // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal @@ -151,8 +154,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.PacketEndpoint.ReadPacket. +func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -177,9 +180,18 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes *addr = packet.senderAddr } + if info != nil { + *info = packet.packetInfo + } + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil } +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return ep.ReadPacket(addr, nil) +} + func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(b/129292371): Implement. return 0, nil, tcpip.ErrInvalidOptionValue @@ -428,12 +440,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } + packet.packetInfo.Protocol = netProto } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(localAddr), } + packet.packetInfo.Protocol = netProto } if ep.cooked { diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index ce63adb23..e94ddcb77 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -193,6 +193,7 @@ void ReceiveMessage(int sock, int ifindex) { EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, ifindex); EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { EXPECT_EQ(src.sll_addr[i], 0); diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index e44062475..2fca9fe4d 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -200,6 +200,7 @@ TEST_P(RawPacketTest, Receive) { EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { EXPECT_EQ(src.sll_addr[i], 0); -- cgit v1.2.3 From 71bf90c55bd888f9b9c493533ca5e4b2b4b3d21d Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 22 Jul 2020 15:12:56 -0700 Subject: Support for receiving outbound packets in AF_PACKET. Updates #173 PiperOrigin-RevId: 322665518 --- pkg/abi/linux/socket.go | 9 +++ pkg/sentry/socket/netstack/netstack.go | 19 +++++ pkg/tcpip/link/channel/channel.go | 4 + pkg/tcpip/link/fdbased/endpoint.go | 36 ++++----- pkg/tcpip/link/fdbased/endpoint_test.go | 8 ++ pkg/tcpip/link/loopback/loopback.go | 3 + pkg/tcpip/link/muxed/injectable.go | 4 + pkg/tcpip/link/nested/nested.go | 15 ++++ pkg/tcpip/link/nested/nested_test.go | 4 + pkg/tcpip/link/packetsocket/BUILD | 14 ++++ pkg/tcpip/link/packetsocket/endpoint.go | 50 +++++++++++++ pkg/tcpip/link/qdisc/fifo/endpoint.go | 12 +++ pkg/tcpip/link/sharedmem/sharedmem.go | 21 ++++-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 4 + pkg/tcpip/link/sniffer/sniffer.go | 5 ++ pkg/tcpip/link/tun/device.go | 43 +++++++---- pkg/tcpip/link/waitable/waitable.go | 14 ++++ pkg/tcpip/link/waitable/waitable_test.go | 9 +++ pkg/tcpip/network/ip_test.go | 5 ++ pkg/tcpip/stack/forwarder_test.go | 5 ++ pkg/tcpip/stack/nic.go | 30 ++++++-- pkg/tcpip/stack/nic_test.go | 5 ++ pkg/tcpip/stack/packet_buffer.go | 4 + pkg/tcpip/stack/registration.go | 16 +++- pkg/tcpip/tcpip.go | 25 +++++++ pkg/tcpip/transport/packet/endpoint.go | 52 +++++++++---- runsc/boot/BUILD | 1 + runsc/boot/network.go | 4 + test/syscalls/linux/packet_socket.cc | 116 +++++++++++++++++++++++++++++ 29 files changed, 471 insertions(+), 66 deletions(-) create mode 100644 pkg/tcpip/link/packetsocket/BUILD create mode 100644 pkg/tcpip/link/packetsocket/endpoint.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 4a14ef691..95337c168 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -134,6 +134,15 @@ const ( SHUT_RDWR = 2 ) +// Packet types from +const ( + PACKET_HOST = 0 // To us + PACKET_BROADCAST = 1 // To all + PACKET_MULTICAST = 2 // To group + PACKET_OTHERHOST = 3 // To someone else + PACKET_OUTGOING = 4 // Outgoing of any type +) + // Socket options from socket.h. const ( SO_DEBUG = 1 diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 49a04e613..964ec8414 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" @@ -2468,6 +2469,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2526,6 +2544,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq switch v := addr.(type) { case *linux.SockAddrLink: v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a2bb773d4..e12a5929b 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -302,3 +302,7 @@ func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { func (*Endpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 6aa1badc7..c18bb91fb 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -386,26 +386,33 @@ const ( _VIRTIO_NET_HDR_GSO_TCPV6 = 4 ) -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) } +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.hdrSize > 0 { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } var builder iovec.Builder @@ -448,22 +455,8 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { - var eth header.Ethernet if e.hdrSize > 0 { - // Add ethernet header if needed. - eth = make(header.Ethernet, header.EthernetMinimumSize) - ethHdr := &header.EthernetFields{ - DstAddr: pkt.EgressRoute.RemoteLinkAddress, - Type: pkt.NetworkProtocolNumber, - } - - // Preserve the src address if it's set in the route. - if pkt.EgressRoute.LocalLinkAddress != "" { - ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) } var vnetHdrBuf []byte @@ -493,7 +486,6 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc var builder iovec.Builder builder.Add(vnetHdrBuf) - builder.Add(eth) builder.Add(pkt.Header.View()) for _, v := range pkt.Data.Views() { builder.Add(v) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 4bad930c7..7b995b85a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin c.ch <- packetInfo{remote, protocol, pkt} } +func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNoEthernetProperties(t *testing.T) { c := newContext(t, &Options{MTU: mtu}) defer c.cleanup() @@ -510,6 +514,10 @@ func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAdd d.pkts = append(d.pkts, pkt) } +func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestDispatchPacketFormat(t *testing.T) { for _, test := range []struct { name string diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 3b17d8c28..781cdd317 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -118,3 +118,6 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { func (*endpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareLoopback } + +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index c305d9e86..56a611825 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -135,6 +135,10 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { panic("unsupported operation") } +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 328bd048e..d40de54df 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -61,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco } } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverOutboundPacket(remote, local, protocol, pkt) + } +} + // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.mu.Lock() @@ -135,3 +145,8 @@ func (e *Endpoint) GSOMaxSize() uint32 { func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { return e.child.ARPHardwareType() } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.child.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go index c1a219f02..7d9249c1c 100644 --- a/pkg/tcpip/link/nested/nested_test.go +++ b/pkg/tcpip/link/nested/nested_test.go @@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd d.count++ } +func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNestedLinkEndpoint(t *testing.T) { const emptyAddress = tcpip.LinkAddress("") diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD new file mode 100644 index 000000000..6fff160ce --- /dev/null +++ b/pkg/tcpip/link/packetsocket/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "packetsocket", + srcs = ["endpoint.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go new file mode 100644 index 000000000..3922c2a04 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packetsocket provides a link layer endpoint that provides the ability +// to loop outbound packets to any AF_PACKET sockets that may be interested in +// the outgoing packet. +package packetsocket + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type endpoint struct { + nested.Endpoint +} + +// New creates a new packetsocket LinkEndpoint. +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + e := &endpoint{} + e.Endpoint.Init(lower, e) + return e +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index c84fe1bb9..467083239 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -107,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) +} + // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher @@ -194,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // TODO(gvisor.dev/issue/3267/): Queue these packets as well once + // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } @@ -213,3 +220,8 @@ func (e *endpoint) Wait() { func (e *endpoint) ARPHardwareType() header.ARPHardwareType { return e.lower.ARPHardwareType() } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index a36862c67..507c76b76 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Add the ethernet header here. +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) v := pkt.Data.ToView() // Transmit the packet. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 28a2e88ba..8f3cd9449 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L c.packetCh <- struct{}{} } +func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (c *testContext) cleanup() { c.ep.Close() closeFDs(&c.txCfg) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index d9cd4e83a..509076643 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -123,6 +123,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) +} + func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 47446efec..04ae58e59 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -272,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader == nil { - hdr := &header.EthernetFields{ - SrcAddr: info.Route.LocalLinkAddress, - DstAddr: info.Route.RemoteLinkAddress, - Type: info.Proto, - } - if hdr.SrcAddr == "" { - hdr.SrcAddr = d.endpoint.LinkAddress() - } - - eth := make(header.Ethernet, header.EthernetMinimumSize) - eth.Encode(hdr) - vv.AppendView(buffer.View(eth)) - } else { - vv.AppendView(info.Pkt.LinkHeader) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } + vv.AppendView(info.Pkt.LinkHeader) } // Append upper headers. @@ -366,3 +354,30 @@ func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { } return header.ARPHardwareNone } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.isTap { + return + } + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) + hdr := &header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: protocol, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = e.LinkAddress() + } + + eth.Encode(hdr) +} + +// MaxHeaderLength returns the maximum size of the link layer header. +func (e *tunEndpoint) MaxHeaderLength() uint16 { + if e.isTap { + return header.EthernetMinimumSize + } + return 0 +} diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 24a8dc2eb..b152a0f26 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -60,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatchGate.Leave() } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + // Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and // registers with the lower endpoint as its dispatcher so that "e" is called // for inbound packets. @@ -153,3 +162,8 @@ func (e *Endpoint) Wait() {} func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { return e.lower.ARPHardwareType() } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index ffb2354be..c448a888f 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -40,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, e.dispatchCount++ } +func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.attachCount++ e.dispatcher = dispatcher @@ -90,6 +94,11 @@ func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} wep := New(ep) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a5b780ca2..615bae648 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -185,6 +185,11 @@ func (*testObject) ARPHardwareType() header.ARPHardwareType { panic("not implemented") } +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index eefb4b07f..bca1d940b 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -307,6 +307,11 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { panic("not implemented") } +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 7b80534e6..fea0ce7e8 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // Are any packet sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Check whether there are packet sockets listening for every protocol. - // If we received a packet with protocol EthernetProtocolAll, then the - // previous for loop will have handled it. - if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) - } + // Add any other packet sockets that maybe listening for all protocols. + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + p := pkt.Clone() + p.PktType = tcpip.PacketHost + ep.HandlePacket(n.id, local, protocol, p) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { @@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } +// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. +func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + // We do not deliver to protocol specific packet endpoints as on Linux + // only ETH_P_ALL endpoints get outbound packets. + // Add any other packet sockets that maybe listening for all protocols. + packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + n.mu.RUnlock() + for _, ep := range packetEPs { + p := pkt.Clone() + p.PktType = tcpip.PacketOutgoing + // Add the link layer header as outgoing packets are intercepted + // before the link layer header is created. + n.linkEP.AddHeader(local, remote, protocol, p) + ep.HandlePacket(n.id, local, protocol, p) + } +} + func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 3bc9fd831..c477e31d8 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -89,6 +89,11 @@ func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { panic("not implemented") } +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // An IPv6 NetworkEndpoint that throws away outgoing packets. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index e3556d5d2..5d6865e35 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -79,6 +79,10 @@ type PacketBuffer struct { // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. NatDone bool + + // PktType indicates the SockAddrLink.PacketType of the packet as defined in + // https://www.man7.org/linux/man-pages/man7/packet.7.html. + PktType tcpip.PacketType } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index f260eeb7f..cd4b7a449 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -330,8 +330,7 @@ type NetworkProtocol interface { } // NetworkDispatcher contains the methods used by the network stack to deliver -// packets to the appropriate network endpoint after it has been handled by -// the data link layer. +// inbound/outbound packets to the appropriate network/packet(if any) endpoints. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. @@ -342,6 +341,16 @@ type NetworkDispatcher interface { // // DeliverNetworkPacket takes ownership of pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) + + // DeliverOutboundPacket is called by link layer when a packet is being + // sent out. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverOutboundPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverOutboundPacket takes ownership of pkt. + DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -443,6 +452,9 @@ type LinkEndpoint interface { // See: // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 ARPHardwareType() header.ARPHardwareType + + // AddHeader adds a link layer header to pkt if required. + AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 48ad56d4d..ff14a3b3c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -316,6 +316,28 @@ const ( ShutdownWrite ) +// PacketType is used to indicate the destination of the packet. +type PacketType uint8 + +const ( + // PacketHost indicates a packet addressed to the local host. + PacketHost PacketType = iota + + // PacketOtherHost indicates an outgoing packet addressed to + // another host caught by a NIC in promiscuous mode. + PacketOtherHost + + // PacketOutgoing for a packet originating from the local host + // that is looped back to a packet socket. + PacketOutgoing + + // PacketBroadcast indicates a link layer broadcast packet. + PacketBroadcast + + // PacketMulticast indicates a link layer multicast packet. + PacketMulticast +) + // FullAddress represents a full transport node address, as required by the // Connect() and Bind() methods. // @@ -555,6 +577,9 @@ type Endpoint interface { type LinkPacketInfo struct { // Protocol is the NetworkProtocolNumber for the packet. Protocol NetworkProtocolNumber + + // PktType is used to indicate the destination of the packet. + PktType PacketType } // PacketEndpoint are additional methods that are only implemented by Packet diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 7b2083a09..8f167391f 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -441,6 +441,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, Addr: tcpip.Address(hdr.SourceAddress()), } packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ @@ -448,30 +449,53 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, Addr: tcpip.Address(localAddr), } packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } if ep.cooked { // Cooked packets can simply be queued. - packet.data = pkt.Data + switch pkt.PktType { + case tcpip.PacketHost: + packet.data = pkt.Data + case tcpip.PacketOutgoing: + // Strip Link Header from the Header. + pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) + combinedVV := pkt.Header.View().ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + default: + panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + } + } else { // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - if len(pkt.LinkHeader) == 0 { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, + var combinedVV buffer.VectorisedView + if pkt.PktType != tcpip.PacketOutgoing { + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - linkHeader = buffer.View(fakeHeader) - } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + combinedVV = linkHeader.ToVectorisedView() + } + if pkt.PktType == tcpip.PacketOutgoing { + // For outgoing packets the Link, Network and Transport + // headers are in the pkt.Header fields normally unless + // a Raw socket is in use in which case pkt.Header could + // be nil. + combinedVV.AppendView(pkt.Header.View()) } - combinedVV := linkHeader.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV } diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 55d45aaa6..9f52438c2 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -90,6 +90,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/packetsocket", "//pkg/tcpip/link/qdisc/fifo", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 14d2f56a5..4e1fa7665 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/packetsocket" "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -252,6 +253,9 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000) } + // Enable support for AF_PACKET sockets to receive outgoing packets. + linkEP = packetsocket.New(linkEP) + log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { return err diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index e94ddcb77..40aa9326d 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -417,6 +417,122 @@ TEST_P(CookedPacketTest, BindDrop) { EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); } +// Verify that we receive outbound packets. This test requires at least one +// non loopback interface so that we can actually capture an outgoing packet. +TEST_P(CookedPacketTest, ReceiveOutbound) { + // Only ETH_P_ALL sockets can receive outbound packets on linux. + SKIP_IF(GetParam() != ETH_P_ALL); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index and name. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + int ifindex = ifr.ifr_ifindex; + + constexpr int kMACSize = 6; + char hwaddr[kMACSize]; + // Get interface address. + ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + ASSERT_THAT(ifr.ifr_hwaddr.sa_family, + AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER))); + memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize); + + // Just send it to the google dns server 8.8.8.8. It's UDP we don't care + // if it actually gets to the DNS Server we just want to see that we receive + // it on our AF_PACKET socket. + // + // NOTE: We just want to pick an IP that is non-local to avoid having to + // handle ARP as this should cause the UDP packet to be sent to the default + // gateway configured for the system under test. Otherwise the only packet we + // will see is the ARP query unless we picked an IP which will actually + // resolve. The test is a bit brittle but this was the best compromise for + // now. + struct sockaddr_in dest = {}; + ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket receives the data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1)); + + // Now read and check that the packet is the one we just sent. + // Read and verify the data. + constexpr size_t packet_size = + sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_ifindex, ifindex); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); + EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING); + // Verify the link address of the interface matches that of the non + // non loopback interface address we stored above. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], hwaddr[i]); + } + + // Verify the IP header. + struct iphdr ip = {}; + memcpy(&ip, buf, sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size)); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr); + EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast(buf + sizeof(iphdr) + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + // Bind with invalid address. TEST_P(CookedPacketTest, BindFail) { // Null address. -- cgit v1.2.3 From 6f7f73996791bbab6b63c248000df0d3ce652f2b Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Thu, 23 Jul 2020 11:43:40 -0700 Subject: Marshallable socket opitons. Socket option values are now required to implement marshal.Marshallable. Co-authored-by: Rahat Mahmood PiperOrigin-RevId: 322831612 --- pkg/abi/linux/BUILD | 3 + pkg/abi/linux/netfilter.go | 146 ++++++++++++++++++-- pkg/abi/linux/socket.go | 8 ++ pkg/sentry/socket/BUILD | 1 + pkg/sentry/socket/hostinet/BUILD | 2 + pkg/sentry/socket/hostinet/socket.go | 7 +- pkg/sentry/socket/netfilter/netfilter.go | 28 ++-- pkg/sentry/socket/netlink/BUILD | 2 + pkg/sentry/socket/netlink/socket.go | 14 +- pkg/sentry/socket/netstack/BUILD | 2 + pkg/sentry/socket/netstack/netstack.go | 205 ++++++++++++++++++---------- pkg/sentry/socket/netstack/netstack_vfs2.go | 16 ++- pkg/sentry/socket/socket.go | 3 +- pkg/sentry/socket/unix/BUILD | 1 + pkg/sentry/socket/unix/unix.go | 3 +- pkg/sentry/socket/unix/unix_vfs2.go | 3 +- pkg/sentry/syscalls/linux/BUILD | 2 + pkg/sentry/syscalls/linux/sys_socket.go | 17 ++- pkg/sentry/syscalls/linux/vfs2/BUILD | 2 + pkg/sentry/syscalls/linux/vfs2/socket.go | 17 ++- tools/go_marshal/README.md | 8 +- tools/go_marshal/primitive/primitive.go | 72 ++++++++++ 22 files changed, 427 insertions(+), 135 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 2b789c4ec..a4bb62013 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -72,6 +72,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 46d8b0b42..a91f9f018 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -14,6 +14,14 @@ package linux +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" +) + // This file contains structures required to support netfilter, specifically // the iptables tool. @@ -76,6 +84,8 @@ const ( // IPTEntry is an iptable rule. It corresponds to struct ipt_entry in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTEntry struct { // IP is used to filter packets based on the IP header. IP IPTIP @@ -112,21 +122,41 @@ type IPTEntry struct { // SizeOfIPTEntry is the size of an IPTEntry. const SizeOfIPTEntry = 112 -// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This -// struct marshaled via the binary package to write an IPTEntry to userspace. +// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. +// KernelIPTEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. type KernelIPTEntry struct { - IPTEntry + Entry IPTEntry // Elems holds the data for all this rule's matches followed by the // target. It is variable length -- users have to iterate over any // matches and use TargetOffset and NextOffset to make sense of the // data. - Elems []byte + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } // IPTIP contains information for matching a packet's IP header. // It corresponds to struct ipt_ip in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTIP struct { // Src is the source IP address. Src InetAddr @@ -189,6 +219,8 @@ const SizeOfIPTIP = 84 // XTCounters holds packet and byte counts for a rule. It corresponds to struct // xt_counters in include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTCounters struct { // Pcnt is the packet count. Pcnt uint64 @@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56 // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetinfo struct { Name TableName ValidHooks uint32 @@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84 // IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It // corresponds to struct ipt_get_entries in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetEntries struct { Name TableName Size uint32 @@ -350,13 +386,103 @@ type IPTGetEntries struct { const SizeOfIPTGetEntries = 40 // KernelIPTGetEntries is identical to IPTGetEntries, but includes the -// Entrytable field. This struct marshaled via the binary package to write an -// KernelIPTGetEntries to userspace. +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. type KernelIPTGetEntries struct { IPTGetEntries Entrytable []KernelIPTEntry } +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIPTGetEntries) Packed() bool { + // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an + // indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIPTGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := task.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) +} + +func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte { + buf := task.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) + // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. @@ -374,12 +500,6 @@ type IPTReplace struct { // Entries [0]IPTEntry } -// KernelIPTReplace is identical to IPTReplace, but includes the Entries field. -type KernelIPTReplace struct { - IPTReplace - Entries [0]IPTEntry -} - // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 @@ -392,6 +512,8 @@ func (en ExtensionName) String() string { } // TableName holds the name of a netfilter table. +// +// +marshal type TableName [XT_TABLE_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 95337c168..c24a8216e 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -234,6 +234,8 @@ const ( const SockAddrMax = 128 // InetAddr is struct in_addr, from uapi/linux/in.h. +// +// +marshal type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. @@ -303,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} // Linger is struct linger, from include/linux/socket.h. +// +// +marshal type Linger struct { OnOff int32 Linger int32 @@ -317,6 +321,8 @@ const SizeOfLinger = 8 // the end of this struct or within existing unusued space, so its size grows // over time. The current iteration is based on linux v4.17. New versions are // always backwards compatible. +// +// +marshal type TCPInfo struct { State uint8 CaState uint8 @@ -414,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // // ControlMessageCredentials represents struct ucred from linux/socket.h. +// +// +marshal type ControlMessageCredentials struct { PID int32 UID uint32 diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c40c6d673..c0fd3425b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,5 +20,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index ff81ea6e6..e76e498de 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -40,6 +40,8 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index a92aed2c9..ec5506efc 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -36,6 +36,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 1243143ea..d9394055d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin // Each rule corresponds to an entry. entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ + Entry: linux.IPTEntry{ IP: linux.IPTIP{ Protocol: uint16(rule.Filter.Protocol), }, @@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin TargetOffset: linux.SizeOfIPTEntry, }, } - copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) - copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) - copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) - copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP } if rule.Filter.SrcInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP } if rule.Filter.OutputInterfaceInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } for _, matcher := range rule.Matchers { @@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) } // Serialize and append the target. @@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) nflog("convert to binary: adding entry: %+v", entry) - entries.Size += uint32(entry.NextOffset) + entries.Size += uint32(entry.Entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) info.NumEntries++ } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index d5ca3ac56..0546801bf 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -36,6 +36,8 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81f34c5a2..98ca7add0 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -38,6 +38,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 @@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var passcred int32 + var passcred primitive.Int32 if s.Passcred() { passcred = 1 } - return passcred, nil + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ea6ebd0e2..1fb777a6c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 964ec8414..9856ab8c5 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -62,6 +62,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -910,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -920,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -956,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -971,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -981,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -1014,7 +1016,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1025,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -1035,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -1050,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1066,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -1082,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { @@ -1093,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { @@ -1104,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1112,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } if v == 0 { - return []byte{}, nil + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument @@ -1127,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // interface was removed. return nil, syserr.ErrUnknownDevice } - return append([]byte(nic.Name), 0), nil + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1138,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1149,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1163,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1171,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1183,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { @@ -1194,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1203,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { @@ -1214,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(!v), nil + + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { @@ -1225,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { @@ -1236,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1247,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1259,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1271,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil case linux.TCP_KEEPCNT: if outLen < sizeOfInt32 { @@ -1283,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { @@ -1295,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Millisecond), nil + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1309,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1344,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + + bP := primitive.ByteSlice(b) + return &bP, nil case linux.TCP_LINGER2: if outLen < sizeOfInt32 { @@ -1356,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: if outLen < sizeOfInt32 { @@ -1368,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil case linux.TCP_SYNCNT: if outLen < sizeOfInt32 { @@ -1379,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_WINDOW_CLAMP: if outLen < sizeOfInt32 { @@ -1391,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1400,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1411,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1419,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil case linux.IPV6_RECVTCLASS: if outLen < sizeOfInt32 { @@ -1444,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIPv6(t, name) @@ -1453,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1466,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { @@ -1482,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1496,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1507,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_RECVTOS: if outLen < sizeOfInt32 { @@ -1532,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1543,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIP(t, name) diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index d65a89316..a9025b0ec 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -31,6 +31,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return entries, nil + return &entries, nil } } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fcd7f9d7f..d112757fb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip @@ -86,7 +87,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cca5e70f1..061a689a9 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -35,5 +35,6 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4bb2b6ff4..0482d33cf 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -40,6 +40,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index ff2149250..05c16fcfe 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, @@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 217fcfef2..4a9b04fd0 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -99,5 +99,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 0760af77b..414fce8e3 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -29,6 +29,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 0c740335b..64696b438 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -72,5 +72,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 10b668477..8096a8f9c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -30,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 4886efddf..68d759083 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -9,11 +9,9 @@ automatically generating code to marshal go data structures to memory. `binary.Marshal` by moving the go runtime reflection necessary to marshal a struct to compile-time. -`go_marshal` automatically generates implementations for `abi.Marshallable` and -`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall -implementations) can directly invoke `safemem.Reader.ReadToBlocks` and -`safemem.Writer.WriteFromBlocks`. Data structures that require custom -serialization will have manual implementations for these interfaces. +`go_marshal` automatically generates implementations for `marshal.Marshallable` +and `safemem.{Reader,Writer}`. Data structures that require custom serialization +will have manual implementations for these interfaces. Data structures can be flagged for code generation by adding a struct-level comment `// +marshal`. diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go index ebcf130ae..d93edda8b 100644 --- a/tools/go_marshal/primitive/primitive.go +++ b/tools/go_marshal/primitive/primitive.go @@ -17,10 +17,22 @@ package primitive import ( + "io" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/marshal" ) +// Int8 is a marshal.Marshallable implementation for int8. +// +// +marshal slice:Int8Slice:inner +type Int8 int8 + +// Uint8 is a marshal.Marshallable implementation for uint8. +// +// +marshal slice:Uint8Slice:inner +type Uint8 uint8 + // Int16 is a marshal.Marshallable implementation for int16. // // +marshal slice:Int16Slice:inner @@ -51,6 +63,66 @@ type Int64 int64 // +marshal slice:Uint64Slice:inner type Uint64 uint64 +// ByteSlice is a marshal.Marshallable implementation for []byte. +// This is a convenience wrapper around a dynamically sized type, and can't be +// embedded in other marshallable types because it breaks assumptions made by +// go-marshal internals. It violates the "no dynamically-sized types" +// constraint of the go-marshal library. +type ByteSlice []byte + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (b *ByteSlice) SizeBytes() int { + return len(*b) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (b *ByteSlice) MarshalBytes(dst []byte) { + copy(dst, *b) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (b *ByteSlice) UnmarshalBytes(src []byte) { + copy(*b, src) +} + +// Packed implements marshal.Marshallable.Packed. +func (b *ByteSlice) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (b *ByteSlice) MarshalUnsafe(dst []byte) { + b.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (b *ByteSlice) UnmarshalUnsafe(src []byte) { + b.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (b *ByteSlice) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + return task.CopyInBytes(addr, *b) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (b *ByteSlice) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + return task.CopyOutBytes(addr, *b) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (b *ByteSlice) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + return task.CopyOutBytes(addr, (*b)[:limit]) +} + +// WriteTo implements io.WriterTo.WriteTo. +func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(*b) + return int64(n), err +} + +var _ marshal.Marshallable = (*ByteSlice)(nil) + // Below, we define some convenience functions for marshalling primitive types // using the newtypes above, without requiring superfluous casts. -- cgit v1.2.3 From e2c70ee9814f0f76ab5c30478748e4c697e91f33 Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Fri, 24 Jul 2020 01:24:16 -0700 Subject: Enable automated marshalling for netstack. PiperOrigin-RevId: 322954792 --- pkg/abi/linux/netdevice.go | 4 +++ pkg/sentry/fs/file_operations.go | 1 + pkg/sentry/socket/netfilter/netfilter.go | 4 +-- pkg/sentry/socket/netstack/netstack.go | 54 ++++++++++++++------------------ 4 files changed, 31 insertions(+), 32 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go index 7866352b4..0faf015c7 100644 --- a/pkg/abi/linux/netdevice.go +++ b/pkg/abi/linux/netdevice.go @@ -22,6 +22,8 @@ const ( ) // IFReq is an interface request. +// +// +marshal type IFReq struct { // IFName is an encoded name, normally null-terminated. This should be // accessed via the Name and SetName functions. @@ -79,6 +81,8 @@ type IFMap struct { // IFConf is used to return a list of interfaces and their addresses. See // netdevice(7) and struct ifconf for more detail on its use. +// +// +marshal type IFConf struct { Len int32 _ [4]byte // Pad to sizeof(struct ifconf). diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index beba0f771..f5537411e 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -160,6 +160,7 @@ type FileOperations interface { // refer. // // Preconditions: The AddressSpace (if any) that io refers to is activated. + // Must only be called from a task goroutine. Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index d9394055d..a9f0604ae 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -66,7 +66,7 @@ func nflog(format string, args ...interface{}) { func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo - if _, err := t.CopyIn(outPtr, &info); err != nil { + if _, err := info.CopyIn(t, outPtr); err != nil { return linux.IPTGetinfo{}, syserr.FromError(err) } @@ -84,7 +84,7 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries - if _, err := t.CopyIn(outPtr, &userEntries); err != nil { + if _, err := userEntries.CopyIn(t, outPtr); err != nil { nflog("couldn't copy in entries %q", userEntries.Name) return linux.KernelIPTGetEntries{}, syserr.FromError(err) } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9856ab8c5..44b3fff46 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2835,6 +2835,11 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. @@ -2847,9 +2852,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } tv := linux.NsecToTimeval(s.timestampNS) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2868,9 +2871,8 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err } @@ -2879,6 +2881,11 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + switch arg := int(args[1].Int()); arg { case linux.SIOCGIFFLAGS, linux.SIOCGIFADDR, @@ -2895,37 +2902,28 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc linux.SIOCETHTOOL: var ifr linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil { return 0, err.ToError() } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := ifr.CopyOut(t, args[2].Pointer()) return 0, err case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } - if err := ifconfIoctl(ctx, io, &ifc); err != nil { + if err := ifconfIoctl(ctx, t, io, &ifc); err != nil { return 0, err } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }) - + _, err := ifc.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2938,9 +2936,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCOUTQ: @@ -2954,9 +2951,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG: @@ -3105,7 +3101,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. @@ -3142,9 +3138,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Copy the ifr to userspace. dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) ifc.Len += int32(linux.SizeOfIFReq) - if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil { return err } } -- cgit v1.2.3 From f82dd8ddb477b8923d1db12654a62d55d613019c Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Tue, 28 Jul 2020 21:22:52 -0700 Subject: Redirect TODO to GitHub issues PiperOrigin-RevId: 323715260 --- pkg/sentry/socket/netstack/netstack.go | 4 ++-- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack_test.go | 6 +++--- pkg/tcpip/transport/packet/endpoint.go | 4 ++-- test/syscalls/linux/packet_socket.cc | 4 ++-- test/syscalls/linux/packet_socket_raw.cc | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 44b3fff46..f86e6cd7a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -423,7 +423,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), @@ -2418,7 +2418,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index fea0ce7e8..9256d4d43 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -181,7 +181,7 @@ func (n *NIC) disableLocked() *tcpip.Error { return nil } - // TODO(b/147015577): Should Routes that are currently bound to n be + // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be // invalidated? Currently, Routes will continue to work when a NIC is enabled // again, and applications may not know that the underlying NIC was ever // disabled. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7657a4101..101ca2206 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -867,9 +867,9 @@ func TestRouteWithDownNIC(t *testing.T) { // Writes with Routes that use NIC1 after being brought up should // succeed. // - // TODO(b/147015577): Should we instead completely invalidate all - // Routes that were bound to a NIC that was brought down at some - // point? + // TODO(gvisor.dev/issue/1491): Should we instead completely + // invalidate all Routes that were bound to a NIC that was brought + // down at some point? if err := upFn(s, nicID1); err != nil { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0e46e6355..df478115d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -193,7 +193,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - // TODO(b/129292371): Implement. + // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } @@ -432,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(b/129292371): Return network protocol. + // TODO(gvisor.dev/issue/173): Return network protocol. if len(pkt.LinkHeader) > 0 { // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader) diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 40aa9326d..861617ff7 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -188,7 +188,7 @@ void ReceiveMessage(int sock, int ifindex) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, ifindex); @@ -234,7 +234,7 @@ TEST_P(CookedPacketTest, Receive) { // Send via a packet socket. TEST_P(CookedPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 2fca9fe4d..a11a03415 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -195,7 +195,7 @@ TEST_P(RawPacketTest, Receive) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); @@ -240,7 +240,7 @@ TEST_P(RawPacketTest, Receive) { // Send via a packet socket. TEST_P(RawPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. -- cgit v1.2.3 From 2a7b2a61e3ea32129c26eeaa6fab3d81a5d8ad6e Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 11 Jun 2020 20:33:56 -0700 Subject: iptables: support SO_ORIGINAL_DST Envoy (#170) uses this to get the original destination of redirected packets. --- pkg/abi/linux/netfilter.go | 8 +- pkg/abi/linux/socket.go | 4 +- pkg/sentry/socket/netstack/netstack.go | 17 ++++ pkg/sentry/strace/socket.go | 1 + pkg/tcpip/stack/conntrack.go | 26 ++++++ pkg/tcpip/stack/iptables.go | 11 ++- pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/tcp/endpoint.go | 11 +++ test/iptables/BUILD | 1 + test/iptables/iptables_test.go | 8 ++ test/iptables/iptables_unsafe.go | 63 ++++++++++++++ test/iptables/iptables_util.go | 51 ++++++++++- test/iptables/nat.go | 152 ++++++++++++++++++++++++++++++++- 13 files changed, 344 insertions(+), 13 deletions(-) create mode 100644 test/iptables/iptables_unsafe.go (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index a91f9f018..9c27f7bb2 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -59,7 +59,7 @@ var VerdictStrings = map[int32]string{ NF_RETURN: "RETURN", } -// Socket options. These correspond to values in +// Socket options for SOL_SOCKET. These correspond to values in // include/uapi/linux/netfilter_ipv4/ip_tables.h. const ( IPT_BASE_CTL = 64 @@ -74,6 +74,12 @@ const ( IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET ) +// Socket option for SOL_IP. This corresponds to the value in +// include/uapi/linux/netfilter_ipv4.h. +const ( + SO_ORIGINAL_DST = 80 +) + // Name lengths. These correspond to values in // include/uapi/linux/netfilter/x_tables.h. const ( diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index c24a8216e..d6946bb82 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -239,11 +239,13 @@ const SockAddrMax = 128 type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. +// +// +marshal type SockAddrInet struct { Family uint16 Port uint16 Addr InetAddr - Zero [8]uint8 // pad to sizeof(struct sockaddr). + _ [8]uint8 // pad to sizeof(struct sockaddr). } // InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f86e6cd7a..31a168f7e 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1490,6 +1490,10 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ORIGINAL_DST: + // TODO(gvisor.dev/issue/170): ip6tables. + return nil, syserr.ErrInvalidArgument + default: emitUnimplementedEventIPv6(t, name) } @@ -1600,6 +1604,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ORIGINAL_DST: + if outLen < int(binary.Size(linux.SockAddrInet{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet), nil + default: emitUnimplementedEventIP(t, name) } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index c0512de89..b51c4c941 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -521,6 +521,7 @@ var sockOptNames = map[uint64]abi.ValueSet{ linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT", linux.IP_PKTOPTIONS: "IP_PKTOPTIONS", linux.IP_MTU: "IP_MTU", + linux.SO_ORIGINAL_DST: "SO_ORIGINAL_DST", }, linux.SOL_SOCKET: { linux.SO_ERROR: "SO_ERROR", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 559a1c4dd..470c265aa 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -240,7 +240,10 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { if err != nil { return nil, dirOriginal } + return ct.connForTID(tid) +} +func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { bucket := ct.bucket(tid) now := time.Now() @@ -604,3 +607,26 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } + +func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + // Lookup the connection. The reply's original destination + // describes the original address. + tid := tupleID{ + srcAddr: epID.LocalAddress, + srcPort: epID.LocalPort, + dstAddr: epID.RemoteAddress, + dstPort: epID.RemotePort, + transProto: header.TCPProtocolNumber, + netProto: header.IPv4ProtocolNumber, + } + conn, _ := ct.connForTID(tid) + if conn == nil { + // Not a tracked connection. + return "", 0, tcpip.ErrNotConnected + } else if conn.manip == manipNone { + // Unmanipulated connection. + return "", 0, tcpip.ErrInvalidOptionValue + } + + return conn.original.dstAddr, conn.original.dstPort, nil +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index cbbae4224..110ba073d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -218,19 +218,16 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() + defer it.mu.RUnlock() if !it.modified { - it.mu.RUnlock() return true } - it.mu.RUnlock() // Packets are manipulated only if connection and matching // NAT rule exists. shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - it.mu.RLock() - defer it.mu.RUnlock() priorities := it.priorities[hook] for _, tableID := range priorities { // If handlePacket already NATed the packet, we don't need to @@ -418,3 +415,9 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // All the matchers matched, so run the target. return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } + +// OriginalDst returns the original destination of redirected connections. It +// returns an error if the connection doesn't exist or isn't redirected. +func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + return it.connections.originalDst(epID) +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index a634b9b60..45f59b60f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -954,6 +954,10 @@ type DefaultTTLOption uint8 // classic BPF filter on a given endpoint. type SocketDetachFilterOption int +// OriginalDestinationOption is used to get the original destination address +// and port of a redirected packet. +type OriginalDestinationOption FullAddress + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0f7487963..682687ebe 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2017,6 +2017,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.TCPDeferAcceptOption(e.deferAccept) e.UnlockUser() + case *tcpip.OriginalDestinationOption: + ipt := e.stack.IPTables() + addr, port, err := ipt.OriginalDst(e.ID) + if err != nil { + return err + } + *o = tcpip.OriginalDestinationOption{ + Addr: addr, + Port: port, + } + default: return tcpip.ErrUnknownProtocolOption } diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 40b63ebbe..66453772a 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -9,6 +9,7 @@ go_library( "filter_input.go", "filter_output.go", "iptables.go", + "iptables_unsafe.go", "iptables_util.go", "nat.go", ], diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 550b6198a..fda5f694f 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -371,3 +371,11 @@ func TestFilterAddrs(t *testing.T) { } } } + +func TestNATPreOriginalDst(t *testing.T) { + singleTest(t, NATPreOriginalDst{}) +} + +func TestNATOutOriginalDst(t *testing.T) { + singleTest(t, NATOutOriginalDst{}) +} diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go new file mode 100644 index 000000000..bd85a8fea --- /dev/null +++ b/test/iptables/iptables_unsafe.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "syscall" + "unsafe" +) + +type originalDstError struct { + errno syscall.Errno +} + +func (e originalDstError) Error() string { + return fmt.Sprintf("errno (%d) when calling getsockopt(SO_ORIGINAL_DST): %v", int(e.errno), e.errno.Error()) +} + +// SO_ORIGINAL_DST gets the original destination of a redirected packet via +// getsockopt. +const SO_ORIGINAL_DST = 80 + +func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) { + var addr syscall.RawSockaddrInet4 + var addrLen uint32 = syscall.SizeofSockaddrInet4 + if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet4{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) { + var addr syscall.RawSockaddrInet6 + var addrLen uint32 = syscall.SizeofSockaddrInet6 + if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet6{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno { + _, _, errno := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + uintptr(connfd), + level, + SO_ORIGINAL_DST, + uintptr(optval), + uintptr(unsafe.Pointer(optlen)), + 0) + return errno +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index ca80a4b5f..5125fe47b 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -15,6 +15,8 @@ package iptables import ( + "encoding/binary" + "errors" "fmt" "net" "os/exec" @@ -218,17 +220,58 @@ func filterAddrs(addrs []string, ipv6 bool) []string { // getInterfaceName returns the name of the interface other than loopback. func getInterfaceName() (string, bool) { - var ifname string + iface, ok := getNonLoopbackInterface() + if !ok { + return "", false + } + return iface.Name, true +} + +func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) { + iface, ok := getNonLoopbackInterface() + if !ok { + return nil, errors.New("no non-loopback interface found") + } + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + // Get only IPv4 or IPv6 addresses. + ips := make([]net.IP, 0, len(addrs)) + for _, addr := range addrs { + parts := strings.Split(addr.String(), "/") + var ip net.IP + // To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses. + // So we check whether To4() returns nil to test whether the + // address is v4 or v6. + if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil { + ip = net.ParseIP(parts[0]).To16() + } else { + ip = v4 + } + if ip != nil { + ips = append(ips, ip) + } + } + return ips, nil +} + +func getNonLoopbackInterface() (net.Interface, bool) { if interfaces, err := net.Interfaces(); err == nil { for _, intf := range interfaces { if intf.Name != "lo" { - ifname = intf.Name - break + return intf, true } } } + return net.Interface{}, false +} - return ifname, ifname != "" +func htons(x uint16) uint16 { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, x) + return binary.LittleEndian.Uint16(buf) } func localIP(ipv6 bool) string { diff --git a/test/iptables/nat.go b/test/iptables/nat.go index ac0d91bb2..b7fea2527 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -18,12 +18,11 @@ import ( "errors" "fmt" "net" + "syscall" "time" ) -const ( - redirectPort = 42 -) +const redirectPort = 42 func init() { RegisterTestCase(NATPreRedirectUDPPort{}) @@ -42,6 +41,8 @@ func init() { RegisterTestCase(NATOutRedirectInvert{}) RegisterTestCase(NATRedirectRequiresProtocol{}) RegisterTestCase(NATLoopbackSkipsPrerouting{}) + RegisterTestCase(NATPreOriginalDst{}) + RegisterTestCase(NATOutOriginalDst{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. @@ -471,6 +472,151 @@ func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP, ipv6 bool) error { return nil } +// NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of PREROUTING NATted packets. +type NATPreOriginalDst struct{} + +// Name implements TestCase.Name. +func (NATPreOriginalDst) Name() string { + return "NATPreOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "PREROUTING", + "-p", "tcp", + "--destination-port", fmt.Sprintf("%d", dropPort), + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + addrs, err := getInterfaceAddrs(ipv6) + if err != nil { + return err + } + return listenForRedirectedConn(ipv6, addrs) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { + return connectTCP(ip, dropPort, sendloopDuration) +} + +// NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of OUTBOUND NATted packets. +type NATOutOriginalDst struct{} + +// Name implements TestCase.Name. +func (NATOutOriginalDst) Name() string { + return "NATOutOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + connCh := make(chan error) + go func() { + connCh <- connectTCP(ip, dropPort, sendloopDuration) + }() + + if err := listenForRedirectedConn(ipv6, []net.IP{ip}); err != nil { + return err + } + return <-connCh +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { + // The net package doesn't give guarantee access to the connection's + // underlying FD, and thus we cannot call getsockopt. We have to use + // traditional syscalls for SO_ORIGINAL_DST. + + // Create the listening socket, bind, listen, and accept. + family := syscall.AF_INET + if ipv6 { + family = syscall.AF_INET6 + } + sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0) + if err != nil { + return err + } + defer syscall.Close(sockfd) + + var bindAddr syscall.Sockaddr + if ipv6 { + bindAddr = &syscall.SockaddrInet6{ + Port: acceptPort, + Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any + } + } else { + bindAddr = &syscall.SockaddrInet4{ + Port: acceptPort, + Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY + } + } + if err := syscall.Bind(sockfd, bindAddr); err != nil { + return err + } + + if err := syscall.Listen(sockfd, 1); err != nil { + return err + } + + connfd, _, err := syscall.Accept(sockfd) + if err != nil { + return err + } + defer syscall.Close(connfd) + + // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST + // indicates the packet was sent to originalDst:dropPort. + if ipv6 { + got, err := originalDestination6(connfd) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet6{ + Family: syscall.AF_INET6, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To16()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } else { + got, err := originalDestination4(connfd) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet4{ + Family: syscall.AF_INET, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To4()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } +} + // loopbackTests runs an iptables rule and ensures that packets sent to // dest:dropPort are received by localhost:acceptPort. func loopbackTest(ipv6 bool, dest net.IP, args ...string) error { -- cgit v1.2.3 From b2ae7ea1bb207eddadd7962080e7bd0b8634db96 Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Mon, 3 Aug 2020 13:33:47 -0700 Subject: Plumbing context.Context to DecRef() and Release(). context is passed to DecRef() and Release() which is needed for SO_LINGER implementation. PiperOrigin-RevId: 324672584 --- pkg/refs/BUILD | 6 +- pkg/refs/refcounter.go | 19 +-- pkg/refs/refcounter_test.go | 38 ++--- pkg/sentry/control/proc.go | 2 +- pkg/sentry/devices/memdev/full.go | 2 +- pkg/sentry/devices/memdev/null.go | 2 +- pkg/sentry/devices/memdev/random.go | 2 +- pkg/sentry/devices/memdev/zero.go | 2 +- pkg/sentry/devices/ttydev/ttydev.go | 2 +- pkg/sentry/devices/tundev/tundev.go | 4 +- pkg/sentry/fdimport/fdimport.go | 8 +- pkg/sentry/fs/copy_up.go | 12 +- pkg/sentry/fs/copy_up_test.go | 4 +- pkg/sentry/fs/dev/net_tun.go | 4 +- pkg/sentry/fs/dirent.go | 110 +++++++------- pkg/sentry/fs/dirent_cache.go | 3 +- pkg/sentry/fs/dirent_refs_test.go | 16 +-- pkg/sentry/fs/dirent_state.go | 3 +- pkg/sentry/fs/fdpipe/pipe.go | 2 +- pkg/sentry/fs/fdpipe/pipe_opener_test.go | 16 +-- pkg/sentry/fs/fdpipe/pipe_test.go | 18 +-- pkg/sentry/fs/file.go | 10 +- pkg/sentry/fs/file_operations.go | 2 +- pkg/sentry/fs/file_overlay.go | 22 +-- pkg/sentry/fs/fsutil/file.go | 4 +- pkg/sentry/fs/gofer/file.go | 4 +- pkg/sentry/fs/gofer/gofer_test.go | 8 +- pkg/sentry/fs/gofer/handles.go | 5 +- pkg/sentry/fs/gofer/inode.go | 5 +- pkg/sentry/fs/gofer/path.go | 6 +- pkg/sentry/fs/gofer/session.go | 16 +-- pkg/sentry/fs/gofer/session_state.go | 3 +- pkg/sentry/fs/gofer/socket.go | 6 +- pkg/sentry/fs/host/control.go | 2 +- pkg/sentry/fs/host/file.go | 4 +- pkg/sentry/fs/host/inode_test.go | 2 +- pkg/sentry/fs/host/socket.go | 10 +- pkg/sentry/fs/host/socket_test.go | 38 ++--- pkg/sentry/fs/host/tty.go | 4 +- pkg/sentry/fs/host/wait_test.go | 2 +- pkg/sentry/fs/inode.go | 11 +- pkg/sentry/fs/inode_inotify.go | 5 +- pkg/sentry/fs/inode_overlay.go | 30 ++-- pkg/sentry/fs/inode_overlay_test.go | 8 +- pkg/sentry/fs/inotify.go | 8 +- pkg/sentry/fs/inotify_watch.go | 9 +- pkg/sentry/fs/mount.go | 12 +- pkg/sentry/fs/mount_overlay.go | 6 +- pkg/sentry/fs/mount_test.go | 29 ++-- pkg/sentry/fs/mounts.go | 30 ++-- pkg/sentry/fs/mounts_test.go | 2 +- pkg/sentry/fs/overlay.go | 10 +- pkg/sentry/fs/proc/fds.go | 18 +-- pkg/sentry/fs/proc/mounts.go | 8 +- pkg/sentry/fs/proc/net.go | 12 +- pkg/sentry/fs/proc/proc.go | 2 +- pkg/sentry/fs/proc/task.go | 4 +- pkg/sentry/fs/ramfs/dir.go | 18 +-- pkg/sentry/fs/ramfs/tree_test.go | 2 +- pkg/sentry/fs/timerfd/timerfd.go | 4 +- pkg/sentry/fs/tmpfs/file_test.go | 2 +- pkg/sentry/fs/tty/dir.go | 8 +- pkg/sentry/fs/tty/fs.go | 2 +- pkg/sentry/fs/tty/master.go | 8 +- pkg/sentry/fs/tty/slave.go | 4 +- pkg/sentry/fs/user/path.go | 8 +- pkg/sentry/fs/user/user.go | 8 +- pkg/sentry/fs/user/user_test.go | 8 +- pkg/sentry/fsbridge/bridge.go | 2 +- pkg/sentry/fsbridge/fs.go | 8 +- pkg/sentry/fsbridge/vfs.go | 6 +- pkg/sentry/fsimpl/devpts/devpts.go | 4 +- pkg/sentry/fsimpl/devpts/master.go | 6 +- pkg/sentry/fsimpl/devpts/slave.go | 6 +- pkg/sentry/fsimpl/devtmpfs/devtmpfs.go | 6 +- pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go | 8 +- pkg/sentry/fsimpl/eventfd/eventfd.go | 6 +- pkg/sentry/fsimpl/eventfd/eventfd_test.go | 12 +- pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go | 6 +- pkg/sentry/fsimpl/ext/dentry.go | 7 +- pkg/sentry/fsimpl/ext/directory.go | 2 +- pkg/sentry/fsimpl/ext/ext.go | 10 +- pkg/sentry/fsimpl/ext/ext_test.go | 4 +- pkg/sentry/fsimpl/ext/filesystem.go | 68 ++++----- pkg/sentry/fsimpl/ext/regular_file.go | 2 +- pkg/sentry/fsimpl/ext/symlink.go | 2 +- pkg/sentry/fsimpl/fuse/dev.go | 2 +- pkg/sentry/fsimpl/fuse/dev_test.go | 4 +- pkg/sentry/fsimpl/fuse/fusefs.go | 4 +- pkg/sentry/fsimpl/gofer/directory.go | 4 +- pkg/sentry/fsimpl/gofer/filesystem.go | 90 ++++++------ pkg/sentry/fsimpl/gofer/gofer.go | 40 +++--- pkg/sentry/fsimpl/gofer/gofer_test.go | 6 +- pkg/sentry/fsimpl/gofer/regular_file.go | 2 +- pkg/sentry/fsimpl/gofer/socket.go | 6 +- pkg/sentry/fsimpl/gofer/special_file.go | 4 +- pkg/sentry/fsimpl/host/control.go | 2 +- pkg/sentry/fsimpl/host/host.go | 14 +- pkg/sentry/fsimpl/host/socket.go | 12 +- pkg/sentry/fsimpl/host/tty.go | 4 +- pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go | 2 +- pkg/sentry/fsimpl/kernfs/fd_impl_util.go | 2 +- pkg/sentry/fsimpl/kernfs/filesystem.go | 68 ++++----- pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 10 +- pkg/sentry/fsimpl/kernfs/kernfs.go | 26 ++-- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 18 +-- pkg/sentry/fsimpl/overlay/copy_up.go | 8 +- pkg/sentry/fsimpl/overlay/directory.go | 8 +- pkg/sentry/fsimpl/overlay/filesystem.go | 64 ++++----- pkg/sentry/fsimpl/overlay/non_directory.go | 18 +-- pkg/sentry/fsimpl/overlay/overlay.go | 48 +++---- pkg/sentry/fsimpl/pipefs/pipefs.go | 6 +- pkg/sentry/fsimpl/proc/filesystem.go | 4 +- pkg/sentry/fsimpl/proc/task_fds.go | 18 +-- pkg/sentry/fsimpl/proc/task_files.go | 14 +- pkg/sentry/fsimpl/proc/task_net.go | 12 +- pkg/sentry/fsimpl/proc/tasks_test.go | 10 +- pkg/sentry/fsimpl/signalfd/signalfd.go | 4 +- pkg/sentry/fsimpl/sockfs/sockfs.go | 4 +- pkg/sentry/fsimpl/sys/sys.go | 4 +- pkg/sentry/fsimpl/sys/sys_test.go | 2 +- pkg/sentry/fsimpl/testutil/kernel.go | 2 +- pkg/sentry/fsimpl/testutil/testutil.go | 6 +- pkg/sentry/fsimpl/timerfd/timerfd.go | 6 +- pkg/sentry/fsimpl/tmpfs/benchmark_test.go | 50 +++---- pkg/sentry/fsimpl/tmpfs/directory.go | 4 +- pkg/sentry/fsimpl/tmpfs/filesystem.go | 106 +++++++------- pkg/sentry/fsimpl/tmpfs/pipe_test.go | 20 +-- pkg/sentry/fsimpl/tmpfs/regular_file.go | 2 +- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 34 ++--- pkg/sentry/fsimpl/tmpfs/tmpfs_test.go | 6 +- pkg/sentry/kernel/abstract_socket_namespace.go | 13 +- pkg/sentry/kernel/epoll/epoll.go | 14 +- pkg/sentry/kernel/epoll/epoll_test.go | 5 +- pkg/sentry/kernel/eventfd/eventfd.go | 4 +- pkg/sentry/kernel/fd_table.go | 55 +++---- pkg/sentry/kernel/fd_table_test.go | 6 +- pkg/sentry/kernel/fs_context.go | 31 ++-- pkg/sentry/kernel/futex/BUILD | 1 + pkg/sentry/kernel/futex/futex.go | 35 ++--- pkg/sentry/kernel/futex/futex_test.go | 66 +++++---- pkg/sentry/kernel/kernel.go | 57 ++++---- pkg/sentry/kernel/pipe/node.go | 6 +- pkg/sentry/kernel/pipe/node_test.go | 2 +- pkg/sentry/kernel/pipe/pipe.go | 2 +- pkg/sentry/kernel/pipe/pipe_test.go | 16 +-- pkg/sentry/kernel/pipe/pipe_util.go | 2 +- pkg/sentry/kernel/pipe/reader.go | 3 +- pkg/sentry/kernel/pipe/vfs.go | 8 +- pkg/sentry/kernel/pipe/writer.go | 3 +- pkg/sentry/kernel/sessions.go | 5 +- pkg/sentry/kernel/shm/shm.go | 10 +- pkg/sentry/kernel/signalfd/signalfd.go | 2 +- pkg/sentry/kernel/task.go | 8 +- pkg/sentry/kernel/task_clone.go | 10 +- pkg/sentry/kernel/task_exec.go | 6 +- pkg/sentry/kernel/task_exit.go | 8 +- pkg/sentry/kernel/task_log.go | 2 +- pkg/sentry/kernel/task_start.go | 6 +- pkg/sentry/kernel/thread_group.go | 4 +- pkg/sentry/loader/elf.go | 4 +- pkg/sentry/loader/loader.go | 6 +- pkg/sentry/memmap/memmap.go | 2 +- pkg/sentry/mm/aio_context.go | 6 +- pkg/sentry/mm/lifecycle.go | 2 +- pkg/sentry/mm/metadata.go | 5 +- pkg/sentry/mm/special_mappable.go | 4 +- pkg/sentry/mm/syscalls.go | 4 +- pkg/sentry/mm/vma.go | 4 +- pkg/sentry/socket/control/control.go | 8 +- pkg/sentry/socket/control/control_vfs2.go | 8 +- pkg/sentry/socket/hostinet/socket.go | 8 +- pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/socket.go | 14 +- pkg/sentry/socket/netlink/socket_vfs2.go | 4 +- pkg/sentry/socket/netstack/netstack.go | 6 +- pkg/sentry/socket/netstack/netstack_vfs2.go | 2 +- pkg/sentry/socket/socket.go | 4 +- pkg/sentry/socket/unix/transport/connectioned.go | 10 +- pkg/sentry/socket/unix/transport/connectionless.go | 12 +- pkg/sentry/socket/unix/transport/queue.go | 13 +- pkg/sentry/socket/unix/transport/unix.go | 48 +++---- pkg/sentry/socket/unix/unix.go | 38 ++--- pkg/sentry/socket/unix/unix_vfs2.go | 18 +-- pkg/sentry/strace/strace.go | 12 +- pkg/sentry/syscalls/epoll.go | 18 +-- pkg/sentry/syscalls/linux/sys_aio.go | 8 +- pkg/sentry/syscalls/linux/sys_eventfd.go | 2 +- pkg/sentry/syscalls/linux/sys_file.go | 78 +++++----- pkg/sentry/syscalls/linux/sys_futex.go | 6 +- pkg/sentry/syscalls/linux/sys_getdents.go | 2 +- pkg/sentry/syscalls/linux/sys_inotify.go | 10 +- pkg/sentry/syscalls/linux/sys_lseek.go | 2 +- pkg/sentry/syscalls/linux/sys_mmap.go | 4 +- pkg/sentry/syscalls/linux/sys_mount.go | 2 +- pkg/sentry/syscalls/linux/sys_pipe.go | 6 +- pkg/sentry/syscalls/linux/sys_poll.go | 10 +- pkg/sentry/syscalls/linux/sys_prctl.go | 4 +- pkg/sentry/syscalls/linux/sys_read.go | 12 +- pkg/sentry/syscalls/linux/sys_shm.go | 10 +- pkg/sentry/syscalls/linux/sys_signal.go | 4 +- pkg/sentry/syscalls/linux/sys_socket.go | 46 +++--- pkg/sentry/syscalls/linux/sys_splice.go | 12 +- pkg/sentry/syscalls/linux/sys_stat.go | 8 +- pkg/sentry/syscalls/linux/sys_sync.go | 8 +- pkg/sentry/syscalls/linux/sys_thread.go | 6 +- pkg/sentry/syscalls/linux/sys_timerfd.go | 6 +- pkg/sentry/syscalls/linux/sys_write.go | 10 +- pkg/sentry/syscalls/linux/sys_xattr.go | 8 +- pkg/sentry/syscalls/linux/vfs2/aio.go | 8 +- pkg/sentry/syscalls/linux/vfs2/epoll.go | 14 +- pkg/sentry/syscalls/linux/vfs2/eventfd.go | 4 +- pkg/sentry/syscalls/linux/vfs2/execve.go | 12 +- pkg/sentry/syscalls/linux/vfs2/fd.go | 12 +- pkg/sentry/syscalls/linux/vfs2/filesystem.go | 22 +-- pkg/sentry/syscalls/linux/vfs2/fscontext.go | 22 +-- pkg/sentry/syscalls/linux/vfs2/getdents.go | 2 +- pkg/sentry/syscalls/linux/vfs2/inotify.go | 14 +- pkg/sentry/syscalls/linux/vfs2/ioctl.go | 2 +- pkg/sentry/syscalls/linux/vfs2/lock.go | 2 +- pkg/sentry/syscalls/linux/vfs2/memfd.go | 2 +- pkg/sentry/syscalls/linux/vfs2/mmap.go | 4 +- pkg/sentry/syscalls/linux/vfs2/mount.go | 4 +- pkg/sentry/syscalls/linux/vfs2/path.go | 12 +- pkg/sentry/syscalls/linux/vfs2/pipe.go | 6 +- pkg/sentry/syscalls/linux/vfs2/poll.go | 10 +- pkg/sentry/syscalls/linux/vfs2/read_write.go | 48 +++---- pkg/sentry/syscalls/linux/vfs2/setstat.go | 20 +-- pkg/sentry/syscalls/linux/vfs2/signal.go | 4 +- pkg/sentry/syscalls/linux/vfs2/socket.go | 46 +++--- pkg/sentry/syscalls/linux/vfs2/splice.go | 20 +-- pkg/sentry/syscalls/linux/vfs2/stat.go | 30 ++-- pkg/sentry/syscalls/linux/vfs2/sync.go | 6 +- pkg/sentry/syscalls/linux/vfs2/timerfd.go | 8 +- pkg/sentry/syscalls/linux/vfs2/xattr.go | 16 +-- pkg/sentry/vfs/anonfs.go | 8 +- pkg/sentry/vfs/dentry.go | 37 ++--- pkg/sentry/vfs/epoll.go | 6 +- pkg/sentry/vfs/file_description.go | 24 ++-- pkg/sentry/vfs/file_description_impl_util_test.go | 18 +-- pkg/sentry/vfs/filesystem.go | 6 +- pkg/sentry/vfs/inotify.go | 34 ++--- pkg/sentry/vfs/mount.go | 54 +++---- pkg/sentry/vfs/pathname.go | 18 +-- pkg/sentry/vfs/resolving_path.go | 43 +++--- pkg/sentry/vfs/vfs.go | 160 ++++++++++----------- pkg/tcpip/link/tun/BUILD | 1 + pkg/tcpip/link/tun/device.go | 9 +- runsc/boot/fs.go | 12 +- runsc/boot/loader.go | 16 +-- runsc/boot/loader_test.go | 8 +- runsc/boot/vfs.go | 10 +- 252 files changed, 1711 insertions(+), 1668 deletions(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 74affc887..9888cce9c 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -24,6 +24,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/context", "//pkg/log", "//pkg/sync", ], @@ -34,5 +35,8 @@ go_test( size = "small", srcs = ["refcounter_test.go"], library = ":refs", - deps = ["//pkg/sync"], + deps = [ + "//pkg/context", + "//pkg/sync", + ], ) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index c45ba8200..61790221b 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -23,6 +23,7 @@ import ( "runtime" "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" ) @@ -38,7 +39,7 @@ type RefCounter interface { // Note that AtomicRefCounter.DecRef() does not support destructors. // If a type has a destructor, it must implement its own DecRef() // method and call AtomicRefCounter.DecRefWithDestructor(destructor). - DecRef() + DecRef(ctx context.Context) // TryIncRef attempts to increase the reference counter on the object, // but may fail if all references have already been dropped. This @@ -57,7 +58,7 @@ type RefCounter interface { // A WeakRefUser is notified when the last non-weak reference is dropped. type WeakRefUser interface { // WeakRefGone is called when the last non-weak reference is dropped. - WeakRefGone() + WeakRefGone(ctx context.Context) } // WeakRef is a weak reference. @@ -123,7 +124,7 @@ func (w *WeakRef) Get() RefCounter { // Drop drops this weak reference. You should always call drop when you are // finished with the weak reference. You may not use this object after calling // drop. -func (w *WeakRef) Drop() { +func (w *WeakRef) Drop(ctx context.Context) { rc, ok := w.get() if !ok { // We've been zapped already. When the refcounter has called @@ -145,7 +146,7 @@ func (w *WeakRef) Drop() { // And now aren't on the object's list of weak references. So it won't // zap us if this causes the reference count to drop to zero. - rc.DecRef() + rc.DecRef(ctx) // Return to the pool. weakRefPool.Put(w) @@ -427,7 +428,7 @@ func (r *AtomicRefCount) dropWeakRef(w *WeakRef) { // A: TryIncRef [transform speculative to real] // //go:nosplit -func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) { +func (r *AtomicRefCount) DecRefWithDestructor(ctx context.Context, destroy func(context.Context)) { switch v := atomic.AddInt64(&r.refCount, -1); { case v < -1: panic("Decrementing non-positive ref count") @@ -448,7 +449,7 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) { if user != nil { r.mu.Unlock() - user.WeakRefGone() + user.WeakRefGone(ctx) r.mu.Lock() } } @@ -456,7 +457,7 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) { // Call the destructor. if destroy != nil { - destroy() + destroy(ctx) } } } @@ -464,6 +465,6 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) { // DecRef decrements this object's reference count. // //go:nosplit -func (r *AtomicRefCount) DecRef() { - r.DecRefWithDestructor(nil) +func (r *AtomicRefCount) DecRef(ctx context.Context) { + r.DecRefWithDestructor(ctx, nil) } diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go index 1ab4a4440..6d0dd1018 100644 --- a/pkg/refs/refcounter_test.go +++ b/pkg/refs/refcounter_test.go @@ -18,6 +18,7 @@ import ( "reflect" "testing" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) @@ -31,11 +32,11 @@ type testCounter struct { destroyed bool } -func (t *testCounter) DecRef() { - t.AtomicRefCount.DecRefWithDestructor(t.destroy) +func (t *testCounter) DecRef(ctx context.Context) { + t.AtomicRefCount.DecRefWithDestructor(ctx, t.destroy) } -func (t *testCounter) destroy() { +func (t *testCounter) destroy(context.Context) { t.mu.Lock() defer t.mu.Unlock() t.destroyed = true @@ -53,7 +54,7 @@ func newTestCounter() *testCounter { func TestOneRef(t *testing.T) { tc := newTestCounter() - tc.DecRef() + tc.DecRef(context.Background()) if !tc.IsDestroyed() { t.Errorf("object should have been destroyed") @@ -63,8 +64,9 @@ func TestOneRef(t *testing.T) { func TestTwoRefs(t *testing.T) { tc := newTestCounter() tc.IncRef() - tc.DecRef() - tc.DecRef() + ctx := context.Background() + tc.DecRef(ctx) + tc.DecRef(ctx) if !tc.IsDestroyed() { t.Errorf("object should have been destroyed") @@ -74,12 +76,13 @@ func TestTwoRefs(t *testing.T) { func TestMultiRefs(t *testing.T) { tc := newTestCounter() tc.IncRef() - tc.DecRef() + ctx := context.Background() + tc.DecRef(ctx) tc.IncRef() - tc.DecRef() + tc.DecRef(ctx) - tc.DecRef() + tc.DecRef(ctx) if !tc.IsDestroyed() { t.Errorf("object should have been destroyed") @@ -89,19 +92,20 @@ func TestMultiRefs(t *testing.T) { func TestWeakRef(t *testing.T) { tc := newTestCounter() w := NewWeakRef(tc, nil) + ctx := context.Background() // Try resolving. if x := w.Get(); x == nil { t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) } else { - x.DecRef() + x.DecRef(ctx) } // Try resolving again. if x := w.Get(); x == nil { t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) } else { - x.DecRef() + x.DecRef(ctx) } // Shouldn't be destroyed yet. (Can't continue if this fails.) @@ -110,7 +114,7 @@ func TestWeakRef(t *testing.T) { } // Drop the original reference. - tc.DecRef() + tc.DecRef(ctx) // Assert destroyed. if !tc.IsDestroyed() { @@ -126,7 +130,8 @@ func TestWeakRef(t *testing.T) { func TestWeakRefDrop(t *testing.T) { tc := newTestCounter() w := NewWeakRef(tc, nil) - w.Drop() + ctx := context.Background() + w.Drop(ctx) // Just assert the list is empty. if !tc.weakRefs.Empty() { @@ -134,14 +139,14 @@ func TestWeakRefDrop(t *testing.T) { } // Drop the original reference. - tc.DecRef() + tc.DecRef(ctx) } type testWeakRefUser struct { weakRefGone func() } -func (u *testWeakRefUser) WeakRefGone() { +func (u *testWeakRefUser) WeakRefGone(ctx context.Context) { u.weakRefGone() } @@ -165,7 +170,8 @@ func TestCallback(t *testing.T) { }}) // Drop the original reference, this must trigger the callback. - tc.DecRef() + ctx := context.Background() + tc.DecRef(ctx) if !called { t.Fatalf("Callback not called") diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 1bae7cfaf..dfa936563 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -139,7 +139,6 @@ func ExecAsync(proc *Proc, args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { // Import file descriptors. fdTable := proc.Kernel.NewFDTable() - defer fdTable.DecRef() creds := auth.NewUserCredentials( args.KUID, @@ -177,6 +176,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI initArgs.MountNamespaceVFS2.IncRef() } ctx := initArgs.NewContext(proc.Kernel) + defer fdTable.DecRef(ctx) if kernel.VFS2Enabled { // Get the full path to the filename from the PATH env variable. diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go index af66fe4dc..511179e31 100644 --- a/pkg/sentry/devices/memdev/full.go +++ b/pkg/sentry/devices/memdev/full.go @@ -46,7 +46,7 @@ type fullFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *fullFD) Release() { +func (fd *fullFD) Release(context.Context) { // noop } diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go index 92d3d71be..4918dbeeb 100644 --- a/pkg/sentry/devices/memdev/null.go +++ b/pkg/sentry/devices/memdev/null.go @@ -47,7 +47,7 @@ type nullFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *nullFD) Release() { +func (fd *nullFD) Release(context.Context) { // noop } diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go index 6b81da5ef..5e7fe0280 100644 --- a/pkg/sentry/devices/memdev/random.go +++ b/pkg/sentry/devices/memdev/random.go @@ -56,7 +56,7 @@ type randomFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *randomFD) Release() { +func (fd *randomFD) Release(context.Context) { // noop } diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index c6f15054d..2e631a252 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -48,7 +48,7 @@ type zeroFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *zeroFD) Release() { +func (fd *zeroFD) Release(context.Context) { // noop } diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go index fbb7fd92c..fd4b79c46 100644 --- a/pkg/sentry/devices/ttydev/ttydev.go +++ b/pkg/sentry/devices/ttydev/ttydev.go @@ -55,7 +55,7 @@ type ttyFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *ttyFD) Release() {} +func (fd *ttyFD) Release(context.Context) {} // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index dfbd069af..852ec3c5c 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -108,8 +108,8 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *tunFD) Release() { - fd.device.Release() +func (fd *tunFD) Release(ctx context.Context) { + fd.device.Release(ctx) } // PRead implements vfs.FileDescriptionImpl.PRead. diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index b8686adb4..1b7cb94c0 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -50,7 +50,7 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] if err != nil { return nil, err } - defer appFile.DecRef() + defer appFile.DecRef(ctx) // Remember this in the TTY file, as we will // use it for the other stdio FDs. @@ -69,7 +69,7 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] if err != nil { return nil, err } - defer appFile.DecRef() + defer appFile.DecRef(ctx) } // Add the file to the FD map. @@ -102,7 +102,7 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi if err != nil { return nil, err } - defer appFile.DecRef() + defer appFile.DecRef(ctx) // Remember this in the TTY file, as we will use it for the other stdio // FDs. @@ -119,7 +119,7 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi if err != nil { return nil, err } - defer appFile.DecRef() + defer appFile.DecRef(ctx) } if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil { diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index ab1424c95..735452b07 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -201,7 +201,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { parentUpper := parent.Inode.overlay.upper root := RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } // Create the file in the upper filesystem and get an Inode for it. @@ -212,7 +212,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { log.Warningf("copy up failed to create file: %v", err) return syserror.EIO } - defer childFile.DecRef() + defer childFile.DecRef(ctx) childUpperInode = childFile.Dirent.Inode case Directory: @@ -226,7 +226,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { cleanupUpper(ctx, parentUpper, next.name, werr) return syserror.EIO } - defer childUpper.DecRef() + defer childUpper.DecRef(ctx) childUpperInode = childUpper.Inode case Symlink: @@ -246,7 +246,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { cleanupUpper(ctx, parentUpper, next.name, werr) return syserror.EIO } - defer childUpper.DecRef() + defer childUpper.DecRef(ctx) childUpperInode = childUpper.Inode default: @@ -352,14 +352,14 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in if err != nil { return err } - defer upperFile.DecRef() + defer upperFile.DecRef(ctx) // Get a handle to the lower filesystem, which we will read from. lowerFile, err := overlayFile(ctx, lower, FileFlags{Read: true}) if err != nil { return err } - defer lowerFile.DecRef() + defer lowerFile.DecRef(ctx) // Use a buffer pool to minimize allocations. buf := copyUpBuffers.Get().([]byte) diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index 91792d9fe..c7a11eec1 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -126,7 +126,7 @@ func makeOverlayTestFiles(t *testing.T) []*overlayTestFile { if err != nil { t.Fatalf("failed to create file %q: %v", name, err) } - defer f.DecRef() + defer f.DecRef(ctx) relname, _ := f.Dirent.FullName(lowerRoot) @@ -171,7 +171,7 @@ func makeOverlayTestFiles(t *testing.T) []*overlayTestFile { if err != nil { t.Fatalf("failed to find %q: %v", f.name, err) } - defer d.DecRef() + defer d.DecRef(ctx) f.File, err = d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) if err != nil { diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index dc7ad075a..ec474e554 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -80,8 +80,8 @@ type netTunFileOperations struct { var _ fs.FileOperations = (*netTunFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (fops *netTunFileOperations) Release() { - fops.device.Release() +func (fops *netTunFileOperations) Release(ctx context.Context) { + fops.device.Release(ctx) } // Ioctl implements fs.FileOperations.Ioctl. diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 65be12175..a2f751068 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -325,7 +325,7 @@ func (d *Dirent) SyncAll(ctx context.Context) { for _, w := range d.children { if child := w.Get(); child != nil { child.(*Dirent).SyncAll(ctx) - child.DecRef() + child.DecRef(ctx) } } } @@ -451,7 +451,7 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl // which don't hold a hard reference on their parent (their parent holds a // hard reference on them, and they contain virtually no state). But this is // good house-keeping. - child.DecRef() + child.DecRef(ctx) return nil, syscall.ENOENT } @@ -468,20 +468,20 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl // their pins on the child. Inotify doesn't properly support filesystems that // revalidate dirents (since watches are lost on revalidation), but if we fail // to unpin the watches child will never be GCed. - cd.Inode.Watches.Unpin(cd) + cd.Inode.Watches.Unpin(ctx, cd) // This child needs to be revalidated, fallthrough to unhash it. Make sure // to not leak a reference from Get(). // // Note that previous lookups may still have a reference to this stale child; // this can't be helped, but we can ensure that *new* lookups are up-to-date. - child.DecRef() + child.DecRef(ctx) } // Either our weak reference expired or we need to revalidate it. Unhash child first, we're // about to replace it. delete(d.children, name) - w.Drop() + w.Drop(ctx) } // Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be @@ -512,12 +512,12 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl // There are active references to the existing child, prefer it to the one we // retrieved from Lookup. Likely the Lookup happened very close to the insertion // of child, so considering one stale over the other is fairly arbitrary. - c.DecRef() + c.DecRef(ctx) // The child that was installed could be negative. if cd.IsNegative() { // If so, don't leak a reference and short circuit. - child.DecRef() + child.DecRef(ctx) return nil, syscall.ENOENT } @@ -531,7 +531,7 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl // we did the Inode.Lookup. Fully drop the weak reference and fallback to using the child // we looked up. delete(d.children, name) - w.Drop() + w.Drop(ctx) } // Give the looked up child a parent. We cannot kick out entries, since we just checked above @@ -587,7 +587,7 @@ func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool { return false } // Child exists. - child.DecRef() + child.DecRef(ctx) return true } @@ -622,7 +622,7 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi } child := file.Dirent - d.finishCreate(child, name) + d.finishCreate(ctx, child, name) // Return the reference and the new file. When the last reference to // the file is dropped, file.Dirent may no longer be cached. @@ -631,7 +631,7 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi // finishCreate validates the created file, adds it as a child of this dirent, // and notifies any watchers. -func (d *Dirent) finishCreate(child *Dirent, name string) { +func (d *Dirent) finishCreate(ctx context.Context, child *Dirent, name string) { // Sanity check c, its name must be consistent. if child.name != name { panic(fmt.Sprintf("create from %q to %q returned unexpected name %q", d.name, name, child.name)) @@ -650,14 +650,14 @@ func (d *Dirent) finishCreate(child *Dirent, name string) { panic(fmt.Sprintf("hashed child %q over a positive child", child.name)) } // Don't leak a reference. - old.DecRef() + old.DecRef(ctx) // Drop d's reference. - old.DecRef() + old.DecRef(ctx) } // Finally drop the useless weak reference on the floor. - w.Drop() + w.Drop(ctx) } d.Inode.Watches.Notify(name, linux.IN_CREATE, 0) @@ -686,17 +686,17 @@ func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, c panic(fmt.Sprintf("hashed over a positive child %q", old.(*Dirent).name)) } // Don't leak a reference. - old.DecRef() + old.DecRef(ctx) // Drop d's reference. - old.DecRef() + old.DecRef(ctx) } // Unhash the negative Dirent, name needs to exist now. delete(d.children, name) // Finally drop the useless weak reference on the floor. - w.Drop() + w.Drop(ctx) } // Execute the create operation. @@ -756,7 +756,7 @@ func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data trans if e != nil { return e } - d.finishCreate(childDir, name) + d.finishCreate(ctx, childDir, name) return nil }) if err == syscall.EEXIST { @@ -901,7 +901,7 @@ func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent, // references to children. // // Preconditions: d.mu must be held. -func (d *Dirent) flush() { +func (d *Dirent) flush(ctx context.Context) { expired := make(map[string]*refs.WeakRef) for n, w := range d.children { // Call flush recursively on each child before removing our @@ -912,7 +912,7 @@ func (d *Dirent) flush() { if !cd.IsNegative() { // Flush the child. cd.mu.Lock() - cd.flush() + cd.flush(ctx) cd.mu.Unlock() // Allow the file system to drop extra references on child. @@ -920,13 +920,13 @@ func (d *Dirent) flush() { } // Don't leak a reference. - child.DecRef() + child.DecRef(ctx) } // Check if the child dirent is closed, and mark it as expired if it is. // We must call w.Get() again here, since the child could have been closed // by the calls to flush() and cache.Remove() in the above if-block. if child := w.Get(); child != nil { - child.DecRef() + child.DecRef(ctx) } else { expired[n] = w } @@ -935,7 +935,7 @@ func (d *Dirent) flush() { // Remove expired entries. for n, w := range expired { delete(d.children, n) - w.Drop() + w.Drop(ctx) } } @@ -977,7 +977,7 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err if !ok { panic("mount must mount over an existing dirent") } - weakRef.Drop() + weakRef.Drop(ctx) // Note that even though `d` is now hidden, it still holds a reference // to its parent. @@ -1002,13 +1002,13 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error { if !ok { panic("mount must mount over an existing dirent") } - weakRef.Drop() + weakRef.Drop(ctx) // d is not reachable anymore, and hence not mounted anymore. d.mounted = false // Drop mount reference. - d.DecRef() + d.DecRef(ctx) return nil } @@ -1029,7 +1029,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath // Child does not exist. return err } - defer child.DecRef() + defer child.DecRef(ctx) // Remove cannot remove directories. if IsDir(child.Inode.StableAttr) { @@ -1055,7 +1055,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath atomic.StoreInt32(&child.deleted, 1) if w, ok := d.children[name]; ok { delete(d.children, name) - w.Drop() + w.Drop(ctx) } // Allow the file system to drop extra references on child. @@ -1067,7 +1067,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath // inode may have other links. If this was the last link, the events for the // watch removal will be queued by the inode destructor. child.Inode.Watches.MarkUnlinked() - child.Inode.Watches.Unpin(child) + child.Inode.Watches.Unpin(ctx, child) d.Inode.Watches.Notify(name, linux.IN_DELETE, 0) return nil @@ -1100,7 +1100,7 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string) // Child does not exist. return err } - defer child.DecRef() + defer child.DecRef(ctx) // RemoveDirectory can only remove directories. if !IsDir(child.Inode.StableAttr) { @@ -1121,7 +1121,7 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string) atomic.StoreInt32(&child.deleted, 1) if w, ok := d.children[name]; ok { delete(d.children, name) - w.Drop() + w.Drop(ctx) } // Allow the file system to drop extra references on child. @@ -1130,14 +1130,14 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string) // Finally, let inotify know the child is being unlinked. Drop any extra // refs from inotify to this child dirent. child.Inode.Watches.MarkUnlinked() - child.Inode.Watches.Unpin(child) + child.Inode.Watches.Unpin(ctx, child) d.Inode.Watches.Notify(name, linux.IN_ISDIR|linux.IN_DELETE, 0) return nil } // destroy closes this node and all children. -func (d *Dirent) destroy() { +func (d *Dirent) destroy(ctx context.Context) { if d.IsNegative() { // Nothing to tear-down and no parent references to drop, since a negative // Dirent does not take a references on its parent, has no Inode and no children. @@ -1153,19 +1153,19 @@ func (d *Dirent) destroy() { if c.(*Dirent).IsNegative() { // The parent holds both weak and strong refs in the case of // negative dirents. - c.DecRef() + c.DecRef(ctx) } // Drop the reference we just acquired in WeakRef.Get. - c.DecRef() + c.DecRef(ctx) } - w.Drop() + w.Drop(ctx) } d.children = nil allDirents.remove(d) // Drop our reference to the Inode. - d.Inode.DecRef() + d.Inode.DecRef(ctx) // Allow the Dirent to be GC'ed after this point, since the Inode may still // be referenced after the Dirent is destroyed (for instance by filesystem @@ -1175,7 +1175,7 @@ func (d *Dirent) destroy() { // Drop the reference we have on our parent if we took one. renameMu doesn't need to be // held because d can't be reparented without any references to it left. if d.parent != nil { - d.parent.DecRef() + d.parent.DecRef(ctx) } } @@ -1201,14 +1201,14 @@ func (d *Dirent) TryIncRef() bool { // DecRef decreases the Dirent's refcount and drops its reference on its mount. // // DecRef implements RefCounter.DecRef with destructor d.destroy. -func (d *Dirent) DecRef() { +func (d *Dirent) DecRef(ctx context.Context) { if d.Inode != nil { // Keep mount around, since DecRef may destroy d.Inode. msrc := d.Inode.MountSource - d.DecRefWithDestructor(d.destroy) + d.DecRefWithDestructor(ctx, d.destroy) msrc.DecDirentRefs() } else { - d.DecRefWithDestructor(d.destroy) + d.DecRefWithDestructor(ctx, d.destroy) } } @@ -1359,7 +1359,7 @@ func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error if err != nil { return err } - defer victim.DecRef() + defer victim.DecRef(ctx) return d.mayDelete(ctx, victim) } @@ -1411,7 +1411,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string if err != nil { return err } - defer renamed.DecRef() + defer renamed.DecRef(ctx) // Check that the renamed dirent is deletable. if err := oldParent.mayDelete(ctx, renamed); err != nil { @@ -1453,13 +1453,13 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // Check that we can delete replaced. if err := newParent.mayDelete(ctx, replaced); err != nil { - replaced.DecRef() + replaced.DecRef(ctx) return err } // Target should not be an ancestor of source. if oldParent.descendantOf(replaced) { - replaced.DecRef() + replaced.DecRef(ctx) // Note that Linux returns EINVAL if the source is an // ancestor of target, but ENOTEMPTY if the target is @@ -1470,7 +1470,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // Check that replaced is not a mount point. if replaced.isMountPointLocked() { - replaced.DecRef() + replaced.DecRef(ctx) return syscall.EBUSY } @@ -1478,11 +1478,11 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string oldIsDir := IsDir(renamed.Inode.StableAttr) newIsDir := IsDir(replaced.Inode.StableAttr) if !newIsDir && oldIsDir { - replaced.DecRef() + replaced.DecRef(ctx) return syscall.ENOTDIR } if !oldIsDir && newIsDir { - replaced.DecRef() + replaced.DecRef(ctx) return syscall.EISDIR } @@ -1493,13 +1493,13 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // open across renames is currently broken for multiple // reasons, so we flush all references on the replaced node and // its children. - replaced.Inode.Watches.Unpin(replaced) + replaced.Inode.Watches.Unpin(ctx, replaced) replaced.mu.Lock() - replaced.flush() + replaced.flush(ctx) replaced.mu.Unlock() // Done with replaced. - replaced.DecRef() + replaced.DecRef(ctx) } if err := renamed.Inode.Rename(ctx, oldParent, renamed, newParent, newName, replaced != nil); err != nil { @@ -1513,14 +1513,14 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // can't destroy oldParent (and try to retake its lock) because // Rename's caller must be holding a reference. newParent.IncRef() - oldParent.DecRef() + oldParent.DecRef(ctx) } if w, ok := newParent.children[newName]; ok { - w.Drop() + w.Drop(ctx) delete(newParent.children, newName) } if w, ok := oldParent.children[oldName]; ok { - w.Drop() + w.Drop(ctx) delete(oldParent.children, oldName) } @@ -1551,7 +1551,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // Same as replaced.flush above. renamed.mu.Lock() - renamed.flush() + renamed.flush(ctx) renamed.mu.Unlock() return nil diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go index 33de32c69..7d9dd717e 100644 --- a/pkg/sentry/fs/dirent_cache.go +++ b/pkg/sentry/fs/dirent_cache.go @@ -17,6 +17,7 @@ package fs import ( "fmt" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) @@ -101,7 +102,7 @@ func (c *DirentCache) remove(d *Dirent) { panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d)) } c.list.Remove(d) - d.DecRef() + d.DecRef(context.Background()) c.currentSize-- if c.limit != nil { c.limit.dec() diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go index 98d69c6f2..176b894ba 100644 --- a/pkg/sentry/fs/dirent_refs_test.go +++ b/pkg/sentry/fs/dirent_refs_test.go @@ -51,7 +51,7 @@ func TestWalkPositive(t *testing.T) { t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 1) } - d.DecRef() + d.DecRef(ctx) if got := root.ReadRefs(); got != 1 { t.Fatalf("root has a ref count of %d, want %d", got, 1) @@ -61,7 +61,7 @@ func TestWalkPositive(t *testing.T) { t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 0) } - root.flush() + root.flush(ctx) if got := len(root.children); got != 0 { t.Fatalf("root has %d children, want %d", got, 0) @@ -114,7 +114,7 @@ func TestWalkNegative(t *testing.T) { t.Fatalf("child has a ref count of %d, want %d", got, 2) } - child.DecRef() + child.DecRef(ctx) if got := child.(*Dirent).ReadRefs(); got != 1 { t.Fatalf("child has a ref count of %d, want %d", got, 1) @@ -124,7 +124,7 @@ func TestWalkNegative(t *testing.T) { t.Fatalf("root has %d children, want %d", got, 1) } - root.DecRef() + root.DecRef(ctx) if got := root.ReadRefs(); got != 0 { t.Fatalf("root has a ref count of %d, want %d", got, 0) @@ -351,9 +351,9 @@ func TestRemoveExtraRefs(t *testing.T) { t.Fatalf("dirent has a ref count of %d, want %d", got, 1) } - d.DecRef() + d.DecRef(ctx) - test.root.flush() + test.root.flush(ctx) if got := len(test.root.children); got != 0 { t.Errorf("root has %d children, want %d", got, 0) @@ -403,8 +403,8 @@ func TestRenameExtraRefs(t *testing.T) { t.Fatalf("Rename got error %v, want nil", err) } - oldParent.flush() - newParent.flush() + oldParent.flush(ctx) + newParent.flush(ctx) // Expect to have only active references. if got := renamed.ReadRefs(); got != 1 { diff --git a/pkg/sentry/fs/dirent_state.go b/pkg/sentry/fs/dirent_state.go index f623d6c0e..67a35f0b2 100644 --- a/pkg/sentry/fs/dirent_state.go +++ b/pkg/sentry/fs/dirent_state.go @@ -18,6 +18,7 @@ import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" ) @@ -48,7 +49,7 @@ func (d *Dirent) saveChildren() map[string]*Dirent { for name, w := range d.children { if rc := w.Get(); rc != nil { // Drop the reference count obtain in w.Get() - rc.DecRef() + rc.DecRef(context.Background()) cd := rc.(*Dirent) if cd.IsNegative() { diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 9fce177ad..b99199798 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -115,7 +115,7 @@ func (p *pipeOperations) Readiness(mask waiter.EventMask) (eventMask waiter.Even } // Release implements fs.FileOperations.Release. -func (p *pipeOperations) Release() { +func (p *pipeOperations) Release(context.Context) { fdnotifier.RemoveFD(int32(p.file.FD())) p.file.Close() p.file = nil diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go index e556da48a..b9cec4b13 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go @@ -182,7 +182,7 @@ func TestTryOpen(t *testing.T) { // Cleanup the state of the pipe, and remove the fd from the // fdnotifier. Sadly this needed to maintain the correctness // of other tests because the fdnotifier is global. - pipeOps.Release() + pipeOps.Release(ctx) } continue } @@ -191,7 +191,7 @@ func TestTryOpen(t *testing.T) { } if pipeOps != nil { // Same as above. - pipeOps.Release() + pipeOps.Release(ctx) } } } @@ -279,7 +279,7 @@ func TestPipeOpenUnblocksEventually(t *testing.T) { pipeOps, err := Open(ctx, opener, flags) if pipeOps != nil { // Same as TestTryOpen. - pipeOps.Release() + pipeOps.Release(ctx) } // Check that the partner opened the file successfully. @@ -325,7 +325,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) { ctx := contexttest.Context(t) pipeOps, err := pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true}) if pipeOps != nil { - pipeOps.Release() + pipeOps.Release(ctx) t.Fatalf("open(%s, %o) got file, want nil", name, syscall.O_RDONLY) } if err != syserror.ErrWouldBlock { @@ -351,7 +351,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) { if pipeOps == nil { t.Fatalf("open(%s, %o) got nil file, want not nil", name, syscall.O_RDONLY) } - defer pipeOps.Release() + defer pipeOps.Release(ctx) if err != nil { t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_RDONLY, err) @@ -471,14 +471,14 @@ func TestPipeHangup(t *testing.T) { f := <-fdchan if f < 0 { t.Errorf("%s: partner routine got fd %d, want > 0", test.desc, f) - pipeOps.Release() + pipeOps.Release(ctx) continue } if test.hangupSelf { // Hangup self and assert that our partner got the expected hangup // error. - pipeOps.Release() + pipeOps.Release(ctx) if test.flags.Read { // Partner is writer. @@ -490,7 +490,7 @@ func TestPipeHangup(t *testing.T) { } else { // Hangup our partner and expect us to get the hangup error. syscall.Close(f) - defer pipeOps.Release() + defer pipeOps.Release(ctx) if test.flags.Read { assertReaderHungup(t, test.desc, pipeOps.(*pipeOperations).file) diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index a0082ecca..1c9e82562 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -98,10 +98,11 @@ func TestNewPipe(t *testing.T) { } f := fd.New(gfd) - p, err := newPipeOperations(contexttest.Context(t), nil, test.flags, f, test.readAheadBuffer) + ctx := contexttest.Context(t) + p, err := newPipeOperations(ctx, nil, test.flags, f, test.readAheadBuffer) if p != nil { // This is necessary to remove the fd from the global fd notifier. - defer p.Release() + defer p.Release(ctx) } else { // If there is no p to DecRef on, because newPipeOperations failed, then the // file still needs to be closed. @@ -153,13 +154,14 @@ func TestPipeDestruction(t *testing.T) { syscall.Close(fds[1]) // Test the read end, but it doesn't really matter which. - p, err := newPipeOperations(contexttest.Context(t), nil, fs.FileFlags{Read: true}, f, nil) + ctx := contexttest.Context(t) + p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, f, nil) if err != nil { f.Close() t.Fatalf("newPipeOperations got error %v, want nil", err) } // Drop our only reference, which should trigger the destructor. - p.Release() + p.Release(ctx) if fdnotifier.HasFD(int32(fds[0])) { t.Fatalf("after DecRef fdnotifier has fd %d, want no longer registered", fds[0]) @@ -282,7 +284,7 @@ func TestPipeRequest(t *testing.T) { if err != nil { t.Fatalf("%s: newPipeOperations got error %v, want nil", test.desc, err) } - defer p.Release() + defer p.Release(ctx) inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p) @@ -334,7 +336,7 @@ func TestPipeReadAheadBuffer(t *testing.T) { rfile.Close() t.Fatalf("newPipeOperations got error %v, want nil", err) } - defer p.Release() + defer p.Release(ctx) inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ Type: fs.Pipe, @@ -380,7 +382,7 @@ func TestPipeReadsAccumulate(t *testing.T) { } // Don't forget to remove the fd from the fd notifier. Otherwise other tests will // likely be borked, because it's global :( - defer p.Release() + defer p.Release(ctx) inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ Type: fs.Pipe, @@ -448,7 +450,7 @@ func TestPipeWritesAccumulate(t *testing.T) { } // Don't forget to remove the fd from the fd notifier. Otherwise other tests // will likely be borked, because it's global :( - defer p.Release() + defer p.Release(ctx) inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ Type: fs.Pipe, diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index ca41520b4..72ea70fcf 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -142,17 +142,17 @@ func NewFile(ctx context.Context, dirent *Dirent, flags FileFlags, fops FileOper } // DecRef destroys the File when it is no longer referenced. -func (f *File) DecRef() { - f.DecRefWithDestructor(func() { +func (f *File) DecRef(ctx context.Context) { + f.DecRefWithDestructor(ctx, func(context.Context) { // Drop BSD style locks. lockRng := lock.LockRange{Start: 0, End: lock.LockEOF} f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng) // Release resources held by the FileOperations. - f.FileOperations.Release() + f.FileOperations.Release(ctx) // Release a reference on the Dirent. - f.Dirent.DecRef() + f.Dirent.DecRef(ctx) // Only unregister if we are currently registered. There is nothing // to register if f.async is nil (this happens when async mode is @@ -460,7 +460,7 @@ func (f *File) UnstableAttr(ctx context.Context) (UnstableAttr, error) { func (f *File) MappedName(ctx context.Context) string { root := RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } name, _ := f.Dirent.FullName(root) return name diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index f5537411e..305c0f840 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -67,7 +67,7 @@ type SpliceOpts struct { // - File.Flags(): This value may change during the operation. type FileOperations interface { // Release release resources held by FileOperations. - Release() + Release(ctx context.Context) // Waitable defines how this File can be waited on for read and // write readiness. diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index dcc1df38f..9dc58d5ff 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -54,7 +54,7 @@ func overlayFile(ctx context.Context, inode *Inode, flags FileFlags) (*File, err // Drop the extra reference on the Dirent. Now there's only one reference // on the dirent, either owned by f (if non-nil), or the Dirent is about // to be destroyed (if GetFile failed). - dirent.DecRef() + dirent.DecRef(ctx) return f, err } @@ -89,12 +89,12 @@ type overlayFileOperations struct { } // Release implements FileOperations.Release. -func (f *overlayFileOperations) Release() { +func (f *overlayFileOperations) Release(ctx context.Context) { if f.upper != nil { - f.upper.DecRef() + f.upper.DecRef(ctx) } if f.lower != nil { - f.lower.DecRef() + f.lower.DecRef(ctx) } } @@ -164,7 +164,7 @@ func (f *overlayFileOperations) Seek(ctx context.Context, file *File, whence See func (f *overlayFileOperations) Readdir(ctx context.Context, file *File, serializer DentrySerializer) (int64, error) { root := RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &DirCtx{ @@ -497,7 +497,7 @@ func readdirOne(ctx context.Context, d *Dirent) (map[string]DentAttr, error) { if err != nil { return nil, err } - defer dir.DecRef() + defer dir.DecRef(ctx) // Use a stub serializer to read the entries into memory. stubSerializer := &CollectEntriesSerializer{} @@ -521,10 +521,10 @@ type overlayMappingIdentity struct { } // DecRef implements AtomicRefCount.DecRef. -func (omi *overlayMappingIdentity) DecRef() { - omi.AtomicRefCount.DecRefWithDestructor(func() { - omi.overlayFile.DecRef() - omi.id.DecRef() +func (omi *overlayMappingIdentity) DecRef(ctx context.Context) { + omi.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) { + omi.overlayFile.DecRef(ctx) + omi.id.DecRef(ctx) }) } @@ -544,7 +544,7 @@ func (omi *overlayMappingIdentity) InodeID() uint64 { func (omi *overlayMappingIdentity) MappedName(ctx context.Context) string { root := RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } name, _ := omi.overlayFile.Dirent.FullName(root) return name diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 08695391c..dc9efa5df 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -31,7 +31,7 @@ import ( type FileNoopRelease struct{} // Release is a no-op. -func (FileNoopRelease) Release() {} +func (FileNoopRelease) Release(context.Context) {} // SeekWithDirCursor is used to implement fs.FileOperations.Seek. If dirCursor // is not nil and the seek was on a directory, the cursor will be updated. @@ -296,7 +296,7 @@ func (sdfo *StaticDirFileOperations) IterateDir(ctx context.Context, d *fs.Diren func (sdfo *StaticDirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ Serializer: serializer, diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index b2fcab127..c0bc63a32 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -114,7 +114,7 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF } // Release implements fs.FileOpeations.Release. -func (f *fileOperations) Release() { +func (f *fileOperations) Release(context.Context) { f.handles.DecRef() } @@ -122,7 +122,7 @@ func (f *fileOperations) Release() { func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go index 2df2fe889..326fed954 100644 --- a/pkg/sentry/fs/gofer/gofer_test.go +++ b/pkg/sentry/fs/gofer/gofer_test.go @@ -232,7 +232,7 @@ func TestRevalidation(t *testing.T) { // We must release the dirent, of the test will fail // with a reference leak. This is tracked by p9test. - defer dirent.DecRef() + defer dirent.DecRef(ctx) // Walk again. Depending on the cache policy, we may // get a new dirent. @@ -246,7 +246,7 @@ func TestRevalidation(t *testing.T) { if !test.preModificationWantReload && dirent != newDirent { t.Errorf("Lookup with cachePolicy=%s got new dirent %+v, wanted old dirent %+v", test.cachePolicy, newDirent, dirent) } - newDirent.DecRef() // See above. + newDirent.DecRef(ctx) // See above. // Modify the underlying mocked file's modification // time for the next walk that occurs. @@ -287,7 +287,7 @@ func TestRevalidation(t *testing.T) { if test.postModificationWantUpdatedAttrs && gotModTimeSeconds != nowSeconds { t.Fatalf("Lookup with cachePolicy=%s got new modification time %v, wanted %v", test.cachePolicy, gotModTimeSeconds, nowSeconds) } - newDirent.DecRef() // See above. + newDirent.DecRef(ctx) // See above. // Remove the file from the remote fs, subsequent walks // should now fail to find anything. @@ -303,7 +303,7 @@ func TestRevalidation(t *testing.T) { t.Errorf("Lookup with cachePolicy=%s got new dirent and error %v, wanted old dirent and nil error", test.cachePolicy, err) } if err == nil { - newDirent.DecRef() // See above. + newDirent.DecRef(ctx) // See above. } }) } diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index fc14249be..f324dbf26 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -47,7 +47,8 @@ type handles struct { // DecRef drops a reference on handles. func (h *handles) DecRef() { - h.DecRefWithDestructor(func() { + ctx := context.Background() + h.DecRefWithDestructor(ctx, func(context.Context) { if h.Host != nil { if h.isHostBorrowed { h.Host.Release() @@ -57,7 +58,7 @@ func (h *handles) DecRef() { } } } - if err := h.File.close(context.Background()); err != nil { + if err := h.File.close(ctx); err != nil { log.Warningf("error closing p9 file: %v", err) } }) diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 51d7368a1..3a225fd39 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -441,8 +441,9 @@ func (i *inodeOperations) Release(ctx context.Context) { // asynchronously. // // We use AsyncWithContext to avoid needing to allocate an extra - // anonymous function on the heap. - fs.AsyncWithContext(ctx, i.fileState.Release) + // anonymous function on the heap. We must use background context + // because the async work cannot happen on the task context. + fs.AsyncWithContext(context.Background(), i.fileState.Release) } // Mappable implements fs.InodeOperations.Mappable. diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index cf9800100..3c66dc3c2 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -168,7 +168,7 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string // Construct the positive Dirent. d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name) - defer d.DecRef() + defer d.DecRef(ctx) // Construct the new file, caching the handles if allowed. h := handles{ @@ -371,7 +371,7 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string // Find out if file being deleted is a socket or pipe that needs to be // removed from endpoint map. if d, err := i.Lookup(ctx, dir, name); err == nil { - defer d.DecRef() + defer d.DecRef(ctx) if fs.IsSocket(d.Inode.StableAttr) || fs.IsPipe(d.Inode.StableAttr) { switch iops := d.Inode.InodeOperations.(type) { @@ -392,7 +392,7 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string return err } if key != nil { - i.session().overrides.remove(*key) + i.session().overrides.remove(ctx, *key) } i.touchModificationAndStatusChangeTime(ctx, dir) diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index b5efc86f2..7cf3522ff 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -89,10 +89,10 @@ func (e *overrideMaps) addPipe(key device.MultiDeviceKey, d *fs.Dirent, inode *f // remove deletes the key from the maps. // // Precondition: maps must have been locked with 'lock'. -func (e *overrideMaps) remove(key device.MultiDeviceKey) { +func (e *overrideMaps) remove(ctx context.Context, key device.MultiDeviceKey) { endpoint := e.keyMap[key] delete(e.keyMap, key) - endpoint.dirent.DecRef() + endpoint.dirent.DecRef(ctx) } // lock blocks other addition and removal operations from happening while @@ -197,7 +197,7 @@ type session struct { } // Destroy tears down the session. -func (s *session) Destroy() { +func (s *session) Destroy(ctx context.Context) { s.client.Close() } @@ -329,7 +329,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF s.client, err = p9.NewClient(conn, s.msize, s.version) if err != nil { // Drop our reference on the session, it needs to be torn down. - s.DecRef() + s.DecRef(ctx) return nil, err } @@ -340,7 +340,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF ctx.UninterruptibleSleepFinish(false) if err != nil { // Same as above. - s.DecRef() + s.DecRef(ctx) return nil, err } @@ -348,7 +348,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF if err != nil { s.attach.close(ctx) // Same as above, but after we execute the Close request. - s.DecRef() + s.DecRef(ctx) return nil, err } @@ -393,13 +393,13 @@ func (s *session) fillKeyMap(ctx context.Context) error { // fillPathMap populates paths for overrides from dirents in direntMap // before save. -func (s *session) fillPathMap() error { +func (s *session) fillPathMap(ctx context.Context) error { unlock := s.overrides.lock() defer unlock() for _, endpoint := range s.overrides.keyMap { mountRoot := endpoint.dirent.MountRoot() - defer mountRoot.DecRef() + defer mountRoot.DecRef(ctx) dirPath, _ := endpoint.dirent.FullName(mountRoot) if dirPath == "" { return fmt.Errorf("error getting path from dirent") diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go index 2d398b753..48b423dd8 100644 --- a/pkg/sentry/fs/gofer/session_state.go +++ b/pkg/sentry/fs/gofer/session_state.go @@ -26,7 +26,8 @@ import ( // beforeSave is invoked by stateify. func (s *session) beforeSave() { if s.overrides != nil { - if err := s.fillPathMap(); err != nil { + ctx := &dummyClockContext{context.Background()} + if err := s.fillPathMap(ctx); err != nil { panic("failed to save paths to override map before saving" + err.Error()) } } diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 40f2c1cad..8a1c69ac2 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -134,14 +134,14 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect // We don't need the receiver. c.CloseRecv() - c.Release() + c.Release(ctx) return c, nil } // Release implements transport.BoundEndpoint.Release. -func (e *endpoint) Release() { - e.inode.DecRef() +func (e *endpoint) Release(ctx context.Context) { + e.inode.DecRef(ctx) } // Passcred implements transport.BoundEndpoint.Passcred. diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go index 39299b7e4..0d8d36afa 100644 --- a/pkg/sentry/fs/host/control.go +++ b/pkg/sentry/fs/host/control.go @@ -57,7 +57,7 @@ func (c *scmRights) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (c *scmRights) Release() { +func (c *scmRights) Release(ctx context.Context) { for _, fd := range c.fds { syscall.Close(fd) } diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go index 3e48b8b2c..86d1a87f0 100644 --- a/pkg/sentry/fs/host/file.go +++ b/pkg/sentry/fs/host/file.go @@ -110,7 +110,7 @@ func newFileFromDonatedFD(ctx context.Context, donated int, saveable, isTTY bool name := fmt.Sprintf("host:[%d]", inode.StableAttr.InodeID) dirent := fs.NewDirent(ctx, inode, name) - defer dirent.DecRef() + defer dirent.DecRef(ctx) if isTTY { return newTTYFile(ctx, dirent, flags, iops), nil @@ -169,7 +169,7 @@ func (f *fileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ Serializer: serializer, diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go index c507f57eb..41a23b5da 100644 --- a/pkg/sentry/fs/host/inode_test.go +++ b/pkg/sentry/fs/host/inode_test.go @@ -36,7 +36,7 @@ func TestCloseFD(t *testing.T) { if err != nil { t.Fatalf("Failed to create File: %v", err) } - file.DecRef() + file.DecRef(ctx) s := make([]byte, 10) if c, err := syscall.Read(p[0], s); c != 0 || err != nil { diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index cfb089e43..a2f3d5918 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -194,7 +194,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) } // Send implements transport.ConnectedEndpoint.Send. -func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -271,7 +271,7 @@ func (c *ConnectedEndpoint) EventUpdate() { } // Recv implements transport.Receiver.Recv. -func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -318,7 +318,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek } // close releases all resources related to the endpoint. -func (c *ConnectedEndpoint) close() { +func (c *ConnectedEndpoint) close(context.Context) { fdnotifier.RemoveFD(int32(c.file.FD())) c.file.Close() c.file = nil @@ -374,8 +374,8 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { } // Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release. -func (c *ConnectedEndpoint) Release() { - c.ref.DecRefWithDestructor(c.close) +func (c *ConnectedEndpoint) Release(ctx context.Context) { + c.ref.DecRefWithDestructor(ctx, c.close) } // CloseUnread implements transport.ConnectedEndpoint.CloseUnread. diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index affdbcacb..9d58ea448 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -67,11 +67,12 @@ func TestSocketIsBlocking(t *testing.T) { if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK { t.Fatalf("Expected socket %v to be blocking", pair[1]) } - sock, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + sock, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) failed => %v", pair[0], err) } - defer sock.DecRef() + defer sock.DecRef(ctx) // Test that the socket now is non-blocking. if fl, err = getFl(pair[0]); err != nil { t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) @@ -93,11 +94,12 @@ func TestSocketWritev(t *testing.T) { if err != nil { t.Fatalf("host socket creation failed: %v", err) } - socket, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + socket, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[0], err) } - defer socket.DecRef() + defer socket.DecRef(ctx) buf := []byte("hello world\n") n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf)) if err != nil { @@ -115,11 +117,12 @@ func TestSocketWritevLen0(t *testing.T) { if err != nil { t.Fatalf("host socket creation failed: %v", err) } - socket, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + socket, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[0], err) } - defer socket.DecRef() + defer socket.DecRef(ctx) n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil)) if err != nil { t.Fatalf("socket writev failed: %v", err) @@ -136,11 +139,12 @@ func TestSocketSendMsgLen0(t *testing.T) { if err != nil { t.Fatalf("host socket creation failed: %v", err) } - sfile, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + sfile, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[0], err) } - defer sfile.DecRef() + defer sfile.DecRef(ctx) s := sfile.FileOperations.(socket.Socket) n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{}) @@ -158,18 +162,19 @@ func TestListen(t *testing.T) { if err != nil { t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err) } - sfile1, err := newSocket(contexttest.Context(t), pair[0], false) + ctx := contexttest.Context(t) + sfile1, err := newSocket(ctx, pair[0], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[0], err) } - defer sfile1.DecRef() + defer sfile1.DecRef(ctx) socket1 := sfile1.FileOperations.(socket.Socket) - sfile2, err := newSocket(contexttest.Context(t), pair[1], false) + sfile2, err := newSocket(ctx, pair[1], false) if err != nil { t.Fatalf("newSocket(%v) => %v", pair[1], err) } - defer sfile2.DecRef() + defer sfile2.DecRef(ctx) socket2 := sfile2.FileOperations.(socket.Socket) // Socketpairs can not be listened to. @@ -185,11 +190,11 @@ func TestListen(t *testing.T) { if err != nil { t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err) } - sfile3, err := newSocket(contexttest.Context(t), sock, false) + sfile3, err := newSocket(ctx, sock, false) if err != nil { t.Fatalf("newSocket(%v) => %v", sock, err) } - defer sfile3.DecRef() + defer sfile3.DecRef(ctx) socket3 := sfile3.FileOperations.(socket.Socket) // This socket is not bound so we can't listen on it. @@ -237,9 +242,10 @@ func TestRelease(t *testing.T) { } c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} want := &ConnectedEndpoint{queue: c.queue} - want.ref.DecRef() + ctx := contexttest.Context(t) + want.ref.DecRef(ctx) fdnotifier.AddFD(int32(c.file.FD()), nil) - c.Release() + c.Release(ctx) if !reflect.DeepEqual(c, want) { t.Errorf("got = %#v, want = %#v", c, want) } diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 82a02fcb2..b5229098c 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -113,12 +113,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme } // Release implements fs.FileOperations.Release. -func (t *TTYFileOperations) Release() { +func (t *TTYFileOperations) Release(ctx context.Context) { t.mu.Lock() t.fgProcessGroup = nil t.mu.Unlock() - t.fileOperations.Release() + t.fileOperations.Release(ctx) } // Ioctl implements fs.FileOperations.Ioctl. diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go index ce397a5e3..c143f4ce2 100644 --- a/pkg/sentry/fs/host/wait_test.go +++ b/pkg/sentry/fs/host/wait_test.go @@ -39,7 +39,7 @@ func TestWait(t *testing.T) { t.Fatalf("NewFile failed: %v", err) } - defer file.DecRef() + defer file.DecRef(ctx) r := file.Readiness(waiter.EventIn) if r != 0 { diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index a34fbc946..b79cd9877 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -96,13 +96,12 @@ func NewInode(ctx context.Context, iops InodeOperations, msrc *MountSource, satt } // DecRef drops a reference on the Inode. -func (i *Inode) DecRef() { - i.DecRefWithDestructor(i.destroy) +func (i *Inode) DecRef(ctx context.Context) { + i.DecRefWithDestructor(ctx, i.destroy) } // destroy releases the Inode and releases the msrc reference taken. -func (i *Inode) destroy() { - ctx := context.Background() +func (i *Inode) destroy(ctx context.Context) { if err := i.WriteOut(ctx); err != nil { // FIXME(b/65209558): Mark as warning again once noatime is // properly supported. @@ -122,12 +121,12 @@ func (i *Inode) destroy() { i.Watches.targetDestroyed() if i.overlay != nil { - i.overlay.release() + i.overlay.release(ctx) } else { i.InodeOperations.Release(ctx) } - i.MountSource.DecRef() + i.MountSource.DecRef(ctx) } // Mappable calls i.InodeOperations.Mappable. diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go index efd3c962b..9911a00c2 100644 --- a/pkg/sentry/fs/inode_inotify.go +++ b/pkg/sentry/fs/inode_inotify.go @@ -17,6 +17,7 @@ package fs import ( "fmt" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) @@ -136,11 +137,11 @@ func (w *Watches) Notify(name string, events, cookie uint32) { } // Unpin unpins dirent from all watches in this set. -func (w *Watches) Unpin(d *Dirent) { +func (w *Watches) Unpin(ctx context.Context, d *Dirent) { w.mu.RLock() defer w.mu.RUnlock() for _, watch := range w.ws { - watch.Unpin(d) + watch.Unpin(ctx, d) } } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 537c8d257..dc2e353d9 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -85,7 +85,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name upperInode = child.Inode upperInode.IncRef() } - child.DecRef() + child.DecRef(ctx) } // Are we done? @@ -108,7 +108,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name entry, err := newOverlayEntry(ctx, upperInode, nil, false) if err != nil { // Don't leak resources. - upperInode.DecRef() + upperInode.DecRef(ctx) parent.copyMu.RUnlock() return nil, false, err } @@ -129,7 +129,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name if err != nil && err != syserror.ENOENT { // Don't leak resources. if upperInode != nil { - upperInode.DecRef() + upperInode.DecRef(ctx) } parent.copyMu.RUnlock() return nil, false, err @@ -152,7 +152,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name } } } - child.DecRef() + child.DecRef(ctx) } } @@ -183,7 +183,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // unnecessary because we don't need to copy-up and we will always // operate (e.g. read/write) on the upper Inode. if !IsDir(upperInode.StableAttr) { - lowerInode.DecRef() + lowerInode.DecRef(ctx) lowerInode = nil } } @@ -194,10 +194,10 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // Well, not quite, we failed at the last moment, how depressing. // Be sure not to leak resources. if upperInode != nil { - upperInode.DecRef() + upperInode.DecRef(ctx) } if lowerInode != nil { - lowerInode.DecRef() + lowerInode.DecRef(ctx) } parent.copyMu.RUnlock() return nil, false, err @@ -248,7 +248,7 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st // user) will clobber the real path for the underlying Inode. upperFile.Dirent.Inode.IncRef() upperDirent := NewTransientDirent(upperFile.Dirent.Inode) - upperFile.Dirent.DecRef() + upperFile.Dirent.DecRef(ctx) upperFile.Dirent = upperDirent // Create the overlay inode and dirent. We need this to construct the @@ -259,7 +259,7 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st // The overlay file created below with NewFile will take a reference on // the overlayDirent, and it should be the only thing holding a // reference at the time of creation, so we must drop this reference. - defer overlayDirent.DecRef() + defer overlayDirent.DecRef(ctx) // Create a new overlay file that wraps the upper file. flags.Pread = upperFile.Flags().Pread @@ -399,7 +399,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena if !replaced.IsNegative() && IsDir(replaced.Inode.StableAttr) { children, err := readdirOne(ctx, replaced) if err != nil { - replaced.DecRef() + replaced.DecRef(ctx) return err } @@ -407,12 +407,12 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena // included among the returned children, so we don't // need to bother checking for them. if len(children) > 0 { - replaced.DecRef() + replaced.DecRef(ctx) return syserror.ENOTEMPTY } } - replaced.DecRef() + replaced.DecRef(ctx) } } @@ -455,12 +455,12 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri // Grab the inode and drop the dirent, we don't need it. inode := d.Inode inode.IncRef() - d.DecRef() + d.DecRef(ctx) // Create a new overlay entry and dirent for the socket. entry, err := newOverlayEntry(ctx, inode, nil, false) if err != nil { - inode.DecRef() + inode.DecRef(ctx) return nil, err } // Use the parent's MountSource, since that corresponds to the overlay, @@ -672,7 +672,7 @@ func overlayGetlink(ctx context.Context, o *overlayEntry) (*Dirent, error) { // ground and claim that jumping around the filesystem like this // is not supported. name, _ := dirent.FullName(nil) - dirent.DecRef() + dirent.DecRef(ctx) // Claim that the path is not accessible. err = syserror.EACCES diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go index 389c219d6..aa9851b26 100644 --- a/pkg/sentry/fs/inode_overlay_test.go +++ b/pkg/sentry/fs/inode_overlay_test.go @@ -316,7 +316,7 @@ func TestCacheFlush(t *testing.T) { t.Fatalf("NewMountNamespace failed: %v", err) } root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) ctx = &rootContext{ Context: ctx, @@ -345,7 +345,7 @@ func TestCacheFlush(t *testing.T) { } // Drop the file reference. - file.DecRef() + file.DecRef(ctx) // Dirent should have 2 refs left. if got, want := dirent.ReadRefs(), 2; int(got) != want { @@ -361,7 +361,7 @@ func TestCacheFlush(t *testing.T) { } // Drop our ref. - dirent.DecRef() + dirent.DecRef(ctx) // We should be back to zero refs. if got, want := dirent.ReadRefs(), 0; int(got) != want { @@ -398,7 +398,7 @@ func (d *dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags if err != nil { return nil, err } - defer file.DecRef() + defer file.DecRef(ctx) // Wrap the file's FileOperations in a dirFile. fops := &dirFile{ FileOperations: file.FileOperations, diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index e3a715c1f..c5c07d564 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -80,7 +80,7 @@ func NewInotify(ctx context.Context) *Inotify { // Release implements FileOperations.Release. Release removes all watches and // frees all resources for an inotify instance. -func (i *Inotify) Release() { +func (i *Inotify) Release(ctx context.Context) { // We need to hold i.mu to avoid a race with concurrent calls to // Inotify.targetDestroyed from Watches. There's no risk of Watches // accessing this Inotify after the destructor ends, because we remove all @@ -93,7 +93,7 @@ func (i *Inotify) Release() { // the owner's destructor. w.target.Watches.Remove(w.ID()) // Don't leak any references to the target, held by pins in the watch. - w.destroy() + w.destroy(ctx) } } @@ -321,7 +321,7 @@ func (i *Inotify) AddWatch(target *Dirent, mask uint32) int32 { // // RmWatch looks up an inotify watch for the given 'wd' and configures the // target dirent to stop sending events to this inotify instance. -func (i *Inotify) RmWatch(wd int32) error { +func (i *Inotify) RmWatch(ctx context.Context, wd int32) error { i.mu.Lock() // Find the watch we were asked to removed. @@ -346,7 +346,7 @@ func (i *Inotify) RmWatch(wd int32) error { i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0)) // Remove all pins. - watch.destroy() + watch.destroy(ctx) return nil } diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go index 900cba3ca..605423d22 100644 --- a/pkg/sentry/fs/inotify_watch.go +++ b/pkg/sentry/fs/inotify_watch.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" ) @@ -105,12 +106,12 @@ func (w *Watch) Pin(d *Dirent) { // Unpin drops any extra refs held on dirent due to a previous Pin // call. Calling Unpin multiple times for the same dirent, or on a dirent // without a corresponding Pin call is a no-op. -func (w *Watch) Unpin(d *Dirent) { +func (w *Watch) Unpin(ctx context.Context, d *Dirent) { w.mu.Lock() defer w.mu.Unlock() if w.pins[d] { delete(w.pins, d) - d.DecRef() + d.DecRef(ctx) } } @@ -125,11 +126,11 @@ func (w *Watch) TargetDestroyed() { // this watch. Destroy does not cause any new events to be generated. The caller // is responsible for ensuring there are no outstanding references to this // watch. -func (w *Watch) destroy() { +func (w *Watch) destroy(ctx context.Context) { w.mu.Lock() defer w.mu.Unlock() for d := range w.pins { - d.DecRef() + d.DecRef(ctx) } w.pins = nil } diff --git a/pkg/sentry/fs/mount.go b/pkg/sentry/fs/mount.go index 37bae6810..ee69b10e8 100644 --- a/pkg/sentry/fs/mount.go +++ b/pkg/sentry/fs/mount.go @@ -51,7 +51,7 @@ type MountSourceOperations interface { DirentOperations // Destroy destroys the MountSource. - Destroy() + Destroy(ctx context.Context) // Below are MountSourceOperations that do not conform to Linux. @@ -165,16 +165,16 @@ func (msrc *MountSource) DecDirentRefs() { } } -func (msrc *MountSource) destroy() { +func (msrc *MountSource) destroy(ctx context.Context) { if c := msrc.DirentRefs(); c != 0 { panic(fmt.Sprintf("MountSource with non-zero direntRefs is being destroyed: %d", c)) } - msrc.MountSourceOperations.Destroy() + msrc.MountSourceOperations.Destroy(ctx) } // DecRef drops a reference on the MountSource. -func (msrc *MountSource) DecRef() { - msrc.DecRefWithDestructor(msrc.destroy) +func (msrc *MountSource) DecRef(ctx context.Context) { + msrc.DecRefWithDestructor(ctx, msrc.destroy) } // FlushDirentRefs drops all references held by the MountSource on Dirents. @@ -264,7 +264,7 @@ func (*SimpleMountSourceOperations) ResetInodeMappings() {} func (*SimpleMountSourceOperations) SaveInodeMapping(*Inode, string) {} // Destroy implements MountSourceOperations.Destroy. -func (*SimpleMountSourceOperations) Destroy() {} +func (*SimpleMountSourceOperations) Destroy(context.Context) {} // Info defines attributes of a filesystem. type Info struct { diff --git a/pkg/sentry/fs/mount_overlay.go b/pkg/sentry/fs/mount_overlay.go index 78e35b1e6..7badc75d6 100644 --- a/pkg/sentry/fs/mount_overlay.go +++ b/pkg/sentry/fs/mount_overlay.go @@ -115,9 +115,9 @@ func (o *overlayMountSourceOperations) SaveInodeMapping(inode *Inode, path strin } // Destroy drops references on the upper and lower MountSource. -func (o *overlayMountSourceOperations) Destroy() { - o.upper.DecRef() - o.lower.DecRef() +func (o *overlayMountSourceOperations) Destroy(ctx context.Context) { + o.upper.DecRef(ctx) + o.lower.DecRef(ctx) } // type overlayFilesystem is the filesystem for overlay mounts. diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go index a3d10770b..6c296f5d0 100644 --- a/pkg/sentry/fs/mount_test.go +++ b/pkg/sentry/fs/mount_test.go @@ -18,6 +18,7 @@ import ( "fmt" "testing" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/contexttest" ) @@ -32,13 +33,13 @@ func cacheReallyContains(cache *DirentCache, d *Dirent) bool { return false } -func mountPathsAre(root *Dirent, got []*Mount, want ...string) error { +func mountPathsAre(ctx context.Context, root *Dirent, got []*Mount, want ...string) error { gotPaths := make(map[string]struct{}, len(got)) gotStr := make([]string, len(got)) for i, g := range got { if groot := g.Root(); groot != nil { name, _ := groot.FullName(root) - groot.DecRef() + groot.DecRef(ctx) gotStr[i] = name gotPaths[name] = struct{}{} } @@ -69,7 +70,7 @@ func TestMountSourceOnlyCachedOnce(t *testing.T) { t.Fatalf("NewMountNamespace failed: %v", err) } rootDirent := mm.Root() - defer rootDirent.DecRef() + defer rootDirent.DecRef(ctx) // Get a child of the root which we will mount over. Note that the // MockInodeOperations causes Walk to always succeed. @@ -125,7 +126,7 @@ func TestAllMountsUnder(t *testing.T) { t.Fatalf("NewMountNamespace failed: %v", err) } rootDirent := mm.Root() - defer rootDirent.DecRef() + defer rootDirent.DecRef(ctx) // Add mounts at the following paths: paths := []string{ @@ -150,14 +151,14 @@ func TestAllMountsUnder(t *testing.T) { if err := mm.Mount(ctx, d, submountInode); err != nil { t.Fatalf("could not mount at %q: %v", p, err) } - d.DecRef() + d.DecRef(ctx) } // mm root should contain all submounts (and does not include the root mount). rootMnt := mm.FindMount(rootDirent) submounts := mm.AllMountsUnder(rootMnt) allPaths := append(paths, "/") - if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil { + if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil { t.Error(err) } @@ -181,9 +182,9 @@ func TestAllMountsUnder(t *testing.T) { if err != nil { t.Fatalf("could not find path %q in mount manager: %v", "/foo", err) } - defer d.DecRef() + defer d.DecRef(ctx) submounts = mm.AllMountsUnder(mm.FindMount(d)) - if err := mountPathsAre(rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil { + if err := mountPathsAre(ctx, rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil { t.Error(err) } @@ -193,9 +194,9 @@ func TestAllMountsUnder(t *testing.T) { if err != nil { t.Fatalf("could not find path %q in mount manager: %v", "/waldo", err) } - defer waldo.DecRef() + defer waldo.DecRef(ctx) submounts = mm.AllMountsUnder(mm.FindMount(waldo)) - if err := mountPathsAre(rootDirent, submounts, "/waldo"); err != nil { + if err := mountPathsAre(ctx, rootDirent, submounts, "/waldo"); err != nil { t.Error(err) } } @@ -212,7 +213,7 @@ func TestUnmount(t *testing.T) { t.Fatalf("NewMountNamespace failed: %v", err) } rootDirent := mm.Root() - defer rootDirent.DecRef() + defer rootDirent.DecRef(ctx) // Add mounts at the following paths: paths := []string{ @@ -240,7 +241,7 @@ func TestUnmount(t *testing.T) { if err := mm.Mount(ctx, d, submountInode); err != nil { t.Fatalf("could not mount at %q: %v", p, err) } - d.DecRef() + d.DecRef(ctx) } allPaths := make([]string, len(paths)+1) @@ -259,13 +260,13 @@ func TestUnmount(t *testing.T) { if err := mm.Unmount(ctx, d, false); err != nil { t.Fatalf("could not unmount at %q: %v", p, err) } - d.DecRef() + d.DecRef(ctx) // Remove the path that has been unmounted and the check that the remaining // mounts are still there. allPaths = allPaths[:len(allPaths)-1] submounts := mm.AllMountsUnder(rootMnt) - if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil { + if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil { t.Error(err) } } diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index 3f2bd0e87..d741c4339 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -234,7 +234,7 @@ func (mns *MountNamespace) flushMountSourceRefsLocked() { // After destroy is called, the MountNamespace may continue to be referenced (for // example via /proc/mounts), but should free all resources and shouldn't have // Find* methods called. -func (mns *MountNamespace) destroy() { +func (mns *MountNamespace) destroy(ctx context.Context) { mns.mu.Lock() defer mns.mu.Unlock() @@ -247,13 +247,13 @@ func (mns *MountNamespace) destroy() { for _, mp := range mns.mounts { // Drop the mount reference on all mounted dirents. for ; mp != nil; mp = mp.previous { - mp.root.DecRef() + mp.root.DecRef(ctx) } } mns.mounts = nil // Drop reference on the root. - mns.root.DecRef() + mns.root.DecRef(ctx) // Ensure that root cannot be accessed via this MountNamespace any // more. @@ -265,8 +265,8 @@ func (mns *MountNamespace) destroy() { } // DecRef implements RefCounter.DecRef with destructor mns.destroy. -func (mns *MountNamespace) DecRef() { - mns.DecRefWithDestructor(mns.destroy) +func (mns *MountNamespace) DecRef(ctx context.Context) { + mns.DecRefWithDestructor(ctx, mns.destroy) } // withMountLocked prevents further walks to `node`, because `node` is about to @@ -312,7 +312,7 @@ func (mns *MountNamespace) Mount(ctx context.Context, mountPoint *Dirent, inode if err != nil { return err } - defer replacement.DecRef() + defer replacement.DecRef(ctx) // Set the mount's root dirent and id. parentMnt := mns.findMountLocked(mountPoint) @@ -394,7 +394,7 @@ func (mns *MountNamespace) Unmount(ctx context.Context, node *Dirent, detachOnly panic(fmt.Sprintf("Last mount in the chain must be a undo mount: %+v", prev)) } // Drop mount reference taken at the end of MountNamespace.Mount. - prev.root.DecRef() + prev.root.DecRef(ctx) } else { mns.mounts[prev.root] = prev } @@ -496,11 +496,11 @@ func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path // non-directory root is hopeless. if current != root { if !IsDir(current.Inode.StableAttr) { - current.DecRef() // Drop reference from above. + current.DecRef(ctx) // Drop reference from above. return nil, syserror.ENOTDIR } if err := current.Inode.CheckPermission(ctx, PermMask{Execute: true}); err != nil { - current.DecRef() // Drop reference from above. + current.DecRef(ctx) // Drop reference from above. return nil, err } } @@ -511,12 +511,12 @@ func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path // Allow failed walks to cache the dirent, because no // children will acquire a reference at the end. current.maybeExtendReference() - current.DecRef() + current.DecRef(ctx) return nil, err } // Drop old reference. - current.DecRef() + current.DecRef(ctx) if remainder != "" { // Ensure it's resolved, unless it's the last level. @@ -570,11 +570,11 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema case nil: // Make sure we didn't exhaust the traversal budget. if *remainingTraversals == 0 { - target.DecRef() + target.DecRef(ctx) return nil, syscall.ELOOP } - node.DecRef() // Drop the original reference. + node.DecRef(ctx) // Drop the original reference. return target, nil case syscall.ENOLINK: @@ -582,7 +582,7 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema return node, nil case ErrResolveViaReadlink: - defer node.DecRef() // See above. + defer node.DecRef(ctx) // See above. // First, check if we should traverse. if *remainingTraversals == 0 { @@ -608,7 +608,7 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema return d, err default: - node.DecRef() // Drop for err; see above. + node.DecRef(ctx) // Drop for err; see above. // Propagate the error. return nil, err diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go index a69b41468..975d6cbc9 100644 --- a/pkg/sentry/fs/mounts_test.go +++ b/pkg/sentry/fs/mounts_test.go @@ -51,7 +51,7 @@ func TestFindLink(t *testing.T) { } root := mm.Root() - defer root.DecRef() + defer root.DecRef(ctx) foo, err := root.Walk(ctx, root, "foo") if err != nil { t.Fatalf("Error walking to foo: %v", err) diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index a8ae7d81d..35013a21b 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -107,7 +107,7 @@ func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags Mount msrc := newOverlayMountSource(ctx, upper.MountSource, lower.MountSource, flags) overlay, err := newOverlayEntry(ctx, upper, lower, true) if err != nil { - msrc.DecRef() + msrc.DecRef(ctx) return nil, err } @@ -130,7 +130,7 @@ func NewOverlayRootFile(ctx context.Context, upperMS *MountSource, lower *Inode, msrc := newOverlayMountSource(ctx, upperMS, lower.MountSource, flags) overlay, err := newOverlayEntry(ctx, nil, lower, true) if err != nil { - msrc.DecRef() + msrc.DecRef(ctx) return nil, err } return newOverlayInode(ctx, overlay, msrc), nil @@ -230,16 +230,16 @@ func newOverlayEntry(ctx context.Context, upper *Inode, lower *Inode, lowerExist }, nil } -func (o *overlayEntry) release() { +func (o *overlayEntry) release(ctx context.Context) { // We drop a reference on upper and lower file system Inodes // rather than releasing them, because in-memory filesystems // may hold an extra reference to these Inodes so that they // stay in memory. if o.upper != nil { - o.upper.DecRef() + o.upper.DecRef(ctx) } if o.lower != nil { - o.lower.DecRef() + o.lower.DecRef(ctx) } } diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go index 35972e23c..45523adf8 100644 --- a/pkg/sentry/fs/proc/fds.go +++ b/pkg/sentry/fs/proc/fds.go @@ -56,11 +56,11 @@ func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDF // readDescriptors reads fds in the task starting at offset, and calls the // toDentAttr callback for each to get a DentAttr, which it then emits. This is // a helper for implementing fs.InodeOperations.Readdir. -func readDescriptors(t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) { +func readDescriptors(ctx context.Context, t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) { var fds []int32 t.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.GetFDs() + fds = fdTable.GetFDs(ctx) } }) @@ -116,7 +116,7 @@ func (f *fd) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error func (f *fd) Readlink(ctx context.Context, _ *fs.Inode) (string, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } n, _ := f.file.Dirent.FullName(root) return n, nil @@ -135,13 +135,7 @@ func (f *fd) Truncate(context.Context, *fs.Inode, int64) error { func (f *fd) Release(ctx context.Context) { f.Symlink.Release(ctx) - f.file.DecRef() -} - -// Close releases the reference on the file. -func (f *fd) Close() error { - f.file.DecRef() - return nil + f.file.DecRef(ctx) } // fdDir is an InodeOperations for /proc/TID/fd. @@ -227,7 +221,7 @@ func (f *fdDirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySer if f.isInfoFile { typ = fs.Symlink } - return readDescriptors(f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr { + return readDescriptors(ctx, f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr { return fs.GenericDentAttr(typ, device.ProcDevice) }) } @@ -261,7 +255,7 @@ func (fdid *fdInfoDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs // locks, and other data. For now we only have flags. // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt flags := file.Flags().ToLinux() | fdFlags.ToLinuxFileFlags() - file.DecRef() + file.DecRef(ctx) contents := []byte(fmt.Sprintf("flags:\t0%o\n", flags)) return newStaticProcInode(ctx, dir.MountSource, contents) }) diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go index 1fc9c703c..6a63c47b3 100644 --- a/pkg/sentry/fs/proc/mounts.go +++ b/pkg/sentry/fs/proc/mounts.go @@ -47,7 +47,7 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) { // The task has been destroyed. Nothing to show here. return } - defer rootDir.DecRef() + defer rootDir.DecRef(t) mnt := t.MountNamespace().FindMount(rootDir) if mnt == nil { @@ -64,7 +64,7 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) { continue // No longer valid. } mountPath, desc := mroot.FullName(rootDir) - mroot.DecRef() + mroot.DecRef(t) if !desc { // MountSources that are not descendants of the chroot jail are ignored. continue @@ -97,7 +97,7 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se if mroot == nil { return // No longer valid. } - defer mroot.DecRef() + defer mroot.DecRef(ctx) // Format: // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue @@ -216,7 +216,7 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan if root == nil { return // No longer valid. } - defer root.DecRef() + defer root.DecRef(ctx) flags := root.Inode.MountSource.Flags opts := "rw" diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index bd18177d4..83a43aa26 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -419,7 +419,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s } sfile := s.(*fs.File) if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { - s.DecRef() + s.DecRef(ctx) // Not a unix socket. continue } @@ -479,7 +479,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s } fmt.Fprintf(&buf, "\n") - s.DecRef() + s.DecRef(ctx) } data := []seqfile.SeqData{ @@ -574,7 +574,7 @@ func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kerne panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { - s.DecRef() + s.DecRef(ctx) // Not tcp4 sockets. continue } @@ -664,7 +664,7 @@ func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kerne fmt.Fprintf(&buf, "\n") - s.DecRef() + s.DecRef(ctx) } data := []seqfile.SeqData{ @@ -752,7 +752,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM { - s.DecRef() + s.DecRef(ctx) // Not udp4 socket. continue } @@ -822,7 +822,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se fmt.Fprintf(&buf, "\n") - s.DecRef() + s.DecRef(ctx) } data := []seqfile.SeqData{ diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index c659224a7..77e0e1d26 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -213,7 +213,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent // Add dot and dotdot. root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dot, dotdot := file.Dirent.GetDotAttrs(root) names = append(names, ".", "..") diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 4bbe90198..9cf7f2a62 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -185,7 +185,7 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry // Serialize "." and "..". root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dot, dotdot := file.Dirent.GetDotAttrs(root) if err := dirCtx.DirEmit(".", dot); err != nil { @@ -295,7 +295,7 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { if err != nil { return "", err } - defer exec.DecRef() + defer exec.DecRef(ctx) return exec.PathnameWithDeleted(ctx), nil } diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index bfa304552..f4fcddecb 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -219,7 +219,7 @@ func (d *Dir) Remove(ctx context.Context, _ *fs.Inode, name string) error { } // Remove our reference on the inode. - inode.DecRef() + inode.DecRef(ctx) return nil } @@ -250,7 +250,7 @@ func (d *Dir) RemoveDirectory(ctx context.Context, _ *fs.Inode, name string) err } // Remove our reference on the inode. - inode.DecRef() + inode.DecRef(ctx) return nil } @@ -326,7 +326,7 @@ func (d *Dir) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.F // Create the Dirent and corresponding file. created := fs.NewDirent(ctx, inode, name) - defer created.DecRef() + defer created.DecRef(ctx) return created.Inode.GetFile(ctx, created, flags) } @@ -412,11 +412,11 @@ func (*Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, ol } // Release implements fs.InodeOperation.Release. -func (d *Dir) Release(_ context.Context) { +func (d *Dir) Release(ctx context.Context) { // Drop references on all children. d.mu.Lock() for _, i := range d.children { - i.DecRef() + i.DecRef(ctx) } d.mu.Unlock() } @@ -456,7 +456,7 @@ func (dfo *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirC func (dfo *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ Serializer: serializer, @@ -473,13 +473,13 @@ func hasChildren(ctx context.Context, inode *fs.Inode) (bool, error) { // dropped when that dirent is destroyed. inode.IncRef() d := fs.NewTransientDirent(inode) - defer d.DecRef() + defer d.DecRef(ctx) file, err := inode.GetFile(ctx, d, fs.FileFlags{Read: true}) if err != nil { return false, err } - defer file.DecRef() + defer file.DecRef(ctx) ser := &fs.CollectEntriesSerializer{} if err := file.Readdir(ctx, ser); err != nil { @@ -530,7 +530,7 @@ func Rename(ctx context.Context, oldParent fs.InodeOperations, oldName string, n if err != nil { return err } - inode.DecRef() + inode.DecRef(ctx) } // Be careful, we may have already grabbed this mutex above. diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go index a6ed8b2c5..3e0d1e07e 100644 --- a/pkg/sentry/fs/ramfs/tree_test.go +++ b/pkg/sentry/fs/ramfs/tree_test.go @@ -67,7 +67,7 @@ func TestMakeDirectoryTree(t *testing.T) { continue } root := mm.Root() - defer mm.DecRef() + defer mm.DecRef(ctx) for _, p := range test.subdirs { maxTraversals := uint(0) diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index 88c344089..f362ca9b6 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -55,7 +55,7 @@ type TimerOperations struct { func NewFile(ctx context.Context, c ktime.Clock) *fs.File { dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[timerfd]") // Release the initial dirent reference after NewFile takes a reference. - defer dirent.DecRef() + defer dirent.DecRef(ctx) tops := &TimerOperations{} tops.timer = ktime.NewTimer(c, tops) // Timerfds reject writes, but the Write flag must be set in order to @@ -65,7 +65,7 @@ func NewFile(ctx context.Context, c ktime.Clock) *fs.File { } // Release implements fs.FileOperations.Release. -func (t *TimerOperations) Release() { +func (t *TimerOperations) Release(context.Context) { t.timer.Destroy() } diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go index aaba35502..d4d613ea9 100644 --- a/pkg/sentry/fs/tmpfs/file_test.go +++ b/pkg/sentry/fs/tmpfs/file_test.go @@ -46,7 +46,7 @@ func newFile(ctx context.Context) *fs.File { func TestGrow(t *testing.T) { ctx := contexttest.Context(t) f := newFile(ctx) - defer f.DecRef() + defer f.DecRef(ctx) abuf := bytes.Repeat([]byte{'a'}, 68) n, err := f.Pwritev(ctx, usermem.BytesIOSequence(abuf), 0) diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 108654827..463f6189e 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -132,7 +132,7 @@ func (d *dirInodeOperations) Release(ctx context.Context) { d.mu.Lock() defer d.mu.Unlock() - d.master.DecRef() + d.master.DecRef(ctx) if len(d.slaves) != 0 { panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d)) } @@ -263,7 +263,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e } // masterClose is called when the master end of t is closed. -func (d *dirInodeOperations) masterClose(t *Terminal) { +func (d *dirInodeOperations) masterClose(ctx context.Context, t *Terminal) { d.mu.Lock() defer d.mu.Unlock() @@ -277,7 +277,7 @@ func (d *dirInodeOperations) masterClose(t *Terminal) { panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d)) } - s.DecRef() + s.DecRef(ctx) delete(d.slaves, t.n) d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10)) } @@ -322,7 +322,7 @@ func (df *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCt func (df *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) { root := fs.RootFromContext(ctx) if root != nil { - defer root.DecRef() + defer root.DecRef(ctx) } dirCtx := &fs.DirCtx{ Serializer: serializer, diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go index 8fe05ebe5..2d4d44bf3 100644 --- a/pkg/sentry/fs/tty/fs.go +++ b/pkg/sentry/fs/tty/fs.go @@ -108,4 +108,4 @@ func (superOperations) ResetInodeMappings() {} func (superOperations) SaveInodeMapping(*fs.Inode, string) {} // Destroy implements MountSourceOperations.Destroy. -func (superOperations) Destroy() {} +func (superOperations) Destroy(context.Context) {} diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index fe07fa929..e00746017 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -75,7 +75,7 @@ func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwn } // Release implements fs.InodeOperations.Release. -func (mi *masterInodeOperations) Release(ctx context.Context) { +func (mi *masterInodeOperations) Release(context.Context) { } // Truncate implements fs.InodeOperations.Truncate. @@ -120,9 +120,9 @@ type masterFileOperations struct { var _ fs.FileOperations = (*masterFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (mf *masterFileOperations) Release() { - mf.d.masterClose(mf.t) - mf.t.DecRef() +func (mf *masterFileOperations) Release(ctx context.Context) { + mf.d.masterClose(ctx, mf.t) + mf.t.DecRef(ctx) } // EventRegister implements waiter.Waitable.EventRegister. diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index 9871f6fc6..7c7292687 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -71,7 +71,7 @@ func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owne // Release implements fs.InodeOperations.Release. func (si *slaveInodeOperations) Release(ctx context.Context) { - si.t.DecRef() + si.t.DecRef(ctx) } // Truncate implements fs.InodeOperations.Truncate. @@ -106,7 +106,7 @@ type slaveFileOperations struct { var _ fs.FileOperations = (*slaveFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (sf *slaveFileOperations) Release() { +func (sf *slaveFileOperations) Release(context.Context) { } // EventRegister implements waiter.Waitable.EventRegister. diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go index 397e96045..2f5a43b84 100644 --- a/pkg/sentry/fs/user/path.go +++ b/pkg/sentry/fs/user/path.go @@ -82,7 +82,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s // Caller has no root. Don't bother traversing anything. return "", syserror.ENOENT } - defer root.DecRef() + defer root.DecRef(ctx) for _, p := range paths { if !path.IsAbs(p) { // Relative paths aren't safe, no one should be using them. @@ -100,7 +100,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s if err != nil { return "", err } - defer d.DecRef() + defer d.DecRef(ctx) // Check that it is a regular file. if !fs.IsRegular(d.Inode.StableAttr) { @@ -121,7 +121,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) { root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) for _, p := range paths { if !path.IsAbs(p) { // Relative paths aren't safe, no one should be using them. @@ -148,7 +148,7 @@ func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNam if err != nil { return "", err } - dentry.DecRef() + dentry.DecRef(ctx) return binPath, nil } diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go index f4d525523..936fd3932 100644 --- a/pkg/sentry/fs/user/user.go +++ b/pkg/sentry/fs/user/user.go @@ -62,7 +62,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K // doesn't exist we will return the default home directory. return defaultHome, nil } - defer dirent.DecRef() + defer dirent.DecRef(ctx) // Check read permissions on the file. if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Read: true}); err != nil { @@ -81,7 +81,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K if err != nil { return "", err } - defer f.DecRef() + defer f.DecRef(ctx) r := &fileReader{ Ctx: ctx, @@ -105,7 +105,7 @@ func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth. const defaultHome = "/" root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) creds := auth.CredentialsFromContext(ctx) @@ -123,7 +123,7 @@ func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth. if err != nil { return defaultHome, nil } - defer fd.DecRef() + defer fd.DecRef(ctx) r := &fileReaderVFS2{ ctx: ctx, diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go index 7d8e9ac7c..12b786224 100644 --- a/pkg/sentry/fs/user/user_test.go +++ b/pkg/sentry/fs/user/user_test.go @@ -39,7 +39,7 @@ func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode if err != nil { return err } - defer etc.DecRef() + defer etc.DecRef(ctx) switch mode.FileType() { case 0: // Don't create anything. @@ -49,7 +49,7 @@ func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode if err != nil { return err } - defer passwd.DecRef() + defer passwd.DecRef(ctx) if _, err := passwd.Writev(ctx, usermem.BytesIOSequence([]byte(contents))); err != nil { return err } @@ -110,9 +110,9 @@ func TestGetExecUserHome(t *testing.T) { if err != nil { t.Fatalf("NewMountNamespace failed: %v", err) } - defer mns.DecRef() + defer mns.DecRef(ctx) root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) ctx = fs.WithRoot(ctx, root) if err := createEtcPasswd(ctx, root, tc.passwdContents, tc.passwdMode); err != nil { diff --git a/pkg/sentry/fsbridge/bridge.go b/pkg/sentry/fsbridge/bridge.go index 8e7590721..7e61209ee 100644 --- a/pkg/sentry/fsbridge/bridge.go +++ b/pkg/sentry/fsbridge/bridge.go @@ -44,7 +44,7 @@ type File interface { IncRef() // DecRef decrements reference. - DecRef() + DecRef(ctx context.Context) } // Lookup provides a common interface to open files. diff --git a/pkg/sentry/fsbridge/fs.go b/pkg/sentry/fsbridge/fs.go index 093ce1fb3..9785fd62a 100644 --- a/pkg/sentry/fsbridge/fs.go +++ b/pkg/sentry/fsbridge/fs.go @@ -49,7 +49,7 @@ func (f *fsFile) PathnameWithDeleted(ctx context.Context) string { // global there. return "" } - defer root.DecRef() + defer root.DecRef(ctx) name, _ := f.file.Dirent.FullName(root) return name @@ -87,8 +87,8 @@ func (f *fsFile) IncRef() { } // DecRef implements File. -func (f *fsFile) DecRef() { - f.file.DecRef() +func (f *fsFile) DecRef(ctx context.Context) { + f.file.DecRef(ctx) } // fsLookup implements Lookup interface using fs.File. @@ -124,7 +124,7 @@ func (l *fsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptio if err != nil { return nil, err } - defer d.DecRef() + defer d.DecRef(ctx) if !resolveFinal && fs.IsSymlink(d.Inode.StableAttr) { return nil, syserror.ELOOP diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index 89168220a..323506d33 100644 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go @@ -43,7 +43,7 @@ func NewVFSFile(file *vfs.FileDescription) File { // PathnameWithDeleted implements File. func (f *VFSFile) PathnameWithDeleted(ctx context.Context) string { root := vfs.RootFromContext(ctx) - defer root.DecRef() + defer root.DecRef(ctx) vfsObj := f.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem() name, _ := vfsObj.PathnameWithDeleted(ctx, root, f.file.VirtualDentry()) @@ -86,8 +86,8 @@ func (f *VFSFile) IncRef() { } // DecRef implements File. -func (f *VFSFile) DecRef() { - f.file.DecRef() +func (f *VFSFile) DecRef(ctx context.Context) { + f.file.DecRef(ctx) } // FileDescription returns the FileDescription represented by f. It does not diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index e6fda2b4f..7169e91af 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -103,9 +103,9 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // rootInode is the root directory inode for the devpts mounts. diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 1081fff52..3bb397f71 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -60,7 +60,7 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf } fd.LockFD.Init(&mi.locks) if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { - mi.DecRef() + mi.DecRef(ctx) return nil, err } return &fd.vfsfd, nil @@ -98,9 +98,9 @@ type masterFileDescription struct { var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil) // Release implements vfs.FileDescriptionImpl.Release. -func (mfd *masterFileDescription) Release() { +func (mfd *masterFileDescription) Release(ctx context.Context) { mfd.inode.root.masterClose(mfd.t) - mfd.inode.DecRef() + mfd.inode.DecRef(ctx) } // EventRegister implements waiter.Waitable.EventRegister. diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go index a91cae3ef..32e4e1908 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/slave.go @@ -56,7 +56,7 @@ func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs } fd.LockFD.Init(&si.locks) if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { - si.DecRef() + si.DecRef(ctx) return nil, err } return &fd.vfsfd, nil @@ -103,8 +103,8 @@ type slaveFileDescription struct { var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil) // Release implements fs.FileOperations.Release. -func (sfd *slaveFileDescription) Release() { - sfd.inode.DecRef() +func (sfd *slaveFileDescription) Release(ctx context.Context) { + sfd.inode.DecRef(ctx) } // EventRegister implements waiter.Waitable.EventRegister. diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index d0e06cdc0..2ed5fa8a9 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -92,9 +92,9 @@ func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth } // Release must be called when a is no longer in use. -func (a *Accessor) Release() { - a.root.DecRef() - a.mntns.DecRef() +func (a *Accessor) Release(ctx context.Context) { + a.root.DecRef(ctx) + a.mntns.DecRef(ctx) } // accessorContext implements context.Context by extending an existing diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go index b6d52c015..747867cca 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go @@ -30,7 +30,7 @@ func TestDevtmpfs(t *testing.T) { creds := auth.CredentialsFromContext(ctx) vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } // Register tmpfs just so that we can have a root filesystem that isn't @@ -48,9 +48,9 @@ func TestDevtmpfs(t *testing.T) { if err != nil { t.Fatalf("failed to create tmpfs root mount: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) devpop := vfs.PathOperation{ Root: root, Start: root, @@ -69,7 +69,7 @@ func TestDevtmpfs(t *testing.T) { if err != nil { t.Fatalf("failed to create devtmpfs.Accessor: %v", err) } - defer a.Release() + defer a.Release(ctx) // Create "userspace-initialized" files using a devtmpfs.Accessor. if err := a.UserspaceInit(ctx); err != nil { diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index d12d78b84..812171fa3 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -59,9 +59,9 @@ type EventFileDescription struct { var _ vfs.FileDescriptionImpl = (*EventFileDescription)(nil) // New creates a new event fd. -func New(vfsObj *vfs.VirtualFilesystem, initVal uint64, semMode bool, flags uint32) (*vfs.FileDescription, error) { +func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, initVal uint64, semMode bool, flags uint32) (*vfs.FileDescription, error) { vd := vfsObj.NewAnonVirtualDentry("[eventfd]") - defer vd.DecRef() + defer vd.DecRef(ctx) efd := &EventFileDescription{ val: initVal, semMode: semMode, @@ -107,7 +107,7 @@ func (efd *EventFileDescription) HostFD() (int, error) { } // Release implements FileDescriptionImpl.Release() -func (efd *EventFileDescription) Release() { +func (efd *EventFileDescription) Release(context.Context) { efd.mu.Lock() defer efd.mu.Unlock() if efd.hostfd >= 0 { diff --git a/pkg/sentry/fsimpl/eventfd/eventfd_test.go b/pkg/sentry/fsimpl/eventfd/eventfd_test.go index 20e3adffc..49916fa81 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd_test.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd_test.go @@ -36,16 +36,16 @@ func TestEventFD(t *testing.T) { for _, initVal := range initVals { ctx := contexttest.Context(t) vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } // Make a new eventfd that is writable. - eventfd, err := New(vfsObj, initVal, false, linux.O_RDWR) + eventfd, err := New(ctx, vfsObj, initVal, false, linux.O_RDWR) if err != nil { t.Fatalf("New() failed: %v", err) } - defer eventfd.DecRef() + defer eventfd.DecRef(ctx) // Register a callback for a write event. w, ch := waiter.NewChannelEntry(nil) @@ -74,16 +74,16 @@ func TestEventFD(t *testing.T) { func TestEventFDStat(t *testing.T) { ctx := contexttest.Context(t) vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } // Make a new eventfd that is writable. - eventfd, err := New(vfsObj, 0, false, linux.O_RDWR) + eventfd, err := New(ctx, vfsObj, 0, false, linux.O_RDWR) if err != nil { t.Fatalf("New() failed: %v", err) } - defer eventfd.DecRef() + defer eventfd.DecRef(ctx) statx, err := eventfd.Stat(ctx, vfs.StatOptions{ Mask: linux.STATX_BASIC_STATS, diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index 89caee3df..8f7d5a9bb 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -53,7 +53,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { return nil, nil, nil, nil, err } vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -68,7 +68,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys root := mntns.Root() tearDown := func() { - root.DecRef() + root.DecRef(ctx) if err := f.Close(); err != nil { b.Fatalf("tearDown failed: %v", err) @@ -169,7 +169,7 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount point: %v", err) } - defer mountPoint.DecRef() + defer mountPoint.DecRef(ctx) // Create extfs submount. mountTearDown := mount(b, fmt.Sprintf("/tmp/image-%d.ext4", depth), vfsfs, &pop) diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 55902322a..7a1b4219f 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -15,6 +15,7 @@ package ext import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -55,7 +56,7 @@ func (d *dentry) TryIncRef() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { +func (d *dentry) DecRef(ctx context.Context) { // FIXME(b/134676337): filesystem.mu may not be locked as required by // inode.decRef(). d.inode.decRef() @@ -64,7 +65,7 @@ func (d *dentry) DecRef() { // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // // TODO(b/134676337): Implement inotify. -func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {} +func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {} // Watches implements vfs.DentryImpl.Watches. // @@ -76,4 +77,4 @@ func (d *dentry) Watches() *vfs.Watches { // OnZeroWatches implements vfs.Dentry.OnZeroWatches. // // TODO(b/134676337): Implement inotify. -func (d *dentry) OnZeroWatches() {} +func (d *dentry) OnZeroWatches(context.Context) {} diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 357512c7e..0fc01668d 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -142,7 +142,7 @@ type directoryFD struct { var _ vfs.FileDescriptionImpl = (*directoryFD)(nil) // Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { +func (fd *directoryFD) Release(ctx context.Context) { if fd.iter == nil { return } diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index dac6effbf..08ffc2834 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -123,32 +123,32 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs.vfsfs.Init(vfsObj, &fsType, &fs) fs.sb, err = readSuperBlock(dev) if err != nil { - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, err } if fs.sb.Magic() != linux.EXT_SUPER_MAGIC { // mount(2) specifies that EINVAL should be returned if the superblock is // invalid. - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EINVAL } // Refuse to mount if the filesystem is incompatible. if !isCompatible(fs.sb) { - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EINVAL } fs.bgs, err = readBlockGroups(dev, fs.sb) if err != nil { - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, err } rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode) if err != nil { - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, err } rootInode.incRef() diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 64e9a579f..2dbaee287 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -65,7 +65,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -80,7 +80,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys root := mntns.Root() tearDown := func() { - root.DecRef() + root.DecRef(ctx) if err := f.Close(); err != nil { t.Fatalf("tearDown failed: %v", err) diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 557963e03..c714ddf73 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -84,7 +84,7 @@ var _ vfs.FilesystemImpl = (*filesystem)(nil) // - filesystem.mu must be locked (for writing if write param is true). // - !rp.Done(). // - inode == vfsd.Impl().(*Dentry).inode. -func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) { +func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) { if !inode.isDir() { return nil, nil, syserror.ENOTDIR } @@ -100,7 +100,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo } d := vfsd.Impl().(*dentry) if name == ".." { - isRoot, err := rp.CheckRoot(vfsd) + isRoot, err := rp.CheckRoot(ctx, vfsd) if err != nil { return nil, nil, err } @@ -108,7 +108,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo rp.Advance() return vfsd, inode, nil } - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, nil, err } rp.Advance() @@ -143,7 +143,7 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo child.name = name dir.childCache[name] = child } - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, nil, err } if child.inode.isSymlink() && rp.ShouldFollowSymlink() { @@ -167,12 +167,12 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo // // Preconditions: // - filesystem.mu must be locked (for writing if write param is true). -func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { +func walkLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { vfsd := rp.Start() inode := vfsd.Impl().(*dentry).inode for !rp.Done() { var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode, write) + vfsd, inode, err = stepLocked(ctx, rp, vfsd, inode, write) if err != nil { return nil, nil, err } @@ -196,12 +196,12 @@ func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) // Preconditions: // - filesystem.mu must be locked (for writing if write param is true). // - !rp.Done(). -func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { +func walkParentLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { vfsd := rp.Start() inode := vfsd.Impl().(*dentry).inode for !rp.Final() { var err error - vfsd, inode, err = stepLocked(rp, vfsd, inode, write) + vfsd, inode, err = stepLocked(ctx, rp, vfsd, inode, write) if err != nil { return nil, nil, err } @@ -216,7 +216,7 @@ func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, e // the rp till the parent of the last component which should be an existing // directory. If parent is false then resolves rp entirely. Attemps to resolve // the path as far as it can with a read lock and upgrades the lock if needed. -func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) { +func (fs *filesystem) walk(ctx context.Context, rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) { var ( vfsd *vfs.Dentry inode *inode @@ -227,9 +227,9 @@ func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *in // of disk. This reduces congestion (allows concurrent walks). fs.mu.RLock() if parent { - vfsd, inode, err = walkParentLocked(rp, false) + vfsd, inode, err = walkParentLocked(ctx, rp, false) } else { - vfsd, inode, err = walkLocked(rp, false) + vfsd, inode, err = walkLocked(ctx, rp, false) } fs.mu.RUnlock() @@ -238,9 +238,9 @@ func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *in // walk is fine as this is a read only filesystem. fs.mu.Lock() if parent { - vfsd, inode, err = walkParentLocked(rp, true) + vfsd, inode, err = walkParentLocked(ctx, rp, true) } else { - vfsd, inode, err = walkLocked(rp, true) + vfsd, inode, err = walkLocked(ctx, rp, true) } fs.mu.Unlock() } @@ -283,7 +283,7 @@ func (fs *filesystem) statTo(stat *linux.Statfs) { // AccessAt implements vfs.Filesystem.Impl.AccessAt. func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -292,7 +292,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { - vfsd, inode, err := fs.walk(rp, false) + vfsd, inode, err := fs.walk(ctx, rp, false) if err != nil { return nil, err } @@ -312,7 +312,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op // GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { - vfsd, inode, err := fs.walk(rp, true) + vfsd, inode, err := fs.walk(ctx, rp, true) if err != nil { return nil, err } @@ -322,7 +322,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // OpenAt implements vfs.FilesystemImpl.OpenAt. func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - vfsd, inode, err := fs.walk(rp, false) + vfsd, inode, err := fs.walk(ctx, rp, false) if err != nil { return nil, err } @@ -336,7 +336,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return "", err } @@ -349,7 +349,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // StatAt implements vfs.FilesystemImpl.StatAt. func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return linux.Statx{}, err } @@ -360,7 +360,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // StatFSAt implements vfs.FilesystemImpl.StatFSAt. func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { - if _, _, err := fs.walk(rp, false); err != nil { + if _, _, err := fs.walk(ctx, rp, false); err != nil { return linux.Statfs{}, err } @@ -370,7 +370,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } @@ -390,7 +390,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EEXIST } - if _, _, err := fs.walk(rp, true); err != nil { + if _, _, err := fs.walk(ctx, rp, true); err != nil { return err } @@ -403,7 +403,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return syserror.EEXIST } - if _, _, err := fs.walk(rp, true); err != nil { + if _, _, err := fs.walk(ctx, rp, true); err != nil { return err } @@ -416,7 +416,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return syserror.EEXIST } - _, _, err := fs.walk(rp, true) + _, _, err := fs.walk(ctx, rp, true) if err != nil { return err } @@ -430,7 +430,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return syserror.ENOENT } - _, _, err := fs.walk(rp, false) + _, _, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -440,7 +440,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // RmdirAt implements vfs.FilesystemImpl.RmdirAt. func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -454,7 +454,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // SetStatAt implements vfs.FilesystemImpl.SetStatAt. func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { - _, _, err := fs.walk(rp, false) + _, _, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -468,7 +468,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ return syserror.EEXIST } - _, _, err := fs.walk(rp, true) + _, _, err := fs.walk(ctx, rp, true) if err != nil { return err } @@ -478,7 +478,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ // UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -492,7 +492,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error // BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { - _, inode, err := fs.walk(rp, false) + _, inode, err := fs.walk(ctx, rp, false) if err != nil { return nil, err } @@ -506,7 +506,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { - _, _, err := fs.walk(rp, false) + _, _, err := fs.walk(ctx, rp, false) if err != nil { return nil, err } @@ -515,7 +515,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { - _, _, err := fs.walk(rp, false) + _, _, err := fs.walk(ctx, rp, false) if err != nil { return "", err } @@ -524,7 +524,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt // SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { - _, _, err := fs.walk(rp, false) + _, _, err := fs.walk(ctx, rp, false) if err != nil { return err } @@ -533,7 +533,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { - _, _, err := fs.walk(rp, false) + _, _, err := fs.walk(ctx, rp, false) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index 66d14bb95..e73e740d6 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -79,7 +79,7 @@ type regularFileFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() {} +func (fd *regularFileFD) Release(context.Context) {} // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index 62efd4095..2fd0d1fa8 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -73,7 +73,7 @@ type symlinkFD struct { var _ vfs.FileDescriptionImpl = (*symlinkFD)(nil) // Release implements vfs.FileDescriptionImpl.Release. -func (fd *symlinkFD) Release() {} +func (fd *symlinkFD) Release(context.Context) {} // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *symlinkFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index 2225076bc..e522ff9a0 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -99,7 +99,7 @@ type DeviceFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *DeviceFD) Release() { +func (fd *DeviceFD) Release(context.Context) { fd.fs.conn.connected = false } diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 84c222ad6..1ffe7ccd2 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -356,12 +356,12 @@ func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveReque vfsObj := &vfs.VirtualFilesystem{} fuseDev := &DeviceFD{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(system.Ctx); err != nil { return nil, nil, err } vd := vfsObj.NewAnonVirtualDentry("genCountFD") - defer vd.DecRef() + defer vd.DecRef(system.Ctx) if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, nil, err } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 200a93bbf..a1405f7c3 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -191,9 +191,9 @@ func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOpt } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // inode implements kernfs.Inode. diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 8c7c8e1b3..1679066ba 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -122,7 +122,7 @@ type directoryFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { +func (fd *directoryFD) Release(context.Context) { } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. @@ -139,7 +139,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fd.dirents = ds } - d.InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + d.InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent) if d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 00e3c99cd..e6af37d0d 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -55,7 +55,7 @@ func (fs *filesystem) Sync(ctx context.Context) error { // Sync regular files. for _, d := range ds { err := d.syncSharedHandle(ctx) - d.DecRef() + d.DecRef(ctx) if err != nil && retErr == nil { retErr = err } @@ -65,7 +65,7 @@ func (fs *filesystem) Sync(ctx context.Context) error { // handles (so they won't be synced by the above). for _, sffd := range sffds { err := sffd.Sync(ctx) - sffd.vfsfd.DecRef() + sffd.vfsfd.DecRef(ctx) if err != nil && retErr == nil { retErr = err } @@ -133,7 +133,7 @@ afterSymlink: return d, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() @@ -146,7 +146,7 @@ afterSymlink: // // Call rp.CheckMount() before updating d.parent's metadata, since if // we traverse to another mount then d.parent's metadata is irrelevant. - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, err } if d != d.parent && !d.cachedMetadataAuthoritative() { @@ -164,7 +164,7 @@ afterSymlink: if child == nil { return nil, syserror.ENOENT } - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { @@ -239,7 +239,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // has 0 references, drop it). Wait to update parent.children until we // know what to replace the existing dentry with (i.e. one of the // returns below), to avoid a redundant map access. - vfsObj.InvalidateDentry(&child.vfsd) + vfsObj.InvalidateDentry(ctx, &child.vfsd) if child.isSynthetic() { // Normally we don't mark invalidated dentries as deleted since // they may still exist (but at a different path), and also for @@ -332,7 +332,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string) error, createInSyntheticDir func(parent *dentry, name string) error) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { // Get updated metadata for start as required by @@ -384,7 +384,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if dir { ev |= linux.IN_ISDIR } - parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } if fs.opts.interop == InteropModeShared { @@ -405,7 +405,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if dir { ev |= linux.IN_ISDIR } - parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } if child := parent.children[name]; child != nil { @@ -426,7 +426,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if dir { ev |= linux.IN_ISDIR } - parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } @@ -434,7 +434,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { // Get updated metadata for start as required by @@ -470,7 +470,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b } vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) parent.dirMu.Lock() defer parent.dirMu.Unlock() @@ -600,17 +600,17 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // Generate inotify events for rmdir or unlink. if dir { - parent.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) + parent.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) } else { var cw *vfs.Watches if child != nil { cw = &child.watches } - vfs.InotifyRemoveChild(cw, &parent.watches, name) + vfs.InotifyRemoveChild(ctx, cw, &parent.watches, name) } if child != nil { - vfsObj.CommitDeleteDentry(&child.vfsd) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) child.setDeleted() if child.isSynthetic() { parent.syntheticChildren-- @@ -637,7 +637,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) { +func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { fs.renameMu.RUnlock() if *ds == nil { return @@ -645,20 +645,20 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) { if len(**ds) != 0 { fs.renameMu.Lock() for _, d := range **ds { - d.checkCachingLocked() + d.checkCachingLocked(ctx) } fs.renameMu.Unlock() } putDentrySlice(*ds) } -func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) { +func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { if *ds == nil { fs.renameMu.Unlock() return } for _, d := range **ds { - d.checkCachingLocked() + d.checkCachingLocked(ctx) } fs.renameMu.Unlock() putDentrySlice(*ds) @@ -668,7 +668,7 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) { func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err @@ -680,7 +680,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -701,7 +701,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { // Get updated metadata for start as required by @@ -812,7 +812,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { @@ -1126,7 +1126,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } childVFSFD = &fd.vfsfd } - d.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) + d.watches.Notify(ctx, name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) return childVFSFD, nil } @@ -1134,7 +1134,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err @@ -1154,7 +1154,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa var ds *[]*dentry fs.renameMu.Lock() - defer fs.renameMuUnlockAndCheckCaching(&ds) + defer fs.renameMuUnlockAndCheckCaching(ctx, &ds) newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds) if err != nil { return err @@ -1244,7 +1244,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return nil } mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil { return err } @@ -1269,7 +1269,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } // Update the dentry tree. - vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD) + vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD) if replaced != nil { replaced.setDeleted() if replaced.isSynthetic() { @@ -1331,17 +1331,17 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { - d.InotifyWithParent(ev, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil } @@ -1350,7 +1350,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statx{}, err @@ -1367,7 +1367,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statfs{}, err @@ -1417,7 +1417,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -1443,7 +1443,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -1455,7 +1455,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) + defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err @@ -1469,16 +1469,16 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } @@ -1488,16 +1488,16 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } if err := d.removexattr(ctx, rp.Credentials(), name); err != nil { - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - fs.renameMuRUnlockAndCheckCaching(&ds) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index e20de84b5..2e5575d8d 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -482,7 +482,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) if err != nil { attachFile.close(ctx) - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, err } // Set the root's reference count to 2. One reference is returned to the @@ -495,8 +495,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { - ctx := context.Background() +func (fs *filesystem) Release(ctx context.Context) { mf := fs.mfp.MemoryFile() fs.syncMu.Lock() @@ -1089,10 +1088,10 @@ func (d *dentry) TryIncRef() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { +func (d *dentry) DecRef(ctx context.Context) { if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { d.fs.renameMu.Lock() - d.checkCachingLocked() + d.checkCachingLocked(ctx) d.fs.renameMu.Unlock() } else if refs < 0 { panic("gofer.dentry.DecRef() called without holding a reference") @@ -1109,7 +1108,7 @@ func (d *dentry) decRefLocked() { } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { +func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { if d.isDir() { events |= linux.IN_ISDIR } @@ -1117,9 +1116,9 @@ func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { d.fs.renameMu.RLock() // The ordering below is important, Linux always notifies the parent first. if d.parent != nil { - d.parent.watches.Notify(d.name, events, cookie, et, d.isDeleted()) + d.parent.watches.Notify(ctx, d.name, events, cookie, et, d.isDeleted()) } - d.watches.Notify("", events, cookie, et, d.isDeleted()) + d.watches.Notify(ctx, "", events, cookie, et, d.isDeleted()) d.fs.renameMu.RUnlock() } @@ -1131,10 +1130,10 @@ func (d *dentry) Watches() *vfs.Watches { // OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. // // If no watches are left on this dentry and it has no references, cache it. -func (d *dentry) OnZeroWatches() { +func (d *dentry) OnZeroWatches(ctx context.Context) { if atomic.LoadInt64(&d.refs) == 0 { d.fs.renameMu.Lock() - d.checkCachingLocked() + d.checkCachingLocked(ctx) d.fs.renameMu.Unlock() } } @@ -1149,7 +1148,7 @@ func (d *dentry) OnZeroWatches() { // do nothing. // // Preconditions: d.fs.renameMu must be locked for writing. -func (d *dentry) checkCachingLocked() { +func (d *dentry) checkCachingLocked(ctx context.Context) { // Dentries with a non-zero reference count must be retained. (The only way // to obtain a reference on a dentry with zero references is via path // resolution, which requires renameMu, so if d.refs is zero then it will @@ -1171,14 +1170,14 @@ func (d *dentry) checkCachingLocked() { // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { if d.isDeleted() { - d.watches.HandleDeletion() + d.watches.HandleDeletion(ctx) } if d.cached { d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- d.cached = false } - d.destroyLocked() + d.destroyLocked(ctx) return } // If d still has inotify watches and it is not deleted or invalidated, we @@ -1213,7 +1212,7 @@ func (d *dentry) checkCachingLocked() { if !victim.vfsd.IsDead() { // Note that victim can't be a mount point (in any mount // namespace), since VFS holds references on mount points. - d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(&victim.vfsd) + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) delete(victim.parent.children, victim.name) // We're only deleting the dentry, not the file it // represents, so we don't need to update @@ -1221,7 +1220,7 @@ func (d *dentry) checkCachingLocked() { } victim.parent.dirMu.Unlock() } - victim.destroyLocked() + victim.destroyLocked(ctx) } // Whether or not victim was destroyed, we brought fs.cachedDentriesLen // back down to fs.opts.maxCachedDentries, so we don't loop. @@ -1233,7 +1232,7 @@ func (d *dentry) checkCachingLocked() { // // Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is // not a child dentry. -func (d *dentry) destroyLocked() { +func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: // Mark the dentry destroyed. @@ -1244,7 +1243,6 @@ func (d *dentry) destroyLocked() { panic("dentry.destroyLocked() called with references on the dentry") } - ctx := context.Background() d.handleMu.Lock() if !d.handle.file.isNil() { mf := d.fs.mfp.MemoryFile() @@ -1276,7 +1274,7 @@ func (d *dentry) destroyLocked() { // d.fs.renameMu. if d.parent != nil { if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkCachingLocked() + d.parent.checkCachingLocked(ctx) } else if refs < 0 { panic("gofer.dentry.DecRef() called without holding a reference") } @@ -1514,7 +1512,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return err } if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { - fd.dentry().InotifyWithParent(ev, 0, vfs.InodeEvent) + fd.dentry().InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil } @@ -1535,7 +1533,7 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { return err } - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } @@ -1545,7 +1543,7 @@ func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { return err } - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index adff39490..56d80bcf8 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -50,7 +50,7 @@ func TestDestroyIdempotent(t *testing.T) { } parent.cacheNewChildLocked(child, "child") - child.checkCachingLocked() + child.checkCachingLocked(ctx) if got := atomic.LoadInt64(&child.refs); got != -1 { t.Fatalf("child.refs=%d, want: -1", got) } @@ -58,6 +58,6 @@ func TestDestroyIdempotent(t *testing.T) { if got := atomic.LoadInt64(&parent.refs); got != -1 { t.Fatalf("parent.refs=%d, want: -1", got) } - child.checkCachingLocked() - child.checkCachingLocked() + child.checkCachingLocked(ctx) + child.checkCachingLocked(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 09f142cfc..420e8efe2 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -48,7 +48,7 @@ type regularFileFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() { +func (fd *regularFileFD) Release(context.Context) { } // OnClose implements vfs.FileDescriptionImpl.OnClose. diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index d6dbe9092..85d2bee72 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -108,7 +108,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect // We don't need the receiver. c.CloseRecv() - c.Release() + c.Release(ctx) return c, nil } @@ -136,8 +136,8 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla } // Release implements transport.BoundEndpoint.Release. -func (e *endpoint) Release() { - e.dentry.DecRef() +func (e *endpoint) Release(ctx context.Context) { + e.dentry.DecRef(ctx) } // Passcred implements transport.BoundEndpoint.Passcred. diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 811528982..fc269ef2b 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -80,11 +80,11 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *specialFileFD) Release() { +func (fd *specialFileFD) Release(ctx context.Context) { if fd.haveQueue { fdnotifier.RemoveFD(fd.handle.fd) } - fd.handle.close(context.Background()) + fd.handle.close(ctx) fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) fs.syncMu.Lock() delete(fs.specialFileFDs, fd) diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go index b9082a20f..0135e4428 100644 --- a/pkg/sentry/fsimpl/host/control.go +++ b/pkg/sentry/fsimpl/host/control.go @@ -58,7 +58,7 @@ func (c *scmRights) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (c *scmRights) Release() { +func (c *scmRights) Release(ctx context.Context) { for _, fd := range c.fds { syscall.Close(fd) } diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index c894f2ca0..bf922c566 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -117,7 +117,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) d.Init(i) // i.open will take a reference on d. - defer d.DecRef() + defer d.DecRef(ctx) // For simplicity, fileDescription.offset is set to 0. Technically, we // should only set to 0 on files that are not seekable (sockets, pipes, @@ -168,9 +168,9 @@ type filesystem struct { devMinor uint32 } -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { @@ -431,12 +431,12 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } // DecRef implements kernfs.Inode. -func (i *inode) DecRef() { - i.AtomicRefCount.DecRefWithDestructor(i.Destroy) +func (i *inode) DecRef(ctx context.Context) { + i.AtomicRefCount.DecRefWithDestructor(ctx, i.Destroy) } // Destroy implements kernfs.Inode. -func (i *inode) Destroy() { +func (i *inode) Destroy(context.Context) { if i.wouldBlock { fdnotifier.RemoveFD(int32(i.hostFD)) } @@ -542,7 +542,7 @@ func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux } // Release implements vfs.FileDescriptionImpl. -func (f *fileDescription) Release() { +func (f *fileDescription) Release(context.Context) { // noop } diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index fd16bd92d..4979dd0a9 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -139,7 +139,7 @@ func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable } // Send implements transport.ConnectedEndpoint.Send. -func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -216,7 +216,7 @@ func (c *ConnectedEndpoint) EventUpdate() { } // Recv implements transport.Receiver.Recv. -func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() @@ -317,8 +317,8 @@ func (c *ConnectedEndpoint) destroyLocked() { // Release implements transport.ConnectedEndpoint.Release and // transport.Receiver.Release. -func (c *ConnectedEndpoint) Release() { - c.ref.DecRefWithDestructor(func() { +func (c *ConnectedEndpoint) Release(ctx context.Context) { + c.ref.DecRefWithDestructor(ctx, func(context.Context) { c.mu.Lock() c.destroyLocked() c.mu.Unlock() @@ -347,8 +347,8 @@ func (e *SCMConnectedEndpoint) Init() error { // Release implements transport.ConnectedEndpoint.Release and // transport.Receiver.Release. -func (e *SCMConnectedEndpoint) Release() { - e.ref.DecRefWithDestructor(func() { +func (e *SCMConnectedEndpoint) Release(ctx context.Context) { + e.ref.DecRefWithDestructor(ctx, func(context.Context) { e.mu.Lock() if err := syscall.Close(e.fd); err != nil { log.Warningf("Failed to close host fd %d: %v", err) diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index 4ee9270cc..d372c60cb 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -67,12 +67,12 @@ func (t *TTYFileDescription) ForegroundProcessGroup() *kernel.ProcessGroup { } // Release implements fs.FileOperations.Release. -func (t *TTYFileDescription) Release() { +func (t *TTYFileDescription) Release(ctx context.Context) { t.mu.Lock() t.fgProcessGroup = nil t.mu.Unlock() - t.fileDescription.Release() + t.fileDescription.Release(ctx) } // PRead implements vfs.FileDescriptionImpl. diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index c6c4472e7..12adf727a 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -122,7 +122,7 @@ func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, of } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *DynamicBytesFD) Release() {} +func (fd *DynamicBytesFD) Release(context.Context) {} // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 1d37ccb98..fcee6200a 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -113,7 +113,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *GenericDirectoryFD) Release() {} +func (fd *GenericDirectoryFD) Release(context.Context) {} func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { return fd.vfsfd.VirtualDentry().Mount().Filesystem() diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 61a36cff9..d7edb6342 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -56,13 +56,13 @@ afterSymlink: return vfsd, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, vfsd); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() return vfsd, nil } - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, err } rp.Advance() @@ -77,7 +77,7 @@ afterSymlink: if err != nil { return nil, err } - if err := rp.CheckMount(&next.vfsd); err != nil { + if err := rp.CheckMount(ctx, &next.vfsd); err != nil { return nil, err } // Resolve any symlink at current path component. @@ -88,7 +88,7 @@ afterSymlink: } if targetVD.Ok() { err := rp.HandleJump(targetVD) - targetVD.DecRef() + targetVD.DecRef(ctx) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // Cached dentry exists, revalidate. if !child.inode.Valid(ctx) { delete(parent.children, name) - vfsObj.InvalidateDentry(&child.vfsd) + vfsObj.InvalidateDentry(ctx, &child.vfsd) fs.deferDecRef(&child.vfsd) // Reference from Lookup. child = nil } @@ -234,7 +234,7 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Den } // Release implements vfs.FilesystemImpl.Release. -func (fs *Filesystem) Release() { +func (fs *Filesystem) Release(context.Context) { } // Sync implements vfs.FilesystemImpl.Sync. @@ -246,7 +246,7 @@ func (fs *Filesystem) Sync(ctx context.Context) error { // AccessAt implements vfs.Filesystem.Impl.AccessAt. func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { fs.mu.RLock() - defer fs.processDeferredDecRefs() + defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() _, inode, err := fs.walkExistingLocked(ctx, rp) @@ -259,7 +259,7 @@ func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { fs.mu.RLock() - defer fs.processDeferredDecRefs() + defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() vfsd, inode, err := fs.walkExistingLocked(ctx, rp) if err != nil { @@ -282,7 +282,7 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op // GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt. func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { fs.mu.RLock() - defer fs.processDeferredDecRefs() + defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() vfsd, _, err := fs.walkParentDirLocked(ctx, rp) if err != nil { @@ -300,7 +300,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. fs.mu.Lock() defer fs.mu.Unlock() parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } @@ -337,7 +337,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v fs.mu.Lock() defer fs.mu.Unlock() parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } @@ -365,7 +365,7 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v fs.mu.Lock() defer fs.mu.Unlock() parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } @@ -397,7 +397,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // Do not create new file. if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() - defer fs.processDeferredDecRefs() + defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() vfsd, inode, err := fs.walkExistingLocked(ctx, rp) if err != nil { @@ -429,7 +429,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf } afterTrailingSymlink: parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) if err != nil { return nil, err } @@ -483,7 +483,7 @@ afterTrailingSymlink: } if targetVD.Ok() { err := rp.HandleJump(targetVD) - targetVD.DecRef() + targetVD.DecRef(ctx) if err != nil { return nil, err } @@ -507,7 +507,7 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st fs.mu.RLock() d, inode, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return "", err } @@ -526,7 +526,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 fs.mu.Lock() - defer fs.processDeferredDecRefsLocked() + defer fs.processDeferredDecRefsLocked(ctx) defer fs.mu.Unlock() // Resolve the destination directory first to verify that it's on this @@ -584,7 +584,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) virtfs := rp.VirtualFilesystem() // We can't deadlock here due to lock ordering because we're protected from @@ -615,7 +615,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa dstDir.children = make(map[string]*Dentry) } dstDir.children[pc] = src - virtfs.CommitRenameReplaceDentry(srcVFSD, replaced) + virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaced) return nil } @@ -624,7 +624,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error fs.mu.Lock() defer fs.mu.Unlock() vfsd, inode, err := fs.walkExistingLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } @@ -648,7 +648,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error defer parentDentry.dirMu.Unlock() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } @@ -656,7 +656,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error virtfs.AbortDeleteDentry(vfsd) return err } - virtfs.CommitDeleteDentry(vfsd) + virtfs.CommitDeleteDentry(ctx, vfsd) return nil } @@ -665,7 +665,7 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts fs.mu.RLock() _, inode, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return err } @@ -680,7 +680,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf fs.mu.RLock() _, inode, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return linux.Statx{}, err } @@ -692,7 +692,7 @@ func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return linux.Statfs{}, err } @@ -708,7 +708,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ fs.mu.Lock() defer fs.mu.Unlock() parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } @@ -733,7 +733,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error fs.mu.Lock() defer fs.mu.Unlock() vfsd, _, err := fs.walkExistingLocked(ctx, rp) - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } @@ -753,7 +753,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error parentDentry.dirMu.Lock() defer parentDentry.dirMu.Unlock() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } @@ -761,7 +761,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error virtfs.AbortDeleteDentry(vfsd) return err } - virtfs.CommitDeleteDentry(vfsd) + virtfs.CommitDeleteDentry(ctx, vfsd) return nil } @@ -770,7 +770,7 @@ func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath fs.mu.RLock() _, inode, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return nil, err } @@ -785,7 +785,7 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return nil, err } @@ -798,7 +798,7 @@ func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return "", err } @@ -811,7 +811,7 @@ func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return err } @@ -824,7 +824,7 @@ func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() - fs.processDeferredDecRefs() + fs.processDeferredDecRefs(ctx) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 579e627f0..c3efcf3ec 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -40,7 +40,7 @@ func (InodeNoopRefCount) IncRef() { } // DecRef implements Inode.DecRef. -func (InodeNoopRefCount) DecRef() { +func (InodeNoopRefCount) DecRef(context.Context) { } // TryIncRef implements Inode.TryIncRef. @@ -49,7 +49,7 @@ func (InodeNoopRefCount) TryIncRef() bool { } // Destroy implements Inode.Destroy. -func (InodeNoopRefCount) Destroy() { +func (InodeNoopRefCount) Destroy(context.Context) { } // InodeDirectoryNoNewChildren partially implements the Inode interface. @@ -366,12 +366,12 @@ func (o *OrderedChildren) Init(opts OrderedChildrenOptions) { } // DecRef implements Inode.DecRef. -func (o *OrderedChildren) DecRef() { - o.AtomicRefCount.DecRefWithDestructor(o.Destroy) +func (o *OrderedChildren) DecRef(ctx context.Context) { + o.AtomicRefCount.DecRefWithDestructor(ctx, o.Destroy) } // Destroy cleans up resources referenced by this OrderedChildren. -func (o *OrderedChildren) Destroy() { +func (o *OrderedChildren) Destroy(context.Context) { o.mu.Lock() defer o.mu.Unlock() o.order.Reset() diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 46f207664..080118841 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -116,17 +116,17 @@ func (fs *Filesystem) deferDecRef(d *vfs.Dentry) { // processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the // droppedDentries list. See comment on Filesystem.mu. -func (fs *Filesystem) processDeferredDecRefs() { +func (fs *Filesystem) processDeferredDecRefs(ctx context.Context) { fs.mu.Lock() - fs.processDeferredDecRefsLocked() + fs.processDeferredDecRefsLocked(ctx) fs.mu.Unlock() } // Precondition: fs.mu must be held for writing. -func (fs *Filesystem) processDeferredDecRefsLocked() { +func (fs *Filesystem) processDeferredDecRefsLocked(ctx context.Context) { fs.droppedDentriesMu.Lock() for _, d := range fs.droppedDentries { - d.DecRef() + d.DecRef(ctx) } fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse. fs.droppedDentriesMu.Unlock() @@ -212,16 +212,16 @@ func (d *Dentry) isSymlink() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef() { - d.AtomicRefCount.DecRefWithDestructor(d.destroy) +func (d *Dentry) DecRef(ctx context.Context) { + d.AtomicRefCount.DecRefWithDestructor(ctx, d.destroy) } // Precondition: Dentry must be removed from VFS' dentry cache. -func (d *Dentry) destroy() { - d.inode.DecRef() // IncRef from Init. +func (d *Dentry) destroy(ctx context.Context) { + d.inode.DecRef(ctx) // IncRef from Init. d.inode = nil if d.parent != nil { - d.parent.DecRef() // IncRef from Dentry.InsertChild. + d.parent.DecRef(ctx) // IncRef from Dentry.InsertChild. } } @@ -230,7 +230,7 @@ func (d *Dentry) destroy() { // Although Linux technically supports inotify on pseudo filesystems (inotify // is implemented at the vfs layer), it is not particularly useful. It is left // unimplemented until someone actually needs it. -func (d *Dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {} +func (d *Dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {} // Watches implements vfs.DentryImpl.Watches. func (d *Dentry) Watches() *vfs.Watches { @@ -238,7 +238,7 @@ func (d *Dentry) Watches() *vfs.Watches { } // OnZeroWatches implements vfs.Dentry.OnZeroWatches. -func (d *Dentry) OnZeroWatches() {} +func (d *Dentry) OnZeroWatches(context.Context) {} // InsertChild inserts child into the vfs dentry cache with the given name under // this dentry. This does not update the directory inode, so calling this on @@ -326,12 +326,12 @@ type Inode interface { type inodeRefs interface { IncRef() - DecRef() + DecRef(ctx context.Context) TryIncRef() bool // Destroy is called when the inode reaches zero references. Destroy release // all resources (references) on objects referenced by the inode, including // any child dentries. - Destroy() + Destroy(ctx context.Context) } type inodeMetadata interface { diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index dc407eb1d..c5d5afedf 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -46,7 +46,7 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System { ctx := contexttest.Context(t) creds := auth.CredentialsFromContext(ctx) v := &vfs.VirtualFilesystem{} - if err := v.Init(); err != nil { + if err := v.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{ @@ -163,7 +163,7 @@ func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (* dir := d.fs.newDir(creds, opts.Mode, nil) dirVFSD := dir.VFSDentry() if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil { - dir.DecRef() + dir.DecRef(ctx) return nil, err } d.IncLinks(1) @@ -175,7 +175,7 @@ func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (* f := d.fs.newFile(creds, "") fVFSD := f.VFSDentry() if err := d.OrderedChildren.Insert(name, fVFSD); err != nil { - f.DecRef() + f.DecRef(ctx) return nil, err } return fVFSD, nil @@ -213,7 +213,7 @@ func TestBasic(t *testing.T) { }) }) defer sys.Destroy() - sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef() + sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef(sys.Ctx) } func TestMkdirGetDentry(t *testing.T) { @@ -228,7 +228,7 @@ func TestMkdirGetDentry(t *testing.T) { if err := sys.VFS.MkdirAt(sys.Ctx, sys.Creds, pop, &vfs.MkdirOptions{Mode: 0755}); err != nil { t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err) } - sys.GetDentryOrDie(pop).DecRef() + sys.GetDentryOrDie(pop).DecRef(sys.Ctx) } func TestReadStaticFile(t *testing.T) { @@ -246,7 +246,7 @@ func TestReadStaticFile(t *testing.T) { if err != nil { t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) } - defer fd.DecRef() + defer fd.DecRef(sys.Ctx) content, err := sys.ReadToEnd(fd) if err != nil { @@ -273,7 +273,7 @@ func TestCreateNewFileInStaticDir(t *testing.T) { } // Close the file. The file should persist. - fd.DecRef() + fd.DecRef(sys.Ctx) fd, err = sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{ Flags: linux.O_RDONLY, @@ -281,7 +281,7 @@ func TestCreateNewFileInStaticDir(t *testing.T) { if err != nil { t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) } - fd.DecRef() + fd.DecRef(sys.Ctx) } func TestDirFDReadWrite(t *testing.T) { @@ -297,7 +297,7 @@ func TestDirFDReadWrite(t *testing.T) { if err != nil { t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) } - defer fd.DecRef() + defer fd.DecRef(sys.Ctx) // Read/Write should fail for directory FDs. if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR { diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 8f8dcfafe..b3d19ff82 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -98,7 +98,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { if err != nil { return err } - defer oldFD.DecRef() + defer oldFD.DecRef(ctx) newFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &newpop, &vfs.OpenOptions{ Flags: linux.O_WRONLY | linux.O_CREAT | linux.O_EXCL, Mode: linux.FileMode(d.mode &^ linux.S_IFMT), @@ -106,7 +106,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { if err != nil { return err } - defer newFD.DecRef() + defer newFD.DecRef(ctx) bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size for { readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{}) @@ -241,13 +241,13 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { Mask: linux.STATX_INO, }) if err != nil { - d.upperVD.DecRef() + d.upperVD.DecRef(ctx) d.upperVD = vfs.VirtualDentry{} cleanupUndoCopyUp() return err } if upperStat.Mask&linux.STATX_INO == 0 { - d.upperVD.DecRef() + d.upperVD.DecRef(ctx) d.upperVD = vfs.VirtualDentry{} cleanupUndoCopyUp() return syserror.EREMOTE diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go index f5c2462a5..fccb94105 100644 --- a/pkg/sentry/fsimpl/overlay/directory.go +++ b/pkg/sentry/fsimpl/overlay/directory.go @@ -46,7 +46,7 @@ func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string readdirErr = err return false } - defer layerFD.DecRef() + defer layerFD.DecRef(ctx) // Reuse slice allocated for maybeWhiteouts from a previous layer to // reduce allocations. @@ -108,7 +108,7 @@ type directoryFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { +func (fd *directoryFD) Release(ctx context.Context) { } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. @@ -177,7 +177,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { readdirErr = err return false } - defer layerFD.DecRef() + defer layerFD.DecRef(ctx) // Reuse slice allocated for maybeWhiteouts from a previous layer to // reduce allocations. @@ -282,6 +282,6 @@ func (fd *directoryFD) Sync(ctx context.Context) error { return err } err = upperFD.Sync(ctx) - upperFD.DecRef() + upperFD.DecRef(ctx) return err } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 6b705e955..986b36ead 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -77,7 +77,7 @@ func putDentrySlice(ds *[]*dentry) { // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckDrop(ds **[]*dentry) { +func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { fs.renameMu.RUnlock() if *ds == nil { return @@ -85,20 +85,20 @@ func (fs *filesystem) renameMuRUnlockAndCheckDrop(ds **[]*dentry) { if len(**ds) != 0 { fs.renameMu.Lock() for _, d := range **ds { - d.checkDropLocked() + d.checkDropLocked(ctx) } fs.renameMu.Unlock() } putDentrySlice(*ds) } -func (fs *filesystem) renameMuUnlockAndCheckDrop(ds **[]*dentry) { +func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { if *ds == nil { fs.renameMu.Unlock() return } for _, d := range **ds { - d.checkDropLocked() + d.checkDropLocked(ctx) } fs.renameMu.Unlock() putDentrySlice(*ds) @@ -126,13 +126,13 @@ afterSymlink: return d, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() return d, nil } - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, err } rp.Advance() @@ -142,7 +142,7 @@ afterSymlink: if err != nil { return nil, err } - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { @@ -272,11 +272,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str }) if lookupErr != nil { - child.destroyLocked() + child.destroyLocked(ctx) return nil, lookupErr } if !existsOnAnyLayer { - child.destroyLocked() + child.destroyLocked(ctx) return nil, syserror.ENOENT } @@ -430,7 +430,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -501,7 +501,7 @@ func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.V func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err @@ -513,7 +513,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -532,7 +532,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -553,7 +553,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -720,7 +720,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) if rp.Done() { @@ -825,7 +825,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf fd.LockFD.Init(&d.locks) layerFDOpts := layerFD.Options() if err := fd.vfsfd.Init(fd, layerFlags, mnt, &d.vfsd, &layerFDOpts); err != nil { - layerFD.DecRef() + layerFD.DecRef(ctx) return nil, err } return &fd.vfsfd, nil @@ -920,7 +920,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving fd.LockFD.Init(&child.locks) upperFDOpts := upperFD.Options() if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil { - upperFD.DecRef() + upperFD.DecRef(ctx) // Don't bother with cleanup; the file was created successfully, we // just can't open it anymore for some reason. return nil, err @@ -932,7 +932,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err @@ -952,7 +952,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa var ds *[]*dentry fs.renameMu.Lock() - defer fs.renameMuUnlockAndCheckDrop(&ds) + defer fs.renameMuUnlockAndCheckDrop(ctx, &ds) newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds) if err != nil { return err @@ -979,7 +979,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -1001,7 +1001,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error } vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) parent.dirMu.Lock() defer parent.dirMu.Unlock() @@ -1086,7 +1086,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } - vfsObj.CommitDeleteDentry(&child.vfsd) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) delete(parent.children, name) ds = appendDentry(ds, child) parent.dirents = nil @@ -1097,7 +1097,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err @@ -1132,7 +1132,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statx{}, err @@ -1160,7 +1160,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) _, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return linux.Statfs{}, err @@ -1211,7 +1211,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) start := rp.Start().Impl().(*dentry) parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { @@ -1233,7 +1233,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) parent.dirMu.Lock() defer parent.dirMu.Unlock() @@ -1298,7 +1298,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } if child != nil { - vfsObj.CommitDeleteDentry(&child.vfsd) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) delete(parent.children, name) ds = appendDentry(ds, child) } @@ -1310,7 +1310,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) _, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err @@ -1324,7 +1324,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) _, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err @@ -1336,7 +1336,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) _, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err @@ -1348,7 +1348,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(&ds) + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) _, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go index c0749e711..d3060a481 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/non_directory.go @@ -81,11 +81,11 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip oldOff, oldOffErr := fd.cachedFD.Seek(ctx, 0, linux.SEEK_CUR) if oldOffErr == nil { if _, err := upperFD.Seek(ctx, oldOff, linux.SEEK_SET); err != nil { - upperFD.DecRef() + upperFD.DecRef(ctx) return nil, err } } - fd.cachedFD.DecRef() + fd.cachedFD.DecRef(ctx) fd.copiedUp = true fd.cachedFD = upperFD fd.cachedFlags = statusFlags @@ -99,8 +99,8 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *nonDirectoryFD) Release() { - fd.cachedFD.DecRef() +func (fd *nonDirectoryFD) Release(ctx context.Context) { + fd.cachedFD.DecRef(ctx) fd.cachedFD = nil } @@ -138,7 +138,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux Mask: layerMask, Sync: opts.Sync, }) - wrappedFD.DecRef() + wrappedFD.DecRef(ctx) if err != nil { return linux.Statx{}, err } @@ -187,7 +187,7 @@ func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, off if err != nil { return 0, err } - defer wrappedFD.DecRef() + defer wrappedFD.DecRef(ctx) return wrappedFD.PRead(ctx, dst, offset, opts) } @@ -209,7 +209,7 @@ func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, of if err != nil { return 0, err } - defer wrappedFD.DecRef() + defer wrappedFD.DecRef(ctx) return wrappedFD.PWrite(ctx, src, offset, opts) } @@ -250,7 +250,7 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error { return err } wrappedFD.IncRef() - defer wrappedFD.DecRef() + defer wrappedFD.DecRef(ctx) fd.mu.Unlock() return wrappedFD.Sync(ctx) } @@ -261,6 +261,6 @@ func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOp if err != nil { return err } - defer wrappedFD.DecRef() + defer wrappedFD.DecRef(ctx) return wrappedFD.ConfigureMMap(ctx, opts) } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index e720d4825..75cc006bf 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -123,7 +123,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // filesystem with any number of lower layers. } else { vfsroot := vfs.RootFromContext(ctx) - defer vfsroot.DecRef() + defer vfsroot.DecRef(ctx) upperPathname, ok := mopts["upperdir"] if ok { delete(mopts, "upperdir") @@ -147,13 +147,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) return nil, nil, err } - defer upperRoot.DecRef() + defer upperRoot.DecRef(ctx) privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) if err != nil { ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) return nil, nil, err } - defer privateUpperRoot.DecRef() + defer privateUpperRoot.DecRef(ctx) fsopts.UpperRoot = privateUpperRoot } lowerPathnamesStr, ok := mopts["lowerdir"] @@ -190,13 +190,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer lowerRoot.DecRef() + defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) if err != nil { ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer privateLowerRoot.DecRef() + defer privateLowerRoot.DecRef(ctx) fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot) } } @@ -264,19 +264,19 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Mask: rootStatMask, }) if err != nil { - root.destroyLocked() - fs.vfsfs.DecRef() + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) return nil, nil, err } if rootStat.Mask&rootStatMask != rootStatMask { - root.destroyLocked() - fs.vfsfs.DecRef() + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EREMOTE } if isWhiteout(&rootStat) { ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout") - root.destroyLocked() - fs.vfsfs.DecRef() + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EINVAL } root.mode = uint32(rootStat.Mode) @@ -319,17 +319,17 @@ func clonePrivateMount(vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, forc } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { vfsObj := fs.vfsfs.VirtualFilesystem() vfsObj.PutAnonBlockDevMinor(fs.dirDevMinor) for _, lowerDevMinor := range fs.lowerDevMinors { vfsObj.PutAnonBlockDevMinor(lowerDevMinor) } if fs.opts.UpperRoot.Ok() { - fs.opts.UpperRoot.DecRef() + fs.opts.UpperRoot.DecRef(ctx) } for _, lowerRoot := range fs.opts.LowerRoots { - lowerRoot.DecRef() + lowerRoot.DecRef(ctx) } } @@ -452,10 +452,10 @@ func (d *dentry) TryIncRef() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { +func (d *dentry) DecRef(ctx context.Context) { if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { d.fs.renameMu.Lock() - d.checkDropLocked() + d.checkDropLocked(ctx) d.fs.renameMu.Unlock() } else if refs < 0 { panic("overlay.dentry.DecRef() called without holding a reference") @@ -466,7 +466,7 @@ func (d *dentry) DecRef() { // becomes deleted. // // Preconditions: d.fs.renameMu must be locked for writing. -func (d *dentry) checkDropLocked() { +func (d *dentry) checkDropLocked(ctx context.Context) { // Dentries with a positive reference count must be retained. (The only way // to obtain a reference on a dentry with zero references is via path // resolution, which requires renameMu, so if d.refs is zero then it will @@ -476,14 +476,14 @@ func (d *dentry) checkDropLocked() { return } // Refs is still zero; destroy it. - d.destroyLocked() + d.destroyLocked(ctx) return } // destroyLocked destroys the dentry. // // Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. -func (d *dentry) destroyLocked() { +func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: // Mark the dentry destroyed. @@ -495,10 +495,10 @@ func (d *dentry) destroyLocked() { } if d.upperVD.Ok() { - d.upperVD.DecRef() + d.upperVD.DecRef(ctx) } for _, lowerVD := range d.lowerVDs { - lowerVD.DecRef() + lowerVD.DecRef(ctx) } if d.parent != nil { @@ -510,7 +510,7 @@ func (d *dentry) destroyLocked() { // Drop the reference held by d on its parent without recursively // locking d.fs.renameMu. if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkDropLocked() + d.parent.checkDropLocked(ctx) } else if refs < 0 { panic("overlay.dentry.DecRef() called without holding a reference") } @@ -518,7 +518,7 @@ func (d *dentry) destroyLocked() { } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) { +func (d *dentry) InotifyWithParent(ctx context.Context, events uint32, cookie uint32, et vfs.EventType) { // TODO(gvisor.dev/issue/1479): Implement inotify. } @@ -531,7 +531,7 @@ func (d *dentry) Watches() *vfs.Watches { // OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. // // TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *dentry) OnZeroWatches() {} +func (d *dentry) OnZeroWatches(context.Context) {} // iterLayers invokes yield on each layer comprising d, from top to bottom. If // any call to yield returns false, iterLayer stops iteration. diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 811f80a5f..2ca793db9 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -63,9 +63,9 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // PrependPath implements vfs.FilesystemImpl.PrependPath. @@ -160,6 +160,6 @@ func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vf inode := newInode(ctx, fs) var d kernfs.Dentry d.Init(inode) - defer d.DecRef() + defer d.DecRef(ctx) return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags) } diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 609210253..2463d51cd 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -77,9 +77,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // dynamicInode is an overfitted interface for common Inodes with diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index fea29e5f0..f0d3f7f5e 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -43,12 +43,12 @@ func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) return file, flags } -func taskFDExists(t *kernel.Task, fd int32) bool { +func taskFDExists(ctx context.Context, t *kernel.Task, fd int32) bool { file, _ := getTaskFD(t, fd) if file == nil { return false } - file.DecRef() + file.DecRef(ctx) return true } @@ -68,7 +68,7 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, off var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.GetFDs() + fds = fdTable.GetFDs(ctx) } }) @@ -135,7 +135,7 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro return nil, syserror.ENOENT } fd := int32(fdInt) - if !taskFDExists(i.task, fd) { + if !taskFDExists(ctx, i.task, fd) { return nil, syserror.ENOENT } taskDentry := i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()) @@ -204,9 +204,9 @@ func (s *fdSymlink) Readlink(ctx context.Context) (string, error) { if file == nil { return "", syserror.ENOENT } - defer file.DecRef() + defer file.DecRef(ctx) root := vfs.RootFromContext(ctx) - defer root.DecRef() + defer root.DecRef(ctx) return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry()) } @@ -215,7 +215,7 @@ func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDen if file == nil { return vfs.VirtualDentry{}, "", syserror.ENOENT } - defer file.DecRef() + defer file.DecRef(ctx) vd := file.VirtualDentry() vd.IncRef() return vd, "", nil @@ -258,7 +258,7 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, return nil, syserror.ENOENT } fd := int32(fdInt) - if !taskFDExists(i.task, fd) { + if !taskFDExists(ctx, i.task, fd) { return nil, syserror.ENOENT } data := &fdInfoData{ @@ -297,7 +297,7 @@ func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { if file == nil { return syserror.ENOENT } - defer file.DecRef() + defer file.DecRef(ctx) // TODO(b/121266871): Include pos, locks, and other data. For now we only // have flags. // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 859b7d727..830b78949 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -677,7 +677,7 @@ func (s *exeSymlink) Readlink(ctx context.Context) (string, error) { if err != nil { return "", err } - defer exec.DecRef() + defer exec.DecRef(ctx) return exec.PathnameWithDeleted(ctx), nil } @@ -692,7 +692,7 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent if err != nil { return vfs.VirtualDentry{}, "", err } - defer exec.DecRef() + defer exec.DecRef(ctx) vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry() vd.IncRef() @@ -748,7 +748,7 @@ func (i *mountInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Root has been destroyed. Don't try to read mounts. return nil } - defer rootDir.DecRef() + defer rootDir.DecRef(ctx) i.task.Kernel().VFS().GenerateProcMountInfo(ctx, rootDir, buf) return nil } @@ -779,7 +779,7 @@ func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Root has been destroyed. Don't try to read mounts. return nil } - defer rootDir.DecRef() + defer rootDir.DecRef(ctx) i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf) return nil } @@ -825,7 +825,7 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir dentry.Init(&namespaceInode{}) vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry()) vd.IncRef() - dentry.DecRef() + dentry.DecRef(ctx) return vd, "", nil } @@ -887,8 +887,8 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err } // Release implements FileDescriptionImpl. -func (fd *namespaceFD) Release() { - fd.inode.DecRef() +func (fd *namespaceFD) Release(ctx context.Context) { + fd.inode.DecRef(ctx) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 6bde27376..a4c884bf9 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -212,7 +212,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { continue } if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX { - s.DecRef() + s.DecRef(ctx) // Not a unix socket. continue } @@ -281,7 +281,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { } fmt.Fprintf(buf, "\n") - s.DecRef() + s.DecRef(ctx) } return nil } @@ -359,7 +359,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s)) } if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { - s.DecRef() + s.DecRef(ctx) // Not tcp4 sockets. continue } @@ -455,7 +455,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, fmt.Fprintf(buf, "\n") - s.DecRef() + s.DecRef(ctx) } return nil @@ -524,7 +524,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s)) } if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM { - s.DecRef() + s.DecRef(ctx) // Not udp4 socket. continue } @@ -600,7 +600,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "\n") - s.DecRef() + s.DecRef(ctx) } return nil } diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 19abb5034..3c9297dee 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -218,7 +218,7 @@ func TestTasks(t *testing.T) { if err != nil { t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err) } - defer fd.DecRef() + defer fd.DecRef(s.Ctx) buf := make([]byte, 1) bufIOSeq := usermem.BytesIOSequence(buf) if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR { @@ -336,7 +336,7 @@ func TestTasksOffset(t *testing.T) { if err != nil { t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) } - defer fd.DecRef() + defer fd.DecRef(s.Ctx) if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil { t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err) } @@ -441,7 +441,7 @@ func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.F t.Errorf("vfsfs.OpenAt(%v) failed: %v", absPath, err) continue } - defer child.DecRef() + defer child.DecRef(ctx) stat, err := child.Stat(ctx, vfs.StatOptions{}) if err != nil { t.Errorf("Stat(%v) failed: %v", absPath, err) @@ -476,7 +476,7 @@ func TestTree(t *testing.T) { if err != nil { t.Fatalf("failed to create test file: %v", err) } - defer file.DecRef() + defer file.DecRef(s.Ctx) var tasks []*kernel.Task for i := 0; i < 5; i++ { @@ -501,5 +501,5 @@ func TestTree(t *testing.T) { t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err) } iterateDir(ctx, t, s, fd) - fd.DecRef() + fd.DecRef(ctx) } diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 242ba9b5d..6297e1df4 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -54,7 +54,7 @@ var _ vfs.FileDescriptionImpl = (*SignalFileDescription)(nil) // New creates a new signal fd. func New(vfsObj *vfs.VirtualFilesystem, target *kernel.Task, mask linux.SignalSet, flags uint32) (*vfs.FileDescription, error) { vd := vfsObj.NewAnonVirtualDentry("[signalfd]") - defer vd.DecRef() + defer vd.DecRef(target) sfd := &SignalFileDescription{ target: target, mask: mask, @@ -133,4 +133,4 @@ func (sfd *SignalFileDescription) EventUnregister(entry *waiter.Entry) { } // Release implements FileDescriptionImpl.Release() -func (sfd *SignalFileDescription) Release() {} +func (sfd *SignalFileDescription) Release(context.Context) {} diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index ee0828a15..c61818ff6 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -67,9 +67,9 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 01ce30a4d..f81b0c38f 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -87,9 +87,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) - fs.Filesystem.Release() + fs.Filesystem.Release(ctx) } // dir implements kernfs.Inode. diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 242d5fd12..9fd38b295 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -59,7 +59,7 @@ func TestReadCPUFile(t *testing.T) { if err != nil { t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) } - defer fd.DecRef() + defer fd.DecRef(s.Ctx) content, err := s.ReadToEnd(fd) if err != nil { t.Fatalf("Read failed: %v", err) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index e743e8114..1e57744e8 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -127,7 +127,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns return nil, err } m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) - m.SetExecutable(fsbridge.NewVFSFile(exe)) + m.SetExecutable(ctx, fsbridge.NewVFSFile(exe)) config := &kernel.TaskConfig{ Kernel: k, diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go index 0556af877..568132121 100644 --- a/pkg/sentry/fsimpl/testutil/testutil.go +++ b/pkg/sentry/fsimpl/testutil/testutil.go @@ -97,8 +97,8 @@ func (s *System) WithTemporaryContext(ctx context.Context) *System { // Destroy release resources associated with a test system. func (s *System) Destroy() { - s.Root.DecRef() - s.MntNs.DecRef() // Reference on MntNs passed to NewSystem. + s.Root.DecRef(s.Ctx) + s.MntNs.DecRef(s.Ctx) // Reference on MntNs passed to NewSystem. } // ReadToEnd reads the contents of fd until EOF to a string. @@ -149,7 +149,7 @@ func (s *System) ListDirents(pop *vfs.PathOperation) *DirentCollector { if err != nil { s.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) } - defer fd.DecRef() + defer fd.DecRef(s.Ctx) collector := &DirentCollector{} if err := fd.IterDirents(s.Ctx, collector); err != nil { diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go index 2dc90d484..86beaa0a8 100644 --- a/pkg/sentry/fsimpl/timerfd/timerfd.go +++ b/pkg/sentry/fsimpl/timerfd/timerfd.go @@ -47,9 +47,9 @@ var _ vfs.FileDescriptionImpl = (*TimerFileDescription)(nil) var _ ktime.TimerListener = (*TimerFileDescription)(nil) // New returns a new timer fd. -func New(vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.FileDescription, error) { +func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.FileDescription, error) { vd := vfsObj.NewAnonVirtualDentry("[timerfd]") - defer vd.DecRef() + defer vd.DecRef(ctx) tfd := &TimerFileDescription{} tfd.timer = ktime.NewTimer(clock, tfd) if err := tfd.vfsfd.Init(tfd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{ @@ -129,7 +129,7 @@ func (tfd *TimerFileDescription) ResumeTimer() { } // Release implements FileDescriptionImpl.Release() -func (tfd *TimerFileDescription) Release() { +func (tfd *TimerFileDescription) Release(context.Context) { tfd.timer.Destroy() } diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go index 2fb5c4d84..d263147c2 100644 --- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go @@ -83,7 +83,7 @@ func fileOpOn(ctx context.Context, mntns *fs.MountNamespace, root, wd *fs.Dirent } err = fn(root, d) - d.DecRef() + d.DecRef(ctx) return err } @@ -105,17 +105,17 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to create mount namespace: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') // Create nested directories with given depth. root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) d := root d.IncRef() - defer d.DecRef() + defer d.DecRef(ctx) for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil { @@ -125,7 +125,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to directory %q: %v", name, err) } - d.DecRef() + d.DecRef(ctx) d = next filePathBuilder.WriteString(name) filePathBuilder.WriteByte('/') @@ -136,7 +136,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } - file.DecRef() + file.DecRef(ctx) filePathBuilder.WriteString(filename) filePath := filePathBuilder.String() @@ -176,7 +176,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { // Create VFS. vfsObj := vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { b.Fatalf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -186,14 +186,14 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') // Create nested directories with given depth. root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) vd := root vd.IncRef() for i := depth; i > 0; i-- { @@ -212,7 +212,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to directory %q: %v", name, err) } - vd.DecRef() + vd.DecRef(ctx) vd = nextVD filePathBuilder.WriteString(name) filePathBuilder.WriteByte('/') @@ -228,12 +228,12 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) - vd.DecRef() + vd.DecRef(ctx) vd = vfs.VirtualDentry{} if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } - defer fd.DecRef() + defer fd.DecRef(ctx) filePathBuilder.WriteString(filename) filePath := filePathBuilder.String() @@ -278,14 +278,14 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to create mount namespace: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') // Create and mount the submount. root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) if err := root.Inode.CreateDirectory(ctx, root, mountPointName, fs.FilePermsFromMode(0755)); err != nil { b.Fatalf("failed to create mount point: %v", err) } @@ -293,7 +293,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount point: %v", err) } - defer mountPoint.DecRef() + defer mountPoint.DecRef(ctx) submountInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil) if err != nil { b.Fatalf("failed to create tmpfs submount: %v", err) @@ -309,7 +309,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount root: %v", err) } - defer d.DecRef() + defer d.DecRef(ctx) for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil { @@ -319,7 +319,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to directory %q: %v", name, err) } - d.DecRef() + d.DecRef(ctx) d = next filePathBuilder.WriteString(name) filePathBuilder.WriteByte('/') @@ -330,7 +330,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } - file.DecRef() + file.DecRef(ctx) filePathBuilder.WriteString(filename) filePath := filePathBuilder.String() @@ -370,7 +370,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { // Create VFS. vfsObj := vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { b.Fatalf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -380,14 +380,14 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } - defer mntns.DecRef() + defer mntns.DecRef(ctx) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') // Create the mount point. root := mntns.Root() - defer root.DecRef() + defer root.DecRef(ctx) pop := vfs.PathOperation{ Root: root, Start: root, @@ -403,7 +403,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount point: %v", err) } - defer mountPoint.DecRef() + defer mountPoint.DecRef(ctx) // Create and mount the submount. if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) @@ -432,7 +432,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to directory %q: %v", name, err) } - vd.DecRef() + vd.DecRef(ctx) vd = nextVD filePathBuilder.WriteString(name) filePathBuilder.WriteByte('/') @@ -448,11 +448,11 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) - vd.DecRef() + vd.DecRef(ctx) if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } - fd.DecRef() + fd.DecRef(ctx) filePathBuilder.WriteString(filename) filePath := filePathBuilder.String() diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 0a1ad4765..78b4fc5be 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -95,7 +95,7 @@ type directoryFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *directoryFD) Release() { +func (fd *directoryFD) Release(ctx context.Context) { if fd.iter != nil { dir := fd.inode().impl.(*directory) dir.iterMu.Lock() @@ -110,7 +110,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fs := fd.filesystem() dir := fd.inode().impl.(*directory) - defer fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + defer fd.dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent) // fs.mu is required to read d.parent and dentry.name. fs.mu.RLock() diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index ef210a69b..fb77f95cc 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -40,7 +40,7 @@ func (fs *filesystem) Sync(ctx context.Context) error { // stepLocked is loosely analogous to fs/namei.c:walk_component(). // // Preconditions: filesystem.mu must be locked. !rp.Done(). -func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { +func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { dir, ok := d.inode.impl.(*directory) if !ok { return nil, syserror.ENOTDIR @@ -55,13 +55,13 @@ afterSymlink: return d, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(&d.vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() return d, nil } - if err := rp.CheckMount(&d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { return nil, err } rp.Advance() @@ -74,7 +74,7 @@ afterSymlink: if !ok { return nil, syserror.ENOENT } - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err } if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { @@ -98,9 +98,9 @@ afterSymlink: // fs/namei.c:path_parentat(). // // Preconditions: filesystem.mu must be locked. !rp.Done(). -func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*directory, error) { +func walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*directory, error) { for !rp.Final() { - next, err := stepLocked(rp, d) + next, err := stepLocked(ctx, rp, d) if err != nil { return nil, err } @@ -118,10 +118,10 @@ func walkParentDirLocked(rp *vfs.ResolvingPath, d *dentry) (*directory, error) { // resolveLocked is loosely analogous to Linux's fs/namei.c:path_lookupat(). // // Preconditions: filesystem.mu must be locked. -func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) { +func resolveLocked(ctx context.Context, rp *vfs.ResolvingPath) (*dentry, error) { d := rp.Start().Impl().(*dentry) for !rp.Done() { - next, err := stepLocked(rp, d) + next, err := stepLocked(ctx, rp, d) if err != nil { return nil, err } @@ -141,10 +141,10 @@ func resolveLocked(rp *vfs.ResolvingPath) (*dentry, error) { // // Preconditions: !rp.Done(). For the final path component in rp, // !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error { +func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error { fs.mu.Lock() defer fs.mu.Unlock() - parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } @@ -182,7 +182,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa if dir { ev |= linux.IN_ISDIR } - parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + parentDir.inode.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) parentDir.inode.touchCMtime() return nil } @@ -191,7 +191,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return err } @@ -202,7 +202,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } @@ -222,7 +222,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) { fs.mu.RLock() defer fs.mu.RUnlock() - dir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + dir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return nil, err } @@ -232,7 +232,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error { if rp.Mount() != vd.Mount() { return syserror.EXDEV } @@ -251,7 +251,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EMLINK } i.incLinksLocked() - i.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */) + i.watches.Notify(ctx, "", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */) parentDir.insertChildLocked(fs.newDentry(i), name) return nil }) @@ -259,7 +259,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { - return fs.doCreateAt(rp, true /* dir */, func(parentDir *directory, name string) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parentDir *directory, name string) error { creds := rp.Credentials() if parentDir.inode.nlink == maxLinks { return syserror.EMLINK @@ -273,7 +273,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error { creds := rp.Credentials() var childInode *inode switch opts.Mode.FileType() { @@ -308,7 +308,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } @@ -330,7 +330,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf return start.open(ctx, rp, &opts, false /* afterCreate */) } afterTrailingSymlink: - parentDir, err := walkParentDirLocked(rp, start) + parentDir, err := walkParentDirLocked(ctx, rp, start) if err != nil { return nil, err } @@ -368,7 +368,7 @@ afterTrailingSymlink: if err != nil { return nil, err } - parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) + parentDir.inode.watches.Notify(ctx, name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) parentDir.inode.touchCMtime() return fd, nil } @@ -376,7 +376,7 @@ afterTrailingSymlink: return nil, syserror.EEXIST } // Is the file mounted over? - if err := rp.CheckMount(&child.vfsd); err != nil { + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { return nil, err } // Do we need to resolve a trailing symlink? @@ -445,7 +445,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return "", err } @@ -467,7 +467,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Resolve newParent first to verify that it's on this Mount. fs.mu.Lock() defer fs.mu.Unlock() - newParentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } @@ -555,7 +555,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) var replacedVFSD *vfs.Dentry if replaced != nil { replacedVFSD = &replaced.vfsd @@ -566,17 +566,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if replaced != nil { newParentDir.removeChildLocked(replaced) if replaced.inode.isDir() { - newParentDir.inode.decLinksLocked() // from replaced's ".." + newParentDir.inode.decLinksLocked(ctx) // from replaced's ".." } - replaced.inode.decLinksLocked() + replaced.inode.decLinksLocked(ctx) } oldParentDir.removeChildLocked(renamed) newParentDir.insertChildLocked(renamed, newName) - vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD) + vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD) oldParentDir.inode.touchCMtime() if oldParentDir != newParentDir { if renamed.inode.isDir() { - oldParentDir.inode.decLinksLocked() + oldParentDir.inode.decLinksLocked(ctx) newParentDir.inode.incLinksLocked() } newParentDir.inode.touchCMtime() @@ -591,7 +591,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() - parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } @@ -626,17 +626,17 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error defer mnt.EndWrite() vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { return err } parentDir.removeChildLocked(child) - parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) + parentDir.inode.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) // Remove links for child, child/., and child/.. - child.inode.decLinksLocked() - child.inode.decLinksLocked() - parentDir.inode.decLinksLocked() - vfsObj.CommitDeleteDentry(&child.vfsd) + child.inode.decLinksLocked(ctx) + child.inode.decLinksLocked(ctx) + parentDir.inode.decLinksLocked(ctx) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) parentDir.inode.touchCMtime() return nil } @@ -644,7 +644,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // SetStatAt implements vfs.FilesystemImpl.SetStatAt. func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { fs.mu.RLock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err @@ -656,7 +656,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts fs.mu.RUnlock() if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { - d.InotifyWithParent(ev, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil } @@ -665,7 +665,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return linux.Statx{}, err } @@ -678,7 +678,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { fs.mu.RLock() defer fs.mu.RUnlock() - if _, err := resolveLocked(rp); err != nil { + if _, err := resolveLocked(ctx, rp); err != nil { return linux.Statfs{}, err } statfs := linux.Statfs{ @@ -695,7 +695,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error { creds := rp.Credentials() child := fs.newDentry(fs.newSymlink(creds.EffectiveKUID, creds.EffectiveKGID, 0777, target)) parentDir.insertChildLocked(child, name) @@ -707,7 +707,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() - parentDir, err := walkParentDirLocked(rp, rp.Start().Impl().(*dentry)) + parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } @@ -738,7 +738,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error defer mnt.EndWrite() vfsObj := rp.VirtualFilesystem() mntns := vfs.MountNamespaceFromContext(ctx) - defer mntns.DecRef() + defer mntns.DecRef(ctx) if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { return err } @@ -746,11 +746,11 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error // Generate inotify events. Note that this must take place before the link // count of the child is decremented, or else the watches may be dropped // before these events are added. - vfs.InotifyRemoveChild(&child.inode.watches, &parentDir.inode.watches, name) + vfs.InotifyRemoveChild(ctx, &child.inode.watches, &parentDir.inode.watches, name) parentDir.removeChildLocked(child) - child.inode.decLinksLocked() - vfsObj.CommitDeleteDentry(&child.vfsd) + child.inode.decLinksLocked(ctx) + vfsObj.CommitDeleteDentry(ctx, &child.vfsd) parentDir.inode.touchCMtime() return nil } @@ -759,7 +759,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } @@ -778,7 +778,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } @@ -789,7 +789,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { return "", err } @@ -799,7 +799,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt // SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { fs.mu.RLock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err @@ -810,14 +810,14 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt } fs.mu.RUnlock() - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() - d, err := resolveLocked(rp) + d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err @@ -828,7 +828,7 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, } fs.mu.RUnlock() - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index 1614f2c39..ec2701d8b 100644 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -32,7 +32,7 @@ const fileName = "mypipe" func TestSeparateFDs(t *testing.T) { ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() + defer root.DecRef(ctx) // Open the read side. This is done in a concurrently because opening // One end the pipe blocks until the other end is opened. @@ -55,13 +55,13 @@ func TestSeparateFDs(t *testing.T) { if err != nil { t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) } - defer wfd.DecRef() + defer wfd.DecRef(ctx) rfd, ok := <-rfdchan if !ok { t.Fatalf("failed to open pipe for reading %q", fileName) } - defer rfd.DecRef() + defer rfd.DecRef(ctx) const msg = "vamos azul" checkEmpty(ctx, t, rfd) @@ -71,7 +71,7 @@ func TestSeparateFDs(t *testing.T) { func TestNonblockingRead(t *testing.T) { ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() + defer root.DecRef(ctx) // Open the read side as nonblocking. pop := vfs.PathOperation{ @@ -85,7 +85,7 @@ func TestNonblockingRead(t *testing.T) { if err != nil { t.Fatalf("failed to open pipe for reading %q: %v", fileName, err) } - defer rfd.DecRef() + defer rfd.DecRef(ctx) // Open the write side. openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY} @@ -93,7 +93,7 @@ func TestNonblockingRead(t *testing.T) { if err != nil { t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) } - defer wfd.DecRef() + defer wfd.DecRef(ctx) const msg = "geh blau" checkEmpty(ctx, t, rfd) @@ -103,7 +103,7 @@ func TestNonblockingRead(t *testing.T) { func TestNonblockingWriteError(t *testing.T) { ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() + defer root.DecRef(ctx) // Open the write side as nonblocking, which should return ENXIO. pop := vfs.PathOperation{ @@ -121,7 +121,7 @@ func TestNonblockingWriteError(t *testing.T) { func TestSingleFD(t *testing.T) { ctx, creds, vfsObj, root := setup(t) - defer root.DecRef() + defer root.DecRef(ctx) // Open the pipe as readable and writable. pop := vfs.PathOperation{ @@ -135,7 +135,7 @@ func TestSingleFD(t *testing.T) { if err != nil { t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) } - defer fd.DecRef() + defer fd.DecRef(ctx) const msg = "forza blu" checkEmpty(ctx, t, fd) @@ -152,7 +152,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy // Create VFS. vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index abbaa5d60..0710b65db 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -270,7 +270,7 @@ type regularFileFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *regularFileFD) Release() { +func (fd *regularFileFD) Release(context.Context) { // noop } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 2545d88e9..68e615e8b 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -185,7 +185,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt case linux.S_IFDIR: root = &fs.newDirectory(rootKUID, rootKGID, rootMode).dentry default: - fs.vfsfs.DecRef() + fs.vfsfs.DecRef(ctx) return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType) } return &fs.vfsfs, &root.vfsd, nil @@ -197,7 +197,7 @@ func NewFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *au } // Release implements vfs.FilesystemImpl.Release. -func (fs *filesystem) Release() { +func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } @@ -249,12 +249,12 @@ func (d *dentry) TryIncRef() bool { } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef() { - d.inode.decRef() +func (d *dentry) DecRef(ctx context.Context) { + d.inode.decRef(ctx) } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { +func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { if d.inode.isDir() { events |= linux.IN_ISDIR } @@ -266,9 +266,9 @@ func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { d.inode.fs.mu.RLock() // The ordering below is important, Linux always notifies the parent first. if d.parent != nil { - d.parent.inode.watches.Notify(d.name, events, cookie, et, deleted) + d.parent.inode.watches.Notify(ctx, d.name, events, cookie, et, deleted) } - d.inode.watches.Notify("", events, cookie, et, deleted) + d.inode.watches.Notify(ctx, "", events, cookie, et, deleted) d.inode.fs.mu.RUnlock() } @@ -278,7 +278,7 @@ func (d *dentry) Watches() *vfs.Watches { } // OnZeroWatches implements vfs.Dentry.OnZeroWatches. -func (d *dentry) OnZeroWatches() {} +func (d *dentry) OnZeroWatches(context.Context) {} // inode represents a filesystem object. type inode struct { @@ -359,12 +359,12 @@ func (i *inode) incLinksLocked() { // remove a reference on i as well. // // Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. -func (i *inode) decLinksLocked() { +func (i *inode) decLinksLocked(ctx context.Context) { if i.nlink == 0 { panic("tmpfs.inode.decLinksLocked() called with no existing links") } if atomic.AddUint32(&i.nlink, ^uint32(0)) == 0 { - i.decRef() + i.decRef(ctx) } } @@ -386,9 +386,9 @@ func (i *inode) tryIncRef() bool { } } -func (i *inode) decRef() { +func (i *inode) decRef(ctx context.Context) { if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { - i.watches.HandleDeletion() + i.watches.HandleDeletion(ctx) if regFile, ok := i.impl.(*regularFile); ok { // Release memory used by regFile to store data. Since regFile is // no longer usable, we don't need to grab any locks or update any @@ -701,7 +701,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) } if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { - d.InotifyWithParent(ev, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil } @@ -724,7 +724,7 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption } // Generate inotify events. - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } @@ -736,13 +736,13 @@ func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { } // Generate inotify events. - d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } // NewMemfd creates a new tmpfs regular file and file description that can back // an anonymous fd created by memfd_create. -func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name string) (*vfs.FileDescription, error) { +func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) { fs, ok := mount.Filesystem().Impl().(*filesystem) if !ok { panic("NewMemfd() called with non-tmpfs mount") @@ -757,7 +757,7 @@ func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name s } d := fs.newDentry(inode) - defer d.DecRef() + defer d.DecRef(ctx) d.name = name // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go index a240fb276..6f3e3ae6f 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go @@ -34,7 +34,7 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr creds := auth.CredentialsFromContext(ctx) vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) } @@ -47,8 +47,8 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr } root := mntns.Root() return vfsObj, root, func() { - root.DecRef() - mntns.DecRef() + root.DecRef(ctx) + mntns.DecRef(ctx) }, nil } diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 920fe4329..52ed5cea2 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -17,6 +17,7 @@ package kernel import ( "syscall" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" @@ -31,7 +32,7 @@ type abstractEndpoint struct { } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (e *abstractEndpoint) WeakRefGone() { +func (e *abstractEndpoint) WeakRefGone(context.Context) { e.ns.mu.Lock() if e.ns.endpoints[e.name].ep == e.ep { delete(e.ns.endpoints, e.name) @@ -64,9 +65,9 @@ type boundEndpoint struct { } // Release implements transport.BoundEndpoint.Release. -func (e *boundEndpoint) Release() { - e.rc.DecRef() - e.BoundEndpoint.Release() +func (e *boundEndpoint) Release(ctx context.Context) { + e.rc.DecRef(ctx) + e.BoundEndpoint.Release(ctx) } // BoundEndpoint retrieves the endpoint bound to the given name. The return @@ -93,13 +94,13 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp // // When the last reference managed by rc is dropped, ep may be removed from the // namespace. -func (a *AbstractSocketNamespace) Bind(name string, ep transport.BoundEndpoint, rc refs.RefCounter) error { +func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, rc refs.RefCounter) error { a.mu.Lock() defer a.mu.Unlock() if ep, ok := a.endpoints[name]; ok { if rc := ep.wr.Get(); rc != nil { - rc.DecRef() + rc.DecRef(ctx) return syscall.EADDRINUSE } } diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 4c0f1e41f..15519f0df 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -76,8 +76,8 @@ type pollEntry struct { // WeakRefGone implements refs.WeakRefUser.WeakRefGone. // weakReferenceGone is called when the file in the weak reference is destroyed. // The poll entry is removed in response to this. -func (p *pollEntry) WeakRefGone() { - p.epoll.RemoveEntry(p.id) +func (p *pollEntry) WeakRefGone(ctx context.Context) { + p.epoll.RemoveEntry(ctx, p.id) } // EventPoll holds all the state associated with an event poll object, that is, @@ -144,14 +144,14 @@ func NewEventPoll(ctx context.Context) *fs.File { // name matches fs/eventpoll.c:epoll_create1. dirent := fs.NewDirent(ctx, anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]")) // Release the initial dirent reference after NewFile takes a reference. - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{ files: make(map[FileIdentifier]*pollEntry), }) } // Release implements fs.FileOperations.Release. -func (e *EventPoll) Release() { +func (e *EventPoll) Release(ctx context.Context) { // We need to take the lock now because files may be attempting to // remove entries in parallel if they get destroyed. e.mu.Lock() @@ -160,7 +160,7 @@ func (e *EventPoll) Release() { // Go through all entries and clean up. for _, entry := range e.files { entry.id.File.EventUnregister(&entry.waiter) - entry.file.Drop() + entry.file.Drop(ctx) } e.files = nil } @@ -423,7 +423,7 @@ func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter } // RemoveEntry a files from the collection of observed files. -func (e *EventPoll) RemoveEntry(id FileIdentifier) error { +func (e *EventPoll) RemoveEntry(ctx context.Context, id FileIdentifier) error { e.mu.Lock() defer e.mu.Unlock() @@ -445,7 +445,7 @@ func (e *EventPoll) RemoveEntry(id FileIdentifier) error { // Remove file from map, and drop weak reference. delete(e.files, id) - entry.file.Drop() + entry.file.Drop(ctx) return nil } diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go index 22630e9c5..55b505593 100644 --- a/pkg/sentry/kernel/epoll/epoll_test.go +++ b/pkg/sentry/kernel/epoll/epoll_test.go @@ -26,7 +26,8 @@ func TestFileDestroyed(t *testing.T) { f := filetest.NewTestFile(t) id := FileIdentifier{f, 12} - efile := NewEventPoll(contexttest.Context(t)) + ctx := contexttest.Context(t) + efile := NewEventPoll(ctx) e := efile.FileOperations.(*EventPoll) if err := e.AddEntry(id, 0, waiter.EventIn, [2]int32{}); err != nil { t.Fatalf("addEntry failed: %v", err) @@ -44,7 +45,7 @@ func TestFileDestroyed(t *testing.T) { } // Destroy the file. Check that we get no more events. - f.DecRef() + f.DecRef(ctx) evt = e.ReadEvents(1) if len(evt) != 0 { diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 87951adeb..bbf568dfc 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -70,7 +70,7 @@ func New(ctx context.Context, initVal uint64, semMode bool) *fs.File { // name matches fs/eventfd.c:eventfd_file_create. dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[eventfd]") // Release the initial dirent reference after NewFile takes a reference. - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{ val: initVal, semMode: semMode, @@ -106,7 +106,7 @@ func (e *EventOperations) HostFD() (int, error) { } // Release implements fs.FileOperations.Release. -func (e *EventOperations) Release() { +func (e *EventOperations) Release(context.Context) { e.mu.Lock() defer e.mu.Unlock() if e.hostfd >= 0 { diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 4b7d234a4..ce53af69b 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -98,7 +98,7 @@ type FDTable struct { func (f *FDTable) saveDescriptorTable() map[int32]descriptor { m := make(map[int32]descriptor) - f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { + f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { m[fd] = descriptor{ file: file, fileVFS2: fileVFS2, @@ -109,6 +109,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { } func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { + ctx := context.Background() f.init() // Initialize table. for fd, d := range m { f.setAll(fd, d.file, d.fileVFS2, d.flags) @@ -118,9 +119,9 @@ func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { // reference taken by set above. switch { case d.file != nil: - d.file.DecRef() + d.file.DecRef(ctx) case d.fileVFS2 != nil: - d.fileVFS2.DecRef() + d.fileVFS2.DecRef(ctx) } } } @@ -144,14 +145,15 @@ func (f *FDTable) drop(file *fs.File) { d.InotifyEvent(ev, 0) // Drop the table reference. - file.DecRef() + file.DecRef(context.Background()) } // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(file *vfs.FileDescription) { // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the // entire file. - err := file.UnlockPOSIX(context.Background(), f, 0, 0, linux.SEEK_SET) + ctx := context.Background() + err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET) if err != nil && err != syserror.ENOLCK { panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) } @@ -161,10 +163,10 @@ func (f *FDTable) dropVFS2(file *vfs.FileDescription) { if file.IsWritable() { ev = linux.IN_CLOSE_WRITE } - file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(ctx, ev, 0, vfs.PathEvent) // Drop the table's reference. - file.DecRef() + file.DecRef(ctx) } // NewFDTable allocates a new FDTable that may be used by tasks in k. @@ -175,15 +177,15 @@ func (k *Kernel) NewFDTable() *FDTable { } // destroy removes all of the file descriptors from the map. -func (f *FDTable) destroy() { - f.RemoveIf(func(*fs.File, *vfs.FileDescription, FDFlags) bool { +func (f *FDTable) destroy(ctx context.Context) { + f.RemoveIf(ctx, func(*fs.File, *vfs.FileDescription, FDFlags) bool { return true }) } // DecRef implements RefCounter.DecRef with destructor f.destroy. -func (f *FDTable) DecRef() { - f.DecRefWithDestructor(f.destroy) +func (f *FDTable) DecRef(ctx context.Context) { + f.DecRefWithDestructor(ctx, f.destroy) } // Size returns the number of file descriptor slots currently allocated. @@ -195,7 +197,7 @@ func (f *FDTable) Size() int { // forEach iterates over all non-nil files in sorted order. // // It is the caller's responsibility to acquire an appropriate lock. -func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) { +func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) { // retries tracks the number of failed TryIncRef attempts for the same FD. retries := 0 fd := int32(0) @@ -214,7 +216,7 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes continue // Race caught. } fn(fd, file, nil, flags) - file.DecRef() + file.DecRef(ctx) case fileVFS2 != nil: if !fileVFS2.TryIncRef() { retries++ @@ -224,7 +226,7 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes continue // Race caught. } fn(fd, nil, fileVFS2, flags) - fileVFS2.DecRef() + fileVFS2.DecRef(ctx) } retries = 0 fd++ @@ -234,7 +236,8 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes // String is a stringer for FDTable. func (f *FDTable) String() string { var buf strings.Builder - f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { + ctx := context.Background() + f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { switch { case file != nil: n, _ := file.Dirent.FullName(nil /* root */) @@ -242,7 +245,7 @@ func (f *FDTable) String() string { case fileVFS2 != nil: vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem() - name, err := vfsObj.PathnameWithDeleted(context.Background(), vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) + name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) if err != nil { fmt.Fprintf(&buf, "\n", err) return @@ -541,9 +544,9 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) { // // Precondition: The caller must be running on the task goroutine, or Task.mu // must be locked. -func (f *FDTable) GetFDs() []int32 { +func (f *FDTable) GetFDs(ctx context.Context) []int32 { fds := make([]int32, 0, int(atomic.LoadInt32(&f.used))) - f.forEach(func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) { + f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) { fds = append(fds, fd) }) return fds @@ -552,9 +555,9 @@ func (f *FDTable) GetFDs() []int32 { // GetRefs returns a stable slice of references to all files and bumps the // reference count on each. The caller must use DecRef on each reference when // they're done using the slice. -func (f *FDTable) GetRefs() []*fs.File { +func (f *FDTable) GetRefs(ctx context.Context) []*fs.File { files := make([]*fs.File, 0, f.Size()) - f.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + f.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { file.IncRef() // Acquire a reference for caller. files = append(files, file) }) @@ -564,9 +567,9 @@ func (f *FDTable) GetRefs() []*fs.File { // GetRefsVFS2 returns a stable slice of references to all files and bumps the // reference count on each. The caller must use DecRef on each reference when // they're done using the slice. -func (f *FDTable) GetRefsVFS2() []*vfs.FileDescription { +func (f *FDTable) GetRefsVFS2(ctx context.Context) []*vfs.FileDescription { files := make([]*vfs.FileDescription, 0, f.Size()) - f.forEach(func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) { + f.forEach(ctx, func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) { file.IncRef() // Acquire a reference for caller. files = append(files, file) }) @@ -574,10 +577,10 @@ func (f *FDTable) GetRefsVFS2() []*vfs.FileDescription { } // Fork returns an independent FDTable. -func (f *FDTable) Fork() *FDTable { +func (f *FDTable) Fork(ctx context.Context) *FDTable { clone := f.k.NewFDTable() - f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { + f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { // The set function here will acquire an appropriate table // reference for the clone. We don't need anything else. switch { @@ -622,11 +625,11 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) { } // RemoveIf removes all FDs where cond is true. -func (f *FDTable) RemoveIf(cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) { +func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) { f.mu.Lock() defer f.mu.Unlock() - f.forEach(func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { + f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { if cond(file, fileVFS2, flags) { f.set(fd, nil, FDFlags{}) // Clear from table. // Update current available position. diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 29f95a2c4..e3f30ba2a 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -154,7 +154,7 @@ func TestFDTable(t *testing.T) { if ref == nil { t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success") } - ref.DecRef() + ref.DecRef(ctx) if ref, _ := fdTable.Remove(1); ref != nil { t.Fatalf("r.Remove(1) for a removed FD: got success, want failure") @@ -191,7 +191,7 @@ func BenchmarkFDLookupAndDecRef(b *testing.B) { b.StartTimer() // Benchmark. for i := 0; i < b.N; i++ { tf, _ := fdTable.Get(fds[i%len(fds)]) - tf.DecRef() + tf.DecRef(ctx) } }) } @@ -219,7 +219,7 @@ func BenchmarkFDLookupAndDecRefConcurrent(b *testing.B) { defer wg.Done() for i := 0; i < each; i++ { tf, _ := fdTable.Get(fds[i%len(fds)]) - tf.DecRef() + tf.DecRef(ctx) } }() } diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index 47f78df9a..8f2d36d5a 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -17,6 +17,7 @@ package kernel import ( "fmt" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -89,28 +90,28 @@ func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext { // Note that there may still be calls to WorkingDirectory() or RootDirectory() // (that return nil). This is because valid references may still be held via // proc files or other mechanisms. -func (f *FSContext) destroy() { +func (f *FSContext) destroy(ctx context.Context) { // Hold f.mu so that we don't race with RootDirectory() and // WorkingDirectory(). f.mu.Lock() defer f.mu.Unlock() if VFS2Enabled { - f.rootVFS2.DecRef() + f.rootVFS2.DecRef(ctx) f.rootVFS2 = vfs.VirtualDentry{} - f.cwdVFS2.DecRef() + f.cwdVFS2.DecRef(ctx) f.cwdVFS2 = vfs.VirtualDentry{} } else { - f.root.DecRef() + f.root.DecRef(ctx) f.root = nil - f.cwd.DecRef() + f.cwd.DecRef(ctx) f.cwd = nil } } // DecRef implements RefCounter.DecRef with destructor f.destroy. -func (f *FSContext) DecRef() { - f.DecRefWithDestructor(f.destroy) +func (f *FSContext) DecRef(ctx context.Context) { + f.DecRefWithDestructor(ctx, f.destroy) } // Fork forks this FSContext. @@ -165,7 +166,7 @@ func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { // This will take an extra reference on the Dirent. // // This is not a valid call after destroy. -func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) { +func (f *FSContext) SetWorkingDirectory(ctx context.Context, d *fs.Dirent) { if d == nil { panic("FSContext.SetWorkingDirectory called with nil dirent") } @@ -180,21 +181,21 @@ func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) { old := f.cwd f.cwd = d d.IncRef() - old.DecRef() + old.DecRef(ctx) } // SetWorkingDirectoryVFS2 sets the current working directory. // This will take an extra reference on the VirtualDentry. // // This is not a valid call after destroy. -func (f *FSContext) SetWorkingDirectoryVFS2(d vfs.VirtualDentry) { +func (f *FSContext) SetWorkingDirectoryVFS2(ctx context.Context, d vfs.VirtualDentry) { f.mu.Lock() defer f.mu.Unlock() old := f.cwdVFS2 f.cwdVFS2 = d d.IncRef() - old.DecRef() + old.DecRef(ctx) } // RootDirectory returns the current filesystem root. @@ -226,7 +227,7 @@ func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { // This will take an extra reference on the Dirent. // // This is not a valid call after free. -func (f *FSContext) SetRootDirectory(d *fs.Dirent) { +func (f *FSContext) SetRootDirectory(ctx context.Context, d *fs.Dirent) { if d == nil { panic("FSContext.SetRootDirectory called with nil dirent") } @@ -241,13 +242,13 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) { old := f.root f.root = d d.IncRef() - old.DecRef() + old.DecRef(ctx) } // SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd. // // This is not a valid call after free. -func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) { +func (f *FSContext) SetRootDirectoryVFS2(ctx context.Context, vd vfs.VirtualDentry) { if !vd.Ok() { panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry") } @@ -263,7 +264,7 @@ func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) { vd.IncRef() f.rootVFS2 = vd f.mu.Unlock() - old.DecRef() + old.DecRef(ctx) } // Umask returns the current umask. diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index c5021f2db..daa2dae76 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -51,6 +51,7 @@ go_test( srcs = ["futex_test.go"], library = ":futex", deps = [ + "//pkg/context", "//pkg/sync", "//pkg/usermem", ], diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index bcc1b29a8..e4dcc4d40 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -19,6 +19,7 @@ package futex import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -66,9 +67,9 @@ type Key struct { Offset uint64 } -func (k *Key) release() { +func (k *Key) release(t Target) { if k.MappingIdentity != nil { - k.MappingIdentity.DecRef() + k.MappingIdentity.DecRef(t) } k.Mappable = nil k.MappingIdentity = nil @@ -94,6 +95,8 @@ func (k *Key) matches(k2 *Key) bool { // Target abstracts memory accesses and keys. type Target interface { + context.Context + // SwapUint32 gives access to usermem.IO.SwapUint32. SwapUint32(addr usermem.Addr, new uint32) (uint32, error) @@ -296,7 +299,7 @@ func (b *bucket) wakeWaiterLocked(w *Waiter) { // bucket "to". // // Preconditions: b and to must be locked. -func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int { +func (b *bucket) requeueLocked(t Target, to *bucket, key, nkey *Key, n int) int { done := 0 for w := b.waiters.Front(); done < n && w != nil; { if !w.key.matches(key) { @@ -308,7 +311,7 @@ func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int { requeued := w w = w.Next() // Next iteration. b.waiters.Remove(requeued) - requeued.key.release() + requeued.key.release(t) requeued.key = nkey.clone() to.waiters.PushBack(requeued) requeued.bucket.Store(to) @@ -456,7 +459,7 @@ func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32 r := b.wakeLocked(&k, bitmask, n) b.mu.Unlock() - k.release() + k.release(t) return r, nil } @@ -465,12 +468,12 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch if err != nil { return 0, err } - defer k1.release() + defer k1.release(t) k2, err := getKey(t, naddr, private) if err != nil { return 0, err } - defer k2.release() + defer k2.release(t) b1, b2 := m.lockBuckets(&k1, &k2) defer b1.mu.Unlock() @@ -488,7 +491,7 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch done := b1.wakeLocked(&k1, ^uint32(0), nwake) // Requeue the number required. - b1.requeueLocked(b2, &k1, &k2, nreq) + b1.requeueLocked(t, b2, &k1, &k2, nreq) return done, nil } @@ -515,12 +518,12 @@ func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwak if err != nil { return 0, err } - defer k1.release() + defer k1.release(t) k2, err := getKey(t, addr2, private) if err != nil { return 0, err } - defer k2.release() + defer k2.release(t) b1, b2 := m.lockBuckets(&k1, &k2) defer b1.mu.Unlock() @@ -571,7 +574,7 @@ func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bo // Perform our atomic check. if err := check(t, addr, val); err != nil { b.mu.Unlock() - w.key.release() + w.key.release(t) return err } @@ -585,7 +588,7 @@ func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bo // WaitComplete must be called when a Waiter previously added by WaitPrepare is // no longer eligible to be woken. -func (m *Manager) WaitComplete(w *Waiter) { +func (m *Manager) WaitComplete(w *Waiter, t Target) { // Remove w from the bucket it's in. for { b := w.bucket.Load() @@ -617,7 +620,7 @@ func (m *Manager) WaitComplete(w *Waiter) { } // Release references held by the waiter. - w.key.release() + w.key.release(t) } // LockPI attempts to lock the futex following the Priority-inheritance futex @@ -648,13 +651,13 @@ func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, pri success, err := m.lockPILocked(w, t, addr, tid, b, try) if err != nil { - w.key.release() + w.key.release(t) b.mu.Unlock() return false, err } if success || try { // Release waiter if it's not going to be a wait. - w.key.release() + w.key.release(t) } b.mu.Unlock() return success, nil @@ -730,7 +733,7 @@ func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool err = m.unlockPILocked(t, addr, tid, b, &k) - k.release() + k.release(t) b.mu.Unlock() return err } diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go index 7c5c7665b..d0128c548 100644 --- a/pkg/sentry/kernel/futex/futex_test.go +++ b/pkg/sentry/kernel/futex/futex_test.go @@ -22,6 +22,7 @@ import ( "testing" "unsafe" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -29,28 +30,33 @@ import ( // testData implements the Target interface, and allows us to // treat the address passed for futex operations as an index in // a byte slice for testing simplicity. -type testData []byte +type testData struct { + context.Context + data []byte +} const sizeofInt32 = 4 func newTestData(size uint) testData { - return make([]byte, size) + return testData{ + data: make([]byte, size), + } } func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) { - val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t[addr])), new) + val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), new) return val, nil } func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) { - if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t[addr])), old, new) { + if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), old, new) { return old, nil } - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil + return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil } func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) { - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil + return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil } func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) { @@ -83,7 +89,7 @@ func TestFutexWake(t *testing.T) { // Start waiting for wakeup. w := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w) + defer m.WaitComplete(w, d) // Perform a wakeup. if n, err := m.Wake(d, 0, private, ^uint32(0), 1); err != nil || n != 1 { @@ -106,7 +112,7 @@ func TestFutexWakeBitmask(t *testing.T) { // Start waiting for wakeup. w := newPreparedTestWaiter(t, m, d, 0, private, 0, 0x0000ffff) - defer m.WaitComplete(w) + defer m.WaitComplete(w, d) // Perform a wakeup using the wrong bitmask. if n, err := m.Wake(d, 0, private, 0xffff0000, 1); err != nil || n != 0 { @@ -141,7 +147,7 @@ func TestFutexWakeTwo(t *testing.T) { var ws [3]*Waiter for i := range ws { ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) + defer m.WaitComplete(ws[i], d) } // Perform two wakeups. @@ -174,9 +180,9 @@ func TestFutexWakeUnrelated(t *testing.T) { // Start two waiters waiting for wakeup on different addresses. w1 := newPreparedTestWaiter(t, m, d, 0*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, 1*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Perform two wakeups on the second address. if n, err := m.Wake(d, 1*sizeofInt32, private, ^uint32(0), 2); err != nil || n != 1 { @@ -216,9 +222,9 @@ func TestWakeOpFirstNonEmpty(t *testing.T) { // Add two waiters on address 0. w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Perform 10 wakeups on address 0. if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 0, 0); err != nil || n != 2 { @@ -244,9 +250,9 @@ func TestWakeOpSecondNonEmpty(t *testing.T) { // Add two waiters on address sizeofInt32. w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Perform 10 wakeups on address sizeofInt32 (contingent on // d.Op(0), which should succeed). @@ -273,9 +279,9 @@ func TestWakeOpSecondNonEmptyFailingOp(t *testing.T) { // Add two waiters on address sizeofInt32. w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Perform 10 wakeups on address sizeofInt32 (contingent on // d.Op(1), which should fail). @@ -302,15 +308,15 @@ func TestWakeOpAllNonEmpty(t *testing.T) { // Add two waiters on address 0. w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Add two waiters on address sizeofInt32. w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3) + defer m.WaitComplete(w3, d) w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4) + defer m.WaitComplete(w4, d) // Perform 10 wakeups on address 0 (unconditionally), and 10 // wakeups on address sizeofInt32 (contingent on d.Op(0), which @@ -344,15 +350,15 @@ func TestWakeOpAllNonEmptyFailingOp(t *testing.T) { // Add two waiters on address 0. w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) + defer m.WaitComplete(w1, d) w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) + defer m.WaitComplete(w2, d) // Add two waiters on address sizeofInt32. w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3) + defer m.WaitComplete(w3, d) w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4) + defer m.WaitComplete(w4, d) // Perform 10 wakeups on address 0 (unconditionally), and 10 // wakeups on address sizeofInt32 (contingent on d.Op(1), which @@ -388,7 +394,7 @@ func TestWakeOpSameAddress(t *testing.T) { var ws [4]*Waiter for i := range ws { ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) + defer m.WaitComplete(ws[i], d) } // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup @@ -422,7 +428,7 @@ func TestWakeOpSameAddressFailingOp(t *testing.T) { var ws [4]*Waiter for i := range ws { ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) + defer m.WaitComplete(ws[i], d) } // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup @@ -472,7 +478,7 @@ func (t *testMutex) Lock() { for { // Attempt to grab the lock. if atomic.CompareAndSwapUint32( - (*uint32)(unsafe.Pointer(&t.d[t.a])), + (*uint32)(unsafe.Pointer(&t.d.data[t.a])), testMutexUnlocked, testMutexLocked) { // Lock held. @@ -490,7 +496,7 @@ func (t *testMutex) Lock() { panic("WaitPrepare returned unexpected error: " + err.Error()) } <-w.C - t.m.WaitComplete(w) + t.m.WaitComplete(w, t.d) } } @@ -498,7 +504,7 @@ func (t *testMutex) Lock() { // This will notify any waiters via the futex manager. func (t *testMutex) Unlock() { // Unlock. - atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d[t.a])), testMutexUnlocked) + atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d.data[t.a])), testMutexUnlocked) // Notify all waiters. t.m.Wake(t.d, t.a, true, ^uint32(0), math.MaxInt32) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 15dae0f5b..316df249d 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -376,7 +376,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.netlinkPorts = port.New() if VFS2Enabled { - if err := k.vfs.Init(); err != nil { + ctx := k.SupervisorContext() + if err := k.vfs.Init(ctx); err != nil { return fmt.Errorf("failed to initialize VFS: %v", err) } @@ -384,19 +385,19 @@ func (k *Kernel) Init(args InitKernelArgs) error { if err != nil { return fmt.Errorf("failed to create pipefs filesystem: %v", err) } - defer pipeFilesystem.DecRef() + defer pipeFilesystem.DecRef(ctx) pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{}) if err != nil { return fmt.Errorf("failed to create pipefs mount: %v", err) } k.pipeMount = pipeMount - tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(k.SupervisorContext(), &k.vfs, auth.NewRootCredentials(k.rootUserNamespace)) + tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(ctx, &k.vfs, auth.NewRootCredentials(k.rootUserNamespace)) if err != nil { return fmt.Errorf("failed to create tmpfs filesystem: %v", err) } - defer tmpfsFilesystem.DecRef() - defer tmpfsRoot.DecRef() + defer tmpfsFilesystem.DecRef(ctx) + defer tmpfsRoot.DecRef(ctx) shmMount, err := k.vfs.NewDisconnectedMount(tmpfsFilesystem, tmpfsRoot, &vfs.MountOptions{}) if err != nil { return fmt.Errorf("failed to create tmpfs mount: %v", err) @@ -407,7 +408,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { if err != nil { return fmt.Errorf("failed to create sockfs filesystem: %v", err) } - defer socketFilesystem.DecRef() + defer socketFilesystem.DecRef(ctx) socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{}) if err != nil { return fmt.Errorf("failed to create sockfs mount: %v", err) @@ -430,8 +431,8 @@ func (k *Kernel) SaveTo(w wire.Writer) error { defer k.extMu.Unlock() // Stop time. - k.pauseTimeLocked() - defer k.resumeTimeLocked() + k.pauseTimeLocked(ctx) + defer k.resumeTimeLocked(ctx) // Evict all evictable MemoryFile allocations. k.mf.StartEvictions() @@ -447,12 +448,12 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // Remove all epoll waiter objects from underlying wait queues. // NOTE: for programs to resume execution in future snapshot scenarios, // we will need to re-establish these waiter objects after saving. - k.tasks.unregisterEpollWaiters() + k.tasks.unregisterEpollWaiters(ctx) // Clear the dirent cache before saving because Dirents must be Loaded in a // particular order (parents before children), and Loading dirents from a cache // breaks that order. - if err := k.flushMountSourceRefs(); err != nil { + if err := k.flushMountSourceRefs(ctx); err != nil { return err } @@ -505,7 +506,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. -func (k *Kernel) flushMountSourceRefs() error { +func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() @@ -521,7 +522,7 @@ func (k *Kernel) flushMountSourceRefs() error { // There may be some open FDs whose filesystems have been unmounted. We // must flush those as well. - return k.tasks.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error { + return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { file.Dirent.Inode.MountSource.FlushDirentRefs() return nil }) @@ -531,7 +532,7 @@ func (k *Kernel) flushMountSourceRefs() error { // each task. // // Precondition: Must be called with the kernel paused. -func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) (err error) { +func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.FileDescription) error) (err error) { // TODO(gvisor.dev/issue/1663): Add save support for VFS2. if VFS2Enabled { return nil @@ -544,7 +545,7 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) if t.fdTable == nil { continue } - t.fdTable.forEach(func(_ int32, file *fs.File, fileVFS2 *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fileVFS2 *vfs.FileDescription, _ FDFlags) { if lastErr := f(file, fileVFS2); lastErr != nil && err == nil { err = lastErr } @@ -555,7 +556,7 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - return ts.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error { + return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil } @@ -602,7 +603,7 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { return nil } -func (ts *TaskSet) unregisterEpollWaiters() { +func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { // TODO(gvisor.dev/issue/1663): Add save support for VFS2. if VFS2Enabled { return @@ -623,7 +624,7 @@ func (ts *TaskSet) unregisterEpollWaiters() { if _, ok := processed[t.fdTable]; ok { continue } - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { if e, ok := file.FileOperations.(*epoll.EventPoll); ok { e.UnregisterEpollWaiters() } @@ -900,7 +901,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, root := args.MountNamespaceVFS2.Root() // The call to newFSContext below will take a reference on root, so we // don't need to hold this one. - defer root.DecRef() + defer root.DecRef(ctx) // Grab the working directory. wd := root // Default. @@ -918,7 +919,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, if err != nil { return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err) } - defer wd.DecRef() + defer wd.DecRef(ctx) } opener = fsbridge.NewVFSLookup(mntnsVFS2, root, wd) fsContext = NewFSContextVFS2(root, wd, args.Umask) @@ -933,7 +934,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, root := mntns.Root() // The call to newFSContext below will take a reference on root, so we // don't need to hold this one. - defer root.DecRef() + defer root.DecRef(ctx) // Grab the working directory. remainingTraversals := args.MaxSymlinkTraversals @@ -944,7 +945,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, if err != nil { return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err) } - defer wd.DecRef() + defer wd.DecRef(ctx) } opener = fsbridge.NewFSLookup(mntns, root, wd) fsContext = newFSContext(root, wd, args.Umask) @@ -1054,7 +1055,7 @@ func (k *Kernel) Start() error { // If k was created by LoadKernelFrom, timers were stopped during // Kernel.SaveTo and need to be resumed. If k was created by NewKernel, // this is a no-op. - k.resumeTimeLocked() + k.resumeTimeLocked(k.SupervisorContext()) // Start task goroutines. k.tasks.mu.RLock() defer k.tasks.mu.RUnlock() @@ -1068,7 +1069,7 @@ func (k *Kernel) Start() error { // // Preconditions: Any task goroutines running in k must be stopped. k.extMu // must be locked. -func (k *Kernel) pauseTimeLocked() { +func (k *Kernel) pauseTimeLocked(ctx context.Context) { // k.cpuClockTicker may be nil since Kernel.SaveTo() may be called before // Kernel.Start(). if k.cpuClockTicker != nil { @@ -1090,7 +1091,7 @@ func (k *Kernel) pauseTimeLocked() { // This means we'll iterate FDTables shared by multiple tasks repeatedly, // but ktime.Timer.Pause is idempotent so this is harmless. if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { if VFS2Enabled { if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok { tfd.PauseTimer() @@ -1112,7 +1113,7 @@ func (k *Kernel) pauseTimeLocked() { // // Preconditions: Any task goroutines running in k must be stopped. k.extMu // must be locked. -func (k *Kernel) resumeTimeLocked() { +func (k *Kernel) resumeTimeLocked(ctx context.Context) { if k.cpuClockTicker != nil { k.cpuClockTicker.Resume() } @@ -1126,7 +1127,7 @@ func (k *Kernel) resumeTimeLocked() { } } if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { if VFS2Enabled { if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok { tfd.ResumeTimer() @@ -1511,7 +1512,7 @@ type SocketEntry struct { } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *SocketEntry) WeakRefGone() { +func (s *SocketEntry) WeakRefGone(context.Context) { s.k.extMu.Lock() s.k.sockets.Remove(s) s.k.extMu.Unlock() @@ -1600,7 +1601,7 @@ func (ctx supervisorContext) Value(key interface{}) interface{} { return vfs.VirtualDentry{} } mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2() - defer mntns.DecRef() + defer mntns.DecRef(ctx) // Root() takes a reference on the root dirent for us. return mntns.Root() case vfs.CtxMountNamespace: diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 4b688c627..6497dc4ba 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -93,7 +93,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() { if !waitFor(&i.mu, &i.wWakeup, ctx) { - r.DecRef() + r.DecRef(ctx) return nil, syserror.ErrInterrupted } } @@ -111,12 +111,12 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi // On a nonblocking, write-only open, the open fails with ENXIO if the // read side isn't open yet. if flags.NonBlocking { - w.DecRef() + w.DecRef(ctx) return nil, syserror.ENXIO } if !waitFor(&i.mu, &i.rWakeup, ctx) { - w.DecRef() + w.DecRef(ctx) return nil, syserror.ErrInterrupted } } diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index ab75a87ff..ce0db5583 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -167,7 +167,7 @@ func TestClosedReaderBlocksWriteOpen(t *testing.T) { f := NewInodeOperations(ctx, perms, newNamedPipe(t)) rFile, _ := testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil) - rFile.DecRef() + rFile.DecRef(ctx) wDone := make(chan struct{}) // This open for write should block because the reader is now gone. diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 79645d7d2..297e8f28f 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -152,7 +152,7 @@ func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs. d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino)) // The p.Open calls below will each take a reference on the Dirent. We // must drop the one we already have. - defer d.DecRef() + defer d.DecRef(ctx) return p.Open(ctx, d, fs.FileFlags{Read: true}), p.Open(ctx, d, fs.FileFlags{Write: true}) } diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go index bda739dbe..fe97e9800 100644 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ b/pkg/sentry/kernel/pipe/pipe_test.go @@ -27,8 +27,8 @@ import ( func TestPipeRW(t *testing.T) { ctx := contexttest.Context(t) r, w := NewConnectedPipe(ctx, 65536, 4096) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(ctx) + defer w.DecRef(ctx) msg := []byte("here's some bytes") wantN := int64(len(msg)) @@ -47,8 +47,8 @@ func TestPipeRW(t *testing.T) { func TestPipeReadBlock(t *testing.T) { ctx := contexttest.Context(t) r, w := NewConnectedPipe(ctx, 65536, 4096) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(ctx) + defer w.DecRef(ctx) n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1))) if n != 0 || err != syserror.ErrWouldBlock { @@ -62,8 +62,8 @@ func TestPipeWriteBlock(t *testing.T) { ctx := contexttest.Context(t) r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(ctx) + defer w.DecRef(ctx) msg := make([]byte, capacity+1) n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) @@ -77,8 +77,8 @@ func TestPipeWriteUntilEnd(t *testing.T) { ctx := contexttest.Context(t) r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(ctx) + defer w.DecRef(ctx) msg := []byte("here's some bytes") diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index aacf28da2..6d58b682f 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -33,7 +33,7 @@ import ( // the old fs architecture. // Release cleans up the pipe's state. -func (p *Pipe) Release() { +func (p *Pipe) Release(context.Context) { p.rClose() p.wClose() diff --git a/pkg/sentry/kernel/pipe/reader.go b/pkg/sentry/kernel/pipe/reader.go index 7724b4452..ac18785c0 100644 --- a/pkg/sentry/kernel/pipe/reader.go +++ b/pkg/sentry/kernel/pipe/reader.go @@ -15,6 +15,7 @@ package pipe import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/waiter" ) @@ -29,7 +30,7 @@ type Reader struct { // Release implements fs.FileOperations.Release. // // This overrides ReaderWriter.Release. -func (r *Reader) Release() { +func (r *Reader) Release(context.Context) { r.Pipe.rClose() // Wake up writers. diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 45d4c5fc1..28f998e45 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -101,7 +101,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s // If this pipe is being opened as blocking and there's no // writer, we have to wait for a writer to open the other end. if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.EINTR } @@ -112,12 +112,12 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s // Non-blocking, write-only opens fail with ENXIO when the read // side isn't open yet. if statusFlags&linux.O_NONBLOCK != 0 { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.ENXIO } // Wait for a reader to open the other end. if !waitFor(&vp.mu, &vp.rWakeup, ctx) { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.EINTR } } @@ -169,7 +169,7 @@ type VFSPipeFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *VFSPipeFD) Release() { +func (fd *VFSPipeFD) Release(context.Context) { var event waiter.EventMask if fd.vfsfd.IsReadable() { fd.pipe.rClose() diff --git a/pkg/sentry/kernel/pipe/writer.go b/pkg/sentry/kernel/pipe/writer.go index 5bc6aa931..ef4b70ca3 100644 --- a/pkg/sentry/kernel/pipe/writer.go +++ b/pkg/sentry/kernel/pipe/writer.go @@ -15,6 +15,7 @@ package pipe import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/waiter" ) @@ -29,7 +30,7 @@ type Writer struct { // Release implements fs.FileOperations.Release. // // This overrides ReaderWriter.Release. -func (w *Writer) Release() { +func (w *Writer) Release(context.Context) { w.Pipe.wClose() // Wake up readers. diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 0e19286de..5c4c622c2 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -16,6 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" @@ -70,7 +71,7 @@ func (s *Session) incRef() { // // Precondition: callers must hold TaskSet.mu for writing. func (s *Session) decRef() { - s.refs.DecRefWithDestructor(func() { + s.refs.DecRefWithDestructor(nil, func(context.Context) { // Remove translations from the leader. for ns := s.leader.pidns; ns != nil; ns = ns.parent { id := ns.sids[s] @@ -162,7 +163,7 @@ func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) { } alive := true - pg.refs.DecRefWithDestructor(func() { + pg.refs.DecRefWithDestructor(nil, func(context.Context) { alive = false // don't bother with handleOrphan. // Remove translations from the originator. diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 55b4c2cdb..13ec7afe0 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -431,8 +431,8 @@ func (s *Shm) InodeID() uint64 { // DecRef overrides refs.RefCount.DecRef with a destructor. // // Precondition: Caller must not hold s.mu. -func (s *Shm) DecRef() { - s.DecRefWithDestructor(s.destroy) +func (s *Shm) DecRef(ctx context.Context) { + s.DecRefWithDestructor(ctx, s.destroy) } // Msync implements memmap.MappingIdentity.Msync. Msync is a no-op for shm @@ -642,7 +642,7 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error { return nil } -func (s *Shm) destroy() { +func (s *Shm) destroy(context.Context) { s.mfp.MemoryFile().DecRef(s.fr) s.registry.remove(s) } @@ -651,7 +651,7 @@ func (s *Shm) destroy() { // destroyed once it has no references. MarkDestroyed may be called multiple // times, and is safe to call after a segment has already been destroyed. See // shmctl(IPC_RMID). -func (s *Shm) MarkDestroyed() { +func (s *Shm) MarkDestroyed(ctx context.Context) { s.registry.dissociateKey(s) s.mu.Lock() @@ -663,7 +663,7 @@ func (s *Shm) MarkDestroyed() { // // N.B. This cannot be the final DecRef, as the caller also // holds a reference. - s.DecRef() + s.DecRef(ctx) return } } diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 8243bb93e..b07e1c1bd 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -76,7 +76,7 @@ func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) { } // Release implements fs.FileOperations.Release. -func (s *SignalOperations) Release() {} +func (s *SignalOperations) Release(context.Context) {} // Mask returns the signal mask. func (s *SignalOperations) Mask() linux.SignalSet { diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index c4db05bd8..5aee699e7 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -730,17 +730,17 @@ func (t *Task) SyscallRestartBlock() SyscallRestartBlock { func (t *Task) IsChrooted() bool { if VFS2Enabled { realRoot := t.mountNamespaceVFS2.Root() - defer realRoot.DecRef() + defer realRoot.DecRef(t) root := t.fsContext.RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) return root != realRoot } realRoot := t.tg.mounts.Root() - defer realRoot.DecRef() + defer realRoot.DecRef(t) root := t.fsContext.RootDirectory() if root != nil { - defer root.DecRef() + defer root.DecRef(t) } return root != realRoot } diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index e1ecca99e..fe6ba6041 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -237,7 +237,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { var fdTable *FDTable if opts.NewFiles { - fdTable = t.fdTable.Fork() + fdTable = t.fdTable.Fork(t) } else { fdTable = t.fdTable fdTable.IncRef() @@ -294,7 +294,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { nt, err := t.tg.pidns.owner.NewTask(cfg) if err != nil { if opts.NewThreadGroup { - tg.release() + tg.release(t) } return 0, nil, err } @@ -510,7 +510,7 @@ func (t *Task) Unshare(opts *SharingOptions) error { var oldFDTable *FDTable if opts.NewFiles { oldFDTable = t.fdTable - t.fdTable = oldFDTable.Fork() + t.fdTable = oldFDTable.Fork(t) } var oldFSContext *FSContext if opts.NewFSContext { @@ -519,10 +519,10 @@ func (t *Task) Unshare(opts *SharingOptions) error { } t.mu.Unlock() if oldFDTable != nil { - oldFDTable.DecRef() + oldFDTable.DecRef(t) } if oldFSContext != nil { - oldFSContext.DecRef() + oldFSContext.DecRef(t) } return nil } diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 7803b98d0..47c28b8ff 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -199,11 +199,11 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.tg.pidns.owner.mu.Unlock() oldFDTable := t.fdTable - t.fdTable = t.fdTable.Fork() - oldFDTable.DecRef() + t.fdTable = t.fdTable.Fork(t) + oldFDTable.DecRef(t) // Remove FDs with the CloseOnExec flag set. - t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool { + t.fdTable.RemoveIf(t, func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool { return flags.CloseOnExec }) diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 231ac548a..c165d6cb1 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -269,12 +269,12 @@ func (*runExitMain) execute(t *Task) taskRunState { // Releasing the MM unblocks a blocked CLONE_VFORK parent. t.unstopVforkParent() - t.fsContext.DecRef() - t.fdTable.DecRef() + t.fsContext.DecRef(t) + t.fdTable.DecRef(t) t.mu.Lock() if t.mountNamespaceVFS2 != nil { - t.mountNamespaceVFS2.DecRef() + t.mountNamespaceVFS2.DecRef(t) t.mountNamespaceVFS2 = nil } t.mu.Unlock() @@ -282,7 +282,7 @@ func (*runExitMain) execute(t *Task) taskRunState { // If this is the last task to exit from the thread group, release the // thread group's resources. if lastExiter { - t.tg.release() + t.tg.release(t) } // Detach tracees. diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index eeccaa197..ab86ceedc 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -203,6 +203,6 @@ func (t *Task) traceExecEvent(tc *TaskContext) { trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>") return } - defer file.DecRef() + defer file.DecRef(t) trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t)) } diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 8485fb4b6..64c1e120a 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -102,10 +102,10 @@ func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) { t, err := ts.newTask(cfg) if err != nil { cfg.TaskContext.release() - cfg.FSContext.DecRef() - cfg.FDTable.DecRef() + cfg.FSContext.DecRef(t) + cfg.FDTable.DecRef(t) if cfg.MountNamespaceVFS2 != nil { - cfg.MountNamespaceVFS2.DecRef() + cfg.MountNamespaceVFS2.DecRef(t) } return nil, err } diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 4dfd2c990..0b34c0099 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -308,7 +308,7 @@ func (tg *ThreadGroup) Limits() *limits.LimitSet { } // release releases the thread group's resources. -func (tg *ThreadGroup) release() { +func (tg *ThreadGroup) release(t *Task) { // Timers must be destroyed without holding the TaskSet or signal mutexes // since timers send signals with Timer.mu locked. tg.itimerRealTimer.Destroy() @@ -325,7 +325,7 @@ func (tg *ThreadGroup) release() { it.DestroyTimer() } if tg.mounts != nil { - tg.mounts.DecRef() + tg.mounts.DecRef(t) } } diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index ddeaff3db..20dd1cc21 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -281,7 +281,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr } defer func() { if mopts.MappingIdentity != nil { - mopts.MappingIdentity.DecRef() + mopts.MappingIdentity.DecRef(ctx) } }() if err := f.ConfigureMMap(ctx, &mopts); err != nil { @@ -663,7 +663,7 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err) return loadedELF{}, nil, err } - defer intFile.DecRef() + defer intFile.DecRef(ctx) interp, err = loadInterpreterELF(ctx, args.MemoryManager, intFile, bin) if err != nil { diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 986c7fb4d..8d6802ea3 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -154,7 +154,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context return loadedELF{}, nil, nil, nil, err } // Ensure file is release in case the code loops or errors out. - defer args.File.DecRef() + defer args.File.DecRef(ctx) } else { if err := checkIsRegularFile(ctx, args.File, args.Filename); err != nil { return loadedELF{}, nil, nil, nil, err @@ -223,7 +223,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } - defer file.DecRef() + defer file.DecRef(ctx) // Load the VDSO. vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) @@ -292,7 +292,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V m.SetEnvvStart(sl.EnvvStart) m.SetEnvvEnd(sl.EnvvEnd) m.SetAuxv(auxv) - m.SetExecutable(file) + m.SetExecutable(ctx, file) symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn") if err != nil { diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index c188f6c29..59c92c7e8 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -238,7 +238,7 @@ type MappingIdentity interface { IncRef() // DecRef decrements the MappingIdentity's reference count. - DecRef() + DecRef(ctx context.Context) // MappedName returns the application-visible name shown in // /proc/[pid]/maps. diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 1999ec706..16fea53c4 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -258,8 +258,8 @@ func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) { } // DecRef implements refs.RefCounter.DecRef. -func (m *aioMappable) DecRef() { - m.AtomicRefCount.DecRefWithDestructor(func() { +func (m *aioMappable) DecRef(ctx context.Context) { + m.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) { m.mfp.MemoryFile().DecRef(m.fr) }) } @@ -367,7 +367,7 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint if err != nil { return 0, err } - defer m.DecRef() + defer m.DecRef(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ Length: aioRingBufferSize, MappingIdentity: m, diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index aac56679b..4d7773f8b 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -258,7 +258,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) { mm.executable = nil mm.metadataMu.Unlock() if exe != nil { - exe.DecRef() + exe.DecRef(ctx) } mm.activeMu.Lock() diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go index 28e5057f7..0cfd60f6c 100644 --- a/pkg/sentry/mm/metadata.go +++ b/pkg/sentry/mm/metadata.go @@ -15,6 +15,7 @@ package mm import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/usermem" @@ -147,7 +148,7 @@ func (mm *MemoryManager) Executable() fsbridge.File { // SetExecutable sets the executable. // // This takes a reference on d. -func (mm *MemoryManager) SetExecutable(file fsbridge.File) { +func (mm *MemoryManager) SetExecutable(ctx context.Context, file fsbridge.File) { mm.metadataMu.Lock() // Grab a new reference. @@ -164,7 +165,7 @@ func (mm *MemoryManager) SetExecutable(file fsbridge.File) { // Do this without holding the lock, since it may wind up doing some // I/O to sync the dirent, etc. if orig != nil { - orig.DecRef() + orig.DecRef(ctx) } } diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 0e142fb11..4cdb52eb6 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -50,8 +50,8 @@ func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.F } // DecRef implements refs.RefCounter.DecRef. -func (m *SpecialMappable) DecRef() { - m.AtomicRefCount.DecRefWithDestructor(func() { +func (m *SpecialMappable) DecRef(ctx context.Context) { + m.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) { m.mfp.MemoryFile().DecRef(m.fr) }) } diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 3f496aa9f..e74d4e1c1 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -101,7 +101,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme if err != nil { return 0, err } - defer m.DecRef() + defer m.DecRef(ctx) opts.MappingIdentity = m opts.Mappable = m } @@ -1191,7 +1191,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length ui mr := vseg.mappableRangeOf(vseg.Range().Intersect(ar)) mm.mappingMu.RUnlock() err := id.Msync(ctx, mr) - id.DecRef() + id.DecRef(ctx) if err != nil { return err } diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 16d8207e9..bd751d696 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -377,7 +377,7 @@ func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRa vma.mappable.RemoveMapping(ctx, mm, vmaAR, vma.off, vma.canWriteMappableLocked()) } if vma.id != nil { - vma.id.DecRef() + vma.id.DecRef(ctx) } mm.usageAS -= uint64(vmaAR.Length()) if vma.isPrivateDataLocked() { @@ -446,7 +446,7 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa } if vma2.id != nil { - vma2.id.DecRef() + vma2.id.DecRef(context.Background()) } return vma1, true } diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 8b439a078..70ccf77a7 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -68,7 +68,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { for _, fd := range fds { file := t.GetFile(fd) if file == nil { - files.Release() + files.Release(t) return nil, syserror.EBADF } files = append(files, file) @@ -100,9 +100,9 @@ func (fs *RightsFiles) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (fs *RightsFiles) Release() { +func (fs *RightsFiles) Release(ctx context.Context) { for _, f := range *fs { - f.DecRef() + f.DecRef(ctx) } *fs = nil } @@ -115,7 +115,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32 fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{ CloseOnExec: cloexec, }) - files[0].DecRef() + files[0].DecRef(t) files = files[1:] if err != nil { t.Warningf("Error inserting FD: %v", err) diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index fd08179be..d9621968c 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -46,7 +46,7 @@ func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) { for _, fd := range fds { file := t.GetFileVFS2(fd) if file == nil { - files.Release() + files.Release(t) return nil, syserror.EBADF } files = append(files, file) @@ -78,9 +78,9 @@ func (fs *RightsFilesVFS2) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (fs *RightsFilesVFS2) Release() { +func (fs *RightsFilesVFS2) Release(ctx context.Context) { for _, f := range *fs { - f.DecRef() + f.DecRef(ctx) } *fs = nil } @@ -93,7 +93,7 @@ func rightsFDsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, max int) fd, err := t.NewFDFromVFS2(0, files[0], kernel.FDFlags{ CloseOnExec: cloexec, }) - files[0].DecRef() + files[0].DecRef(t) files = files[1:] if err != nil { t.Warningf("Error inserting FD: %v", err) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 532a1ea5d..242e6bf76 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -100,12 +100,12 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc return nil, syserr.FromError(err) } dirent := socket.NewDirent(ctx, socketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { +func (s *socketOpsCommon) Release(context.Context) { fdnotifier.RemoveFD(int32(s.fd)) syscall.Close(s.fd) } @@ -269,7 +269,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, syscall.Close(fd) return 0, nil, 0, err } - defer f.DecRef() + defer f.DecRef(t) kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, @@ -281,7 +281,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, syscall.Close(fd) return 0, nil, 0, err } - defer f.DecRef() + defer f.DecRef(t) kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 0d45e5053..31e374833 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -97,7 +97,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int } d := socket.NewDirent(t, netlinkSocketDevice) - defer d.DecRef() + defer d.DecRef(t) return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, s), nil } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 98ca7add0..68a9b9a96 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -140,14 +140,14 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke // Bind the endpoint for good measure so we can connect to it. The // bound address will never be exposed. if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { - ep.Close() + ep.Close(t) return nil, err } // Create a connection from which the kernel can write messages. connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) if err != nil { - ep.Close() + ep.Close(t) return nil, err } @@ -164,9 +164,9 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { - s.connection.Release() - s.ep.Close() +func (s *socketOpsCommon) Release(ctx context.Context) { + s.connection.Release(ctx) + s.ep.Close(ctx) if s.bound { s.ports.Release(s.protocol.Protocol(), s.portID) @@ -621,7 +621,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(ctx, bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -648,7 +648,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys // Add the dump_done_errno payload. m.Put(int64(0)) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index dbcd8b49a..a38d25da9 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -57,14 +57,14 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV // Bind the endpoint for good measure so we can connect to it. The // bound address will never be exposed. if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { - ep.Close() + ep.Close(t) return nil, err } // Create a connection from which the kernel can write messages. connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) if err != nil { - ep.Close() + ep.Close(t) return nil, err } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 31a168f7e..e4846bc0b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -330,7 +330,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue } dirent := socket.NewDirent(t, netstackDevice) - defer dirent.DecRef() + defer dirent.DecRef(t) return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ socketOpsCommon: socketOpsCommon{ Queue: queue, @@ -479,7 +479,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { +func (s *socketOpsCommon) Release(context.Context) { s.Endpoint.Close() } @@ -854,7 +854,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index a9025b0ec..3335e7430 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -169,7 +169,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil { return 0, nil, 0, syserr.FromError(err) diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index d112757fb..04b259d27 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -46,8 +46,8 @@ type ControlMessages struct { } // Release releases Unix domain socket credentials and rights. -func (c *ControlMessages) Release() { - c.Unix.Release() +func (c *ControlMessages) Release(ctx context.Context) { + c.Unix.Release(ctx) } // Socket is an interface combining fs.FileOperations and SocketOps, diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index a1e49cc57..c67b602f0 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -211,7 +211,7 @@ func (e *connectionedEndpoint) Listening() bool { // The socket will be a fresh state after a call to close and may be reused. // That is, close may be used to "unbind" or "disconnect" the socket in error // paths. -func (e *connectionedEndpoint) Close() { +func (e *connectionedEndpoint) Close(ctx context.Context) { e.Lock() var c ConnectedEndpoint var r Receiver @@ -233,7 +233,7 @@ func (e *connectionedEndpoint) Close() { case e.Listening(): close(e.acceptedChan) for n := range e.acceptedChan { - n.Close() + n.Close(ctx) } e.acceptedChan = nil e.path = "" @@ -241,11 +241,11 @@ func (e *connectionedEndpoint) Close() { e.Unlock() if c != nil { c.CloseNotify() - c.Release() + c.Release(ctx) } if r != nil { r.CloseNotify() - r.Release() + r.Release(ctx) } } @@ -340,7 +340,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn return nil default: // Busy; return ECONNREFUSED per spec. - ne.Close() + ne.Close(ctx) e.Unlock() ce.Unlock() return syserr.ErrConnectionRefused diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 4b06d63ac..70ee8f9b8 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -54,10 +54,10 @@ func (e *connectionlessEndpoint) isBound() bool { // Close puts the endpoint in a closed state and frees all resources associated // with it. -func (e *connectionlessEndpoint) Close() { +func (e *connectionlessEndpoint) Close(ctx context.Context) { e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) e.connected = nil } @@ -71,7 +71,7 @@ func (e *connectionlessEndpoint) Close() { e.Unlock() r.CloseNotify() - r.Release() + r.Release(ctx) } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. @@ -108,10 +108,10 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C if err != nil { return 0, syserr.ErrInvalidEndpointState } - defer connected.Release() + defer connected.Release(ctx) e.Lock() - n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -135,7 +135,7 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) } e.connected = connected e.Unlock() diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index d8f3ad63d..ef6043e19 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,6 +15,7 @@ package transport import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" @@ -57,10 +58,10 @@ func (q *queue) Close() { // Both the read and write queues must be notified after resetting: // q.ReaderQueue.Notify(waiter.EventIn) // q.WriterQueue.Notify(waiter.EventOut) -func (q *queue) Reset() { +func (q *queue) Reset(ctx context.Context) { q.mu.Lock() for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { - cur.Release() + cur.Release(ctx) } q.dataList.Reset() q.used = 0 @@ -68,8 +69,8 @@ func (q *queue) Reset() { } // DecRef implements RefCounter.DecRef with destructor q.Reset. -func (q *queue) DecRef() { - q.DecRefWithDestructor(q.Reset) +func (q *queue) DecRef(ctx context.Context) { + q.DecRefWithDestructor(ctx, q.Reset) // We don't need to notify after resetting because no one cares about // this queue after all references have been dropped. } @@ -111,7 +112,7 @@ func (q *queue) IsWritable() bool { // // If notify is true, ReaderQueue.Notify must be called: // q.ReaderQueue.Notify(waiter.EventIn) -func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) { +func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) { q.mu.Lock() if q.closed { @@ -124,7 +125,7 @@ func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress } if discardEmpty && l == 0 { q.mu.Unlock() - c.Release() + c.Release(ctx) return 0, false, nil } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2f1b127df..475d7177e 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -37,7 +37,7 @@ type RightsControlMessage interface { Clone() RightsControlMessage // Release releases any resources owned by the RightsControlMessage. - Release() + Release(ctx context.Context) } // A CredentialsControlMessage is a control message containing Unix credentials. @@ -74,9 +74,9 @@ func (c *ControlMessages) Clone() ControlMessages { } // Release releases both the credentials and the rights. -func (c *ControlMessages) Release() { +func (c *ControlMessages) Release(ctx context.Context) { if c.Rights != nil { - c.Rights.Release() + c.Rights.Release(ctx) } *c = ControlMessages{} } @@ -90,7 +90,7 @@ type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources // associated with it. - Close() + Close(ctx context.Context) // RecvMsg reads data and a control message from the endpoint. This method // does not block if there is no data pending. @@ -252,7 +252,7 @@ type BoundEndpoint interface { // Release releases any resources held by the BoundEndpoint. It must be // called before dropping all references to a BoundEndpoint returned by a // function. - Release() + Release(ctx context.Context) } // message represents a message passed over a Unix domain socket. @@ -281,8 +281,8 @@ func (m *message) Length() int64 { } // Release releases any resources held by the message. -func (m *message) Release() { - m.Control.Release() +func (m *message) Release(ctx context.Context) { + m.Control.Release(ctx) } // Peek returns a copy of the message. @@ -304,7 +304,7 @@ type Receiver interface { // See Endpoint.RecvMsg for documentation on shared arguments. // // notify indicates if RecvNotify should be called. - Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) + Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) // RecvNotify notifies the Receiver of a successful Recv. This must not be // called while holding any endpoint locks. @@ -333,7 +333,7 @@ type Receiver interface { // Release releases any resources owned by the Receiver. It should be // called before droping all references to a Receiver. - Release() + Release(ctx context.Context) } // queueReceiver implements Receiver for datagram sockets. @@ -344,7 +344,7 @@ type queueReceiver struct { } // Recv implements Receiver.Recv. -func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *queueReceiver) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { var m *message var notify bool var err *syserr.Error @@ -398,8 +398,8 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 { } // Release implements Receiver.Release. -func (q *queueReceiver) Release() { - q.readQueue.DecRef() +func (q *queueReceiver) Release(ctx context.Context) { + q.readQueue.DecRef(ctx) } // streamQueueReceiver implements Receiver for stream sockets. @@ -456,7 +456,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { } // Recv implements Receiver.Recv. -func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { q.mu.Lock() defer q.mu.Unlock() @@ -502,7 +502,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, var cmTruncated bool if c.Rights != nil && numRights == 0 { - c.Rights.Release() + c.Rights.Release(ctx) c.Rights = nil cmTruncated = true } @@ -557,7 +557,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, // Consume rights. if numRights == 0 { cmTruncated = true - q.control.Rights.Release() + q.control.Rights.Release(ctx) } else { c.Rights = q.control.Rights haveRights = true @@ -582,7 +582,7 @@ type ConnectedEndpoint interface { // // syserr.ErrWouldBlock can be returned along with a partial write if // the caller should block to send the rest of the data. - Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) + Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) // SendNotify notifies the ConnectedEndpoint of a successful Send. This // must not be called while holding any endpoint locks. @@ -616,7 +616,7 @@ type ConnectedEndpoint interface { // Release releases any resources owned by the ConnectedEndpoint. It should // be called before droping all references to a ConnectedEndpoint. - Release() + Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to // the peer socket. @@ -654,7 +654,7 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) } // Send implements ConnectedEndpoint.Send. -func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { +func (e *connectedEndpoint) Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { discardEmpty := false truncate := false if e.endpoint.Type() == linux.SOCK_STREAM { @@ -669,7 +669,7 @@ func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.Fu truncate = true } - return e.writeQueue.Enqueue(data, c, from, discardEmpty, truncate) + return e.writeQueue.Enqueue(ctx, data, c, from, discardEmpty, truncate) } // SendNotify implements ConnectedEndpoint.SendNotify. @@ -707,8 +707,8 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 { } // Release implements ConnectedEndpoint.Release. -func (e *connectedEndpoint) Release() { - e.writeQueue.DecRef() +func (e *connectedEndpoint) Release(ctx context.Context) { + e.writeQueue.DecRef(ctx) } // CloseUnread implements ConnectedEndpoint.CloseUnread. @@ -798,7 +798,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected } - recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(ctx, data, creds, numRights, peek) e.Unlock() if err != nil { return 0, 0, ControlMessages{}, false, err @@ -827,7 +827,7 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return 0, syserr.ErrAlreadyConnected } - n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := e.connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -1001,6 +1001,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } // Release implements BoundEndpoint.Release. -func (*baseEndpoint) Release() { +func (*baseEndpoint) Release(context.Context) { // Binding a baseEndpoint doesn't take a reference. } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 0482d33cf..2b8454edb 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -62,7 +62,7 @@ type SocketOperations struct { // New creates a new unix socket. func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true}) } @@ -97,17 +97,17 @@ type socketOpsCommon struct { } // DecRef implements RefCounter.DecRef. -func (s *socketOpsCommon) DecRef() { - s.DecRefWithDestructor(func() { - s.ep.Close() +func (s *socketOpsCommon) DecRef(ctx context.Context) { + s.DecRefWithDestructor(ctx, func(context.Context) { + s.ep.Close(ctx) }) } // Release implemements fs.FileOperations.Release. -func (s *socketOpsCommon) Release() { +func (s *socketOpsCommon) Release(ctx context.Context) { // Release only decrements a reference on s because s may be referenced in // the abstract socket namespace. - s.DecRef() + s.DecRef(ctx) } func (s *socketOpsCommon) isPacket() bool { @@ -234,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } ns := New(t, ep, s.stype) - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -284,7 +284,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } @@ -294,7 +294,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { var name string cwd := t.FSContext().WorkingDirectory() - defer cwd.DecRef() + defer cwd.DecRef(t) // Is there no slash at all? if !strings.Contains(p, "/") { @@ -302,7 +302,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { name = p } else { root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) // Find the last path component, we know that something follows // that final slash, otherwise extractPath() would have failed. lastSlash := strings.LastIndex(p, "/") @@ -318,7 +318,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // No path available. return syserr.ErrNoSuchFile } - defer d.DecRef() + defer d.DecRef(t) name = p[lastSlash+1:] } @@ -332,7 +332,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if err != nil { return syserr.ErrPortInUse } - childDir.DecRef() + childDir.DecRef(t) } return nil @@ -378,9 +378,9 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, FollowFinalSymlink: true, } ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path}) - root.DecRef() + root.DecRef(t) if relPath { - start.DecRef() + start.DecRef(t) } if e != nil { return nil, syserr.FromError(e) @@ -393,15 +393,15 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, cwd := t.FSContext().WorkingDirectory() remainingTraversals := uint(fs.DefaultTraversalLimit) d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals) - cwd.DecRef() - root.DecRef() + cwd.DecRef(t) + root.DecRef(t) if e != nil { return nil, syserr.FromError(e) } // Extract the endpoint if one is there. ep := d.Inode.BoundEndpoint(path) - d.DecRef() + d.DecRef(t) if ep == nil { // No socket! return nil, syserr.ErrConnectionRefused @@ -415,7 +415,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool if err != nil { return err } - defer ep.Release() + defer ep.Release(t) // Connect the server endpoint. err = s.ep.Connect(t, ep) @@ -473,7 +473,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if err != nil { return 0, err } - defer ep.Release() + defer ep.Release(t) w.To = ep if ep.Passcred() && w.Control.Credentials == nil { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 05c16fcfe..dfa25241a 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -136,7 +136,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK) @@ -183,19 +183,19 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } } else { path := fspath.Parse(p) root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root relPath := !path.Absolute if relPath { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } pop := vfs.PathOperation{ Root: root, @@ -333,7 +333,7 @@ func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) f, err := NewSockfsFile(t, ep, stype) if err != nil { - ep.Close() + ep.Close(t) return nil, err } return f, nil @@ -357,14 +357,14 @@ func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (* ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) s1, err := NewSockfsFile(t, ep1, stype) if err != nil { - ep1.Close() - ep2.Close() + ep1.Close(t) + ep2.Close(t) return nil, nil, err } s2, err := NewSockfsFile(t, ep2, stype) if err != nil { - s1.DecRef() - ep2.Close() + s1.DecRef(t) + ep2.Close(t) return nil, nil, err } diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 68ca537c8..87b239730 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -147,14 +147,14 @@ func fd(t *kernel.Task, fd int32) string { root := t.FSContext().RootDirectory() if root != nil { - defer root.DecRef() + defer root.DecRef(t) } if fd == linux.AT_FDCWD { wd := t.FSContext().WorkingDirectory() var name string if wd != nil { - defer wd.DecRef() + defer wd.DecRef(t) name, _ = wd.FullName(root) } else { name = "(unknown cwd)" @@ -167,7 +167,7 @@ func fd(t *kernel.Task, fd int32) string { // Cast FD to uint64 to avoid printing negative hex. return fmt.Sprintf("%#x (bad FD)", uint64(fd)) } - defer file.DecRef() + defer file.DecRef(t) name, _ := file.Dirent.FullName(root) return fmt.Sprintf("%#x %s", fd, name) @@ -175,12 +175,12 @@ func fd(t *kernel.Task, fd int32) string { func fdVFS2(t *kernel.Task, fd int32) string { root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) vfsObj := root.Mount().Filesystem().VirtualFilesystem() if fd == linux.AT_FDCWD { wd := t.FSContext().WorkingDirectoryVFS2() - defer wd.DecRef() + defer wd.DecRef(t) name, _ := vfsObj.PathnameWithDeleted(t, root, wd) return fmt.Sprintf("AT_FDCWD %s", name) @@ -191,7 +191,7 @@ func fdVFS2(t *kernel.Task, fd int32) string { // Cast FD to uint64 to avoid printing negative hex. return fmt.Sprintf("%#x (bad FD)", uint64(fd)) } - defer file.DecRef() + defer file.DecRef(t) name, _ := vfsObj.PathnameWithDeleted(t, root, file.VirtualDentry()) return fmt.Sprintf("%#x %s", fd, name) diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go index d9fb808c0..d23a0068a 100644 --- a/pkg/sentry/syscalls/epoll.go +++ b/pkg/sentry/syscalls/epoll.go @@ -28,7 +28,7 @@ import ( // CreateEpoll implements the epoll_create(2) linux syscall. func CreateEpoll(t *kernel.Task, closeOnExec bool) (int32, error) { file := epoll.NewEventPoll(t) - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFrom(0, file, kernel.FDFlags{ CloseOnExec: closeOnExec, @@ -47,14 +47,14 @@ func AddEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, mask if epollfile == nil { return syserror.EBADF } - defer epollfile.DecRef() + defer epollfile.DecRef(t) // Get the target file id. file := t.GetFile(fd) if file == nil { return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the epollPoll operations. e, ok := epollfile.FileOperations.(*epoll.EventPoll) @@ -73,14 +73,14 @@ func UpdateEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, m if epollfile == nil { return syserror.EBADF } - defer epollfile.DecRef() + defer epollfile.DecRef(t) // Get the target file id. file := t.GetFile(fd) if file == nil { return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the epollPoll operations. e, ok := epollfile.FileOperations.(*epoll.EventPoll) @@ -99,14 +99,14 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error { if epollfile == nil { return syserror.EBADF } - defer epollfile.DecRef() + defer epollfile.DecRef(t) // Get the target file id. file := t.GetFile(fd) if file == nil { return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the epollPoll operations. e, ok := epollfile.FileOperations.(*epoll.EventPoll) @@ -115,7 +115,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error { } // Try to remove the entry. - return e.RemoveEntry(epoll.FileIdentifier{file, fd}) + return e.RemoveEntry(t, epoll.FileIdentifier{file, fd}) } // WaitEpoll implements the epoll_wait(2) linux syscall. @@ -125,7 +125,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve if epollfile == nil { return nil, syserror.EBADF } - defer epollfile.DecRef() + defer epollfile.DecRef(t) // Extract the epollPoll operations. e, ok := epollfile.FileOperations.(*epoll.EventPoll) diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index ba2557c52..e9d64dec5 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -247,7 +247,7 @@ func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linu ev.Result = -int64(kernel.ExtractErrno(err, 0)) } - file.DecRef() + file.DecRef(ctx) // Queue the result for delivery. actx.FinishRequest(ev) @@ -257,7 +257,7 @@ func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linu // wake up. if eventFile != nil { eventFile.FileOperations.(*eventfd.EventOperations).Signal(1) - eventFile.DecRef() + eventFile.DecRef(ctx) } } } @@ -269,7 +269,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user // File not found. return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Was there an eventFD? Extract it. var eventFile *fs.File @@ -279,7 +279,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user // Bad FD. return syserror.EBADF } - defer eventFile.DecRef() + defer eventFile.DecRef(t) // Check that it is an eventfd. if _, ok := eventFile.FileOperations.(*eventfd.EventOperations); !ok { diff --git a/pkg/sentry/syscalls/linux/sys_eventfd.go b/pkg/sentry/syscalls/linux/sys_eventfd.go index ed3413ca6..3b4f879e4 100644 --- a/pkg/sentry/syscalls/linux/sys_eventfd.go +++ b/pkg/sentry/syscalls/linux/sys_eventfd.go @@ -37,7 +37,7 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc event.SetFlags(fs.SettableFileFlags{ NonBlocking: flags&linux.EFD_NONBLOCK != 0, }) - defer event.DecRef() + defer event.DecRef(t) fd, err := t.NewFDFrom(0, event, kernel.FDFlags{ CloseOnExec: flags&linux.EFD_CLOEXEC != 0, diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 8cf6401e7..1bc9b184e 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -40,7 +40,7 @@ func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent, // Common case: we are accessing a file in the root. root := t.FSContext().RootDirectory() err := fn(root, root, name, linux.MaxSymlinkTraversals) - root.DecRef() + root.DecRef(t) return err } else if dir == "." && dirFD == linux.AT_FDCWD { // Common case: we are accessing a file relative to the current @@ -48,8 +48,8 @@ func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent, wd := t.FSContext().WorkingDirectory() root := t.FSContext().RootDirectory() err := fn(root, wd, name, linux.MaxSymlinkTraversals) - wd.DecRef() - root.DecRef() + wd.DecRef(t) + root.DecRef(t) return err } @@ -97,19 +97,19 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro } else { d, err = t.MountNamespace().FindLink(t, root, rel, path, &remainingTraversals) } - root.DecRef() + root.DecRef(t) if wd != nil { - wd.DecRef() + wd.DecRef(t) } if f != nil { - f.DecRef() + f.DecRef(t) } if err != nil { return err } err = fn(root, d, remainingTraversals) - d.DecRef() + d.DecRef(t) return err } @@ -186,7 +186,7 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint if err != nil { return syserror.ConvertIntr(err, kernel.ERESTARTSYS) } - defer file.DecRef() + defer file.DecRef(t) // Success. newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{ @@ -242,7 +242,7 @@ func mknodAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode if err != nil { return err } - file.DecRef() + file.DecRef(t) return nil case linux.ModeNamedPipe: @@ -332,7 +332,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l if err != nil { break } - defer found.DecRef() + defer found.DecRef(t) // We found something (possibly a symlink). If the // O_EXCL flag was passed, then we can immediately @@ -357,7 +357,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l resolved, err = found.Inode.Getlink(t) if err == nil { // No more resolution necessary. - defer resolved.DecRef() + defer resolved.DecRef(t) break } if err != fs.ErrResolveViaReadlink { @@ -384,7 +384,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l if err != nil { break } - defer newParent.DecRef() + defer newParent.DecRef(t) // Repeat the process with the parent and name of the // symlink target. @@ -416,7 +416,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l if err != nil { return syserror.ConvertIntr(err, kernel.ERESTARTSYS) } - defer newFile.DecRef() + defer newFile.DecRef(t) case syserror.ENOENT: // File does not exist. Proceed with creation. @@ -432,7 +432,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l // No luck, bail. return err } - defer newFile.DecRef() + defer newFile.DecRef(t) found = newFile.Dirent default: return err @@ -596,7 +596,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Shared flags between file and socket. switch request { @@ -671,9 +671,9 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal addr := args[0].Pointer() size := args[1].SizeT() cwd := t.FSContext().WorkingDirectory() - defer cwd.DecRef() + defer cwd.DecRef(t) root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) // Get our fullname from the root and preprend unreachable if the root was // unreachable from our current dirent this is the same behavior as on linux. @@ -722,7 +722,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return err } - t.FSContext().SetRootDirectory(d) + t.FSContext().SetRootDirectory(t, d) return nil }) } @@ -747,7 +747,7 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return err } - t.FSContext().SetWorkingDirectory(d) + t.FSContext().SetWorkingDirectory(t, d) return nil }) } @@ -760,7 +760,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Is it a directory? if !fs.IsDir(file.Dirent.Inode.StableAttr) { @@ -772,7 +772,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err } - t.FSContext().SetWorkingDirectory(file.Dirent) + t.FSContext().SetWorkingDirectory(t, file.Dirent) return 0, nil, nil } @@ -791,7 +791,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := file.Flush(t) return 0, nil, handleIOError(t, false /* partial */, err, syserror.EINTR, "close", file) @@ -805,7 +805,7 @@ func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{}) if err != nil { @@ -826,7 +826,7 @@ func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if oldFile == nil { return 0, nil, syserror.EBADF } - defer oldFile.DecRef() + defer oldFile.DecRef(t) return uintptr(newfd), nil, nil } @@ -850,7 +850,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if oldFile == nil { return 0, nil, syserror.EBADF } - defer oldFile.DecRef() + defer oldFile.DecRef(t) err := t.NewFDAt(newfd, oldFile, kernel.FDFlags{CloseOnExec: flags&linux.O_CLOEXEC != 0}) if err != nil { @@ -925,7 +925,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) switch cmd { case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC: @@ -1132,7 +1132,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // If the FD refers to a pipe or FIFO, return error. if fs.IsPipe(file.Dirent.Inode.StableAttr) { @@ -1171,7 +1171,7 @@ func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode switch err { case nil: // The directory existed. - defer f.DecRef() + defer f.DecRef(t) return syserror.EEXIST case syserror.EACCES: // Permission denied while walking to the directory. @@ -1349,7 +1349,7 @@ func linkAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32 if target == nil { return syserror.EBADF } - defer target.DecRef() + defer target.DecRef(t) if err := mayLinkAt(t, target.Dirent.Inode); err != nil { return err } @@ -1602,7 +1602,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Reject truncation if the file flags do not permit this operation. // This is different from truncate(2) above. @@ -1730,7 +1730,7 @@ func chownAt(t *kernel.Task, fd int32, addr usermem.Addr, resolve, allowEmpty bo if file == nil { return syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return chown(t, file.Dirent, uid, gid) } @@ -1768,7 +1768,7 @@ func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, chown(t, file.Dirent, uid, gid) } @@ -1833,7 +1833,7 @@ func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, chmod(t, file.Dirent, mode) } @@ -1893,10 +1893,10 @@ func utimes(t *kernel.Task, dirFD int32, addr usermem.Addr, ts fs.TimeSpec, reso if f == nil { return syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) return setTimestamp(root, f.Dirent, linux.MaxSymlinkTraversals) } @@ -2088,7 +2088,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) if offset < 0 || length <= 0 { return 0, nil, syserror.EINVAL @@ -2141,7 +2141,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // flock(2): EBADF fd is not an open file descriptor. return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) nonblocking := operation&linux.LOCK_NB != 0 operation &^= linux.LOCK_NB @@ -2224,8 +2224,8 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, err } - defer dirent.DecRef() - defer file.DecRef() + defer dirent.DecRef(t) + defer file.DecRef(t) newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{ CloseOnExec: cloExec, diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index f04d78856..9d1b2edb1 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -73,7 +73,7 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo err = t.BlockWithDeadline(w.C, true, ktime.FromTimespec(ts)) } - t.Futex().WaitComplete(w) + t.Futex().WaitComplete(w, t) return 0, syserror.ConvertIntr(err, kernel.ERESTARTSYS) } @@ -95,7 +95,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add } remaining, err := t.BlockWithTimeout(w.C, !forever, duration) - t.Futex().WaitComplete(w) + t.Futex().WaitComplete(w, t) if err == nil { return 0, nil } @@ -148,7 +148,7 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.A timer.Destroy() } - t.Futex().WaitComplete(w) + t.Futex().WaitComplete(w, t) return syserror.ConvertIntr(err, kernel.ERESTARTSYS) } diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index b126fecc0..f5699e55d 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -68,7 +68,7 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir if dir == nil { return 0, syserror.EBADF } - defer dir.DecRef() + defer dir.DecRef(t) w := &usermem.IOReadWriter{ Ctx: t, diff --git a/pkg/sentry/syscalls/linux/sys_inotify.go b/pkg/sentry/syscalls/linux/sys_inotify.go index b2c7b3444..cf47bb9dd 100644 --- a/pkg/sentry/syscalls/linux/sys_inotify.go +++ b/pkg/sentry/syscalls/linux/sys_inotify.go @@ -40,7 +40,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. NonBlocking: flags&linux.IN_NONBLOCK != 0, } n := fs.NewFile(t, dirent, fileFlags, fs.NewInotify(t)) - defer n.DecRef() + defer n.DecRef(t) fd, err := t.NewFDFrom(0, n, kernel.FDFlags{ CloseOnExec: flags&linux.IN_CLOEXEC != 0, @@ -71,7 +71,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*fs.Inotify, *fs.File, error) { ino, ok := file.FileOperations.(*fs.Inotify) if !ok { // Not an inotify fd. - file.DecRef() + file.DecRef(t) return nil, nil, syserror.EINVAL } @@ -98,7 +98,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -128,6 +128,6 @@ func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if err != nil { return 0, nil, err } - defer file.DecRef() - return 0, nil, ino.RmWatch(wd) + defer file.DecRef(t) + return 0, nil, ino.RmWatch(t, wd) } diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go index 3f7691eae..1c38f8f4f 100644 --- a/pkg/sentry/syscalls/linux/sys_lseek.go +++ b/pkg/sentry/syscalls/linux/sys_lseek.go @@ -33,7 +33,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) var sw fs.SeekWhence switch whence { diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 91694d374..72786b032 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -75,7 +75,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC } defer func() { if opts.MappingIdentity != nil { - opts.MappingIdentity.DecRef() + opts.MappingIdentity.DecRef(t) } }() @@ -85,7 +85,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) flags := file.Flags() // mmap unconditionally requires that the FD is readable. diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go index eb5ff48f5..bd0633564 100644 --- a/pkg/sentry/syscalls/linux/sys_mount.go +++ b/pkg/sentry/syscalls/linux/sys_mount.go @@ -115,7 +115,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall }); err != nil { // Something went wrong. Drop our ref on rootInode before // returning the error. - rootInode.DecRef() + rootInode.DecRef(t) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 43c510930..3149e4aad 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -34,10 +34,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize) r.SetFlags(linuxToFlags(flags).Settable()) - defer r.DecRef() + defer r.DecRef(t) w.SetFlags(linuxToFlags(flags).Settable()) - defer w.DecRef() + defer w.DecRef(t) fds, err := t.NewFDs(0, []*fs.File{r, w}, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -49,7 +49,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { if _, err := t.CopyOut(addr, fds); err != nil { for _, fd := range fds { if file, _ := t.FDTable().Remove(fd); file != nil { - file.DecRef() + file.DecRef(t) } } return 0, err diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index f0198141c..3435bdf77 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -70,7 +70,7 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan } if ch == nil { - defer file.DecRef() + defer file.DecRef(t) } else { state.file = file state.waiter, _ = waiter.NewChannelEntry(ch) @@ -82,11 +82,11 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan } // releaseState releases all the pollState in "state". -func releaseState(state []pollState) { +func releaseState(t *kernel.Task, state []pollState) { for i := range state { if state[i].file != nil { state[i].file.EventUnregister(&state[i].waiter) - state[i].file.DecRef() + state[i].file.DecRef(t) } } } @@ -107,7 +107,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // result, we stop registering for events but still go through all files // to get their ready masks. state := make([]pollState, len(pfd)) - defer releaseState(state) + defer releaseState(t, state) n := uintptr(0) for i := range pfd { initReadiness(t, &pfd[i], &state[i], ch) @@ -266,7 +266,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add if file == nil { return 0, syserror.EBADF } - file.DecRef() + file.DecRef(t) var mask int16 if (rV & m) != 0 { diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index f92bf8096..64a725296 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -128,7 +128,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // They trying to set exe to a non-file? if !fs.IsFile(file.Dirent.Inode.StableAttr) { @@ -136,7 +136,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } // Set the underlying executable. - t.MemoryManager().SetExecutable(fsbridge.NewFSFile(file)) + t.MemoryManager().SetExecutable(t, fsbridge.NewFSFile(file)) case linux.PR_SET_MM_AUXV, linux.PR_SET_MM_START_CODE, diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 071b4bacc..3bbc3fa4b 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -48,7 +48,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is readable. if !file.Flags().Read { @@ -84,7 +84,7 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is readable. if !file.Flags().Read { @@ -118,7 +118,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { @@ -164,7 +164,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is readable. if !file.Flags().Read { @@ -195,7 +195,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < 0 { @@ -244,7 +244,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < -1 { diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go index 4a8bc24a2..f0ae8fa8e 100644 --- a/pkg/sentry/syscalls/linux/sys_shm.go +++ b/pkg/sentry/syscalls/linux/sys_shm.go @@ -39,7 +39,7 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - defer segment.DecRef() + defer segment.DecRef(t) return uintptr(segment.ID), nil, nil } @@ -66,7 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, syserror.EINVAL } - defer segment.DecRef() + defer segment.DecRef(t) opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{ Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC, @@ -108,7 +108,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, syserror.EINVAL } - defer segment.DecRef() + defer segment.DecRef(t) stat, err := segment.IPCStat(t) if err == nil { @@ -132,7 +132,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, syserror.EINVAL } - defer segment.DecRef() + defer segment.DecRef(t) switch cmd { case linux.IPC_SET: @@ -145,7 +145,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err case linux.IPC_RMID: - segment.MarkDestroyed() + segment.MarkDestroyed(t) return 0, nil, nil case linux.SHM_LOCK, linux.SHM_UNLOCK: diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index d2b0012ae..20cb1a5cb 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -536,7 +536,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Is this a signalfd? if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok { @@ -553,7 +553,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) // Set appropriate flags. file.SetFlags(fs.SettableFileFlags{ diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 414fce8e3..fec1c1974 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -200,7 +200,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal s.SetFlags(fs.SettableFileFlags{ NonBlocking: stype&linux.SOCK_NONBLOCK != 0, }) - defer s.DecRef() + defer s.DecRef(t) fd, err := t.NewFDFrom(0, s, kernel.FDFlags{ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0, @@ -235,8 +235,8 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } s1.SetFlags(fileFlags) s2.SetFlags(fileFlags) - defer s1.DecRef() - defer s2.DecRef() + defer s1.DecRef(t) + defer s2.DecRef(t) // Create the FDs for the sockets. fds, err := t.NewFDs(0, []*fs.File{s1, s2}, kernel.FDFlags{ @@ -250,7 +250,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if _, err := t.CopyOut(socks, fds); err != nil { for _, fd := range fds { if file, _ := t.FDTable().Remove(fd); file != nil { - file.DecRef() + file.DecRef(t) } } return 0, nil, err @@ -270,7 +270,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -301,7 +301,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -360,7 +360,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -387,7 +387,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -416,7 +416,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -447,7 +447,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -529,7 +529,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -567,7 +567,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -595,7 +595,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -628,7 +628,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -681,7 +681,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -775,7 +775,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC - cms.Release() + cms.Release(t) } if int(msg.Flags) != mflags { @@ -795,7 +795,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } - defer cms.Release() + defer cms.Release(t) controlData := make([]byte, 0, msg.ControlLen) controlData = control.PackControlMessages(t, cms, controlData) @@ -851,7 +851,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -880,7 +880,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag } n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) - cm.Release() + cm.Release(t) if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } @@ -924,7 +924,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -962,7 +962,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) @@ -1066,7 +1066,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) if err != nil { - controlMessages.Release() + controlMessages.Release(t) } return uintptr(n), err } @@ -1084,7 +1084,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.FileOperations.(socket.Socket) diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 77c78889d..b8846a10a 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -101,7 +101,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) if !inFile.Flags().Read { return 0, nil, syserror.EBADF @@ -111,7 +111,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) if !outFile.Flags().Write { return 0, nil, syserror.EBADF @@ -192,13 +192,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) inFile := t.GetFile(inFD) if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) // The operation is non-blocking if anything is non-blocking. // @@ -300,13 +300,13 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) inFile := t.GetFile(inFD) if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) // All files must be pipes. if !fs.IsPipe(inFile.Dirent.Inode.StableAttr) || !fs.IsPipe(outFile.Dirent.Inode.StableAttr) { diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 46ebf27a2..a5826f2dd 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -58,7 +58,7 @@ func Fstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, fstat(t, file, statAddr) } @@ -100,7 +100,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, fstat(t, file, statAddr) } @@ -158,7 +158,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) uattr, err := file.UnstableAttr(t) if err != nil { return 0, nil, err @@ -249,7 +249,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, statfsImpl(t, file.Dirent, statfsAddr) } diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go index 5ad465ae3..f2c0e5069 100644 --- a/pkg/sentry/syscalls/linux/sys_sync.go +++ b/pkg/sentry/syscalls/linux/sys_sync.go @@ -39,7 +39,7 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Use "sync-the-world" for now, it's guaranteed that fd is at least // on the root filesystem. @@ -54,7 +54,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll) return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) @@ -70,7 +70,7 @@ func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData) return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) @@ -103,7 +103,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // SYNC_FILE_RANGE_WAIT_BEFORE waits upon write-out of all pages in the // specified range that have already been submitted to the device diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 00915fdde..2d16e4933 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -117,7 +117,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0 root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) var wd *fs.Dirent var executable fsbridge.File @@ -133,7 +133,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) closeOnExec = fdFlags.CloseOnExec if atEmptyPath && len(pathname) == 0 { @@ -155,7 +155,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user } } if wd != nil { - defer wd.DecRef() + defer wd.DecRef(t) } // Load the new TaskContext. diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go index cf49b43db..34b03e4ee 100644 --- a/pkg/sentry/syscalls/linux/sys_timerfd.go +++ b/pkg/sentry/syscalls/linux/sys_timerfd.go @@ -43,7 +43,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.EINVAL } f := timerfd.NewFile(t, c) - defer f.DecRef() + defer f.DecRef(t) f.SetFlags(fs.SettableFileFlags{ NonBlocking: flags&linux.TFD_NONBLOCK != 0, }) @@ -73,7 +73,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) tf, ok := f.FileOperations.(*timerfd.TimerOperations) if !ok { @@ -107,7 +107,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) tf, ok := f.FileOperations.(*timerfd.TimerOperations) if !ok { diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 6ec0de96e..485526e28 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -48,7 +48,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is writable. if !file.Flags().Write { @@ -85,7 +85,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { @@ -131,7 +131,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is writable. if !file.Flags().Write { @@ -162,7 +162,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < 0 { @@ -215,7 +215,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < -1 { diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index c24946160..97474fd3c 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -49,7 +49,7 @@ func FGetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) n, err := getXattr(t, f.Dirent, nameAddr, valueAddr, size) if err != nil { @@ -153,7 +153,7 @@ func FSetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) return 0, nil, setXattr(t, f.Dirent, nameAddr, valueAddr, uint64(size), flags) } @@ -270,7 +270,7 @@ func FListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) n, err := listXattr(t, f.Dirent, listAddr, size) if err != nil { @@ -384,7 +384,7 @@ func FRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if f == nil { return 0, nil, syserror.EBADF } - defer f.DecRef() + defer f.DecRef(t) return 0, nil, removeXattr(t, f.Dirent, nameAddr) } diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index e5cdefc50..399b4f60c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -88,7 +88,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if fd == nil { return syserror.EBADF } - defer fd.DecRef() + defer fd.DecRef(t) // Was there an eventFD? Extract it. var eventFD *vfs.FileDescription @@ -97,7 +97,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if eventFD == nil { return syserror.EBADF } - defer eventFD.DecRef() + defer eventFD.DecRef(t) // Check that it is an eventfd. if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok { @@ -169,7 +169,7 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use ev.Result = -int64(kernel.ExtractErrno(err, 0)) } - fd.DecRef() + fd.DecRef(ctx) // Queue the result for delivery. aioCtx.FinishRequest(ev) @@ -179,7 +179,7 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use // wake up. if eventFD != nil { eventFD.Impl().(*eventfd.EventFileDescription).Signal(1) - eventFD.DecRef() + eventFD.DecRef(ctx) } } } diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index 34c90ae3e..c62f03509 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -37,11 +37,11 @@ func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. return 0, nil, syserror.EINVAL } - file, err := t.Kernel().VFS().NewEpollInstanceFD() + file, err := t.Kernel().VFS().NewEpollInstanceFD(t) if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0, @@ -62,11 +62,11 @@ func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, syserror.EINVAL } - file, err := t.Kernel().VFS().NewEpollInstanceFD() + file, err := t.Kernel().VFS().NewEpollInstanceFD(t) if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{}) if err != nil { @@ -86,7 +86,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if epfile == nil { return 0, nil, syserror.EBADF } - defer epfile.DecRef() + defer epfile.DecRef(t) ep, ok := epfile.Impl().(*vfs.EpollInstance) if !ok { return 0, nil, syserror.EINVAL @@ -95,7 +95,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) if epfile == file { return 0, nil, syserror.EINVAL } @@ -135,7 +135,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if epfile == nil { return 0, nil, syserror.EBADF } - defer epfile.DecRef() + defer epfile.DecRef(t) ep, ok := epfile.Impl().(*vfs.EpollInstance) if !ok { return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/vfs2/eventfd.go b/pkg/sentry/syscalls/linux/vfs2/eventfd.go index aff1a2070..807f909da 100644 --- a/pkg/sentry/syscalls/linux/vfs2/eventfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/eventfd.go @@ -38,11 +38,11 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc fileFlags |= linux.O_NONBLOCK } semMode := flags&linux.EFD_SEMAPHORE != 0 - eventfd, err := eventfd.New(vfsObj, initVal, semMode, fileFlags) + eventfd, err := eventfd.New(t, vfsObj, initVal, semMode, fileFlags) if err != nil { return 0, nil, err } - defer eventfd.DecRef() + defer eventfd.DecRef(t) fd, err := t.NewFDFromVFS2(0, eventfd, kernel.FDFlags{ CloseOnExec: flags&linux.EFD_CLOEXEC != 0, diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go index aef0078a8..066ee0863 100644 --- a/pkg/sentry/syscalls/linux/vfs2/execve.go +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -71,7 +71,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user } root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) var executable fsbridge.File closeOnExec := false if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute { @@ -90,7 +90,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user } start := dirfile.VirtualDentry() start.IncRef() - dirfile.DecRef() + dirfile.DecRef(t) closeOnExec = dirfileFlags.CloseOnExec file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{ Root: root, @@ -101,19 +101,19 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user Flags: linux.O_RDONLY, FileExec: true, }) - start.DecRef() + start.DecRef(t) if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) executable = fsbridge.NewVFSFile(file) } // Load the new TaskContext. mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change - defer mntns.DecRef() + defer mntns.DecRef(t) wd := t.FSContext().WorkingDirectoryVFS2() - defer wd.DecRef() + defer wd.DecRef(t) remainingTraversals := uint(linux.MaxSymlinkTraversals) loadArgs := loader.LoadArgs{ Opener: fsbridge.NewVFSLookup(mntns, root, wd), diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 67f191551..72ca916a0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -38,7 +38,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := file.OnClose(t) return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file) @@ -52,7 +52,7 @@ func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{}) if err != nil { @@ -72,7 +72,7 @@ func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - file.DecRef() + file.DecRef(t) return uintptr(newfd), nil, nil } @@ -101,7 +101,7 @@ func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -121,7 +121,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) switch cmd { case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC: @@ -332,7 +332,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // If the FD refers to a pipe or FIFO, return error. if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe { diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index b6d2ddd65..01e0f9010 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -56,7 +56,7 @@ func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd i if err != nil { return err } - defer oldtpop.Release() + defer oldtpop.Release(t) newpath, err := copyInPath(t, newpathAddr) if err != nil { @@ -66,7 +66,7 @@ func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd i if err != nil { return err } - defer newtpop.Release() + defer newtpop.Release(t) return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop) } @@ -95,7 +95,7 @@ func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error { if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()), }) @@ -127,7 +127,7 @@ func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) // "Zero file type is equivalent to type S_IFREG." - mknod(2) if mode.FileType() == 0 { @@ -174,7 +174,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{ Flags: flags | linux.O_LARGEFILE, @@ -183,7 +183,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -227,7 +227,7 @@ func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd if err != nil { return err } - defer oldtpop.Release() + defer oldtpop.Release(t) newpath, err := copyInPath(t, newpathAddr) if err != nil { @@ -237,7 +237,7 @@ func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd if err != nil { return err } - defer newtpop.Release() + defer newtpop.Release(t) return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{ Flags: flags, @@ -259,7 +259,7 @@ func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop) } @@ -278,7 +278,7 @@ func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error { if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop) } @@ -329,6 +329,6 @@ func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpath if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target) } diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go index 317409a18..a7d4d2a36 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fscontext.go +++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go @@ -31,8 +31,8 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal root := t.FSContext().RootDirectoryVFS2() wd := t.FSContext().WorkingDirectoryVFS2() s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd) - root.DecRef() - wd.DecRef() + root.DecRef(t) + wd.DecRef(t) if err != nil { return 0, nil, err } @@ -67,7 +67,7 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ CheckSearchable: true, @@ -75,8 +75,8 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - t.FSContext().SetWorkingDirectoryVFS2(vd) - vd.DecRef() + t.FSContext().SetWorkingDirectoryVFS2(t, vd) + vd.DecRef(t) return 0, nil, nil } @@ -88,7 +88,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ CheckSearchable: true, @@ -96,8 +96,8 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - t.FSContext().SetWorkingDirectoryVFS2(vd) - vd.DecRef() + t.FSContext().SetWorkingDirectoryVFS2(t, vd) + vd.DecRef(t) return 0, nil, nil } @@ -117,7 +117,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{ CheckSearchable: true, @@ -125,7 +125,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - t.FSContext().SetRootDirectoryVFS2(vd) - vd.DecRef() + t.FSContext().SetRootDirectoryVFS2(t, vd) + vd.DecRef(t) return 0, nil, nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go index c7c7bf7ce..5517595b5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/getdents.go +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go @@ -44,7 +44,7 @@ func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (ui if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) cb := getGetdentsCallback(t, addr, size, isGetdents64) err := file.IterDirents(t, cb) diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go index 5d98134a5..11753d8e5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/inotify.go +++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go @@ -35,7 +35,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if err != nil { return 0, nil, err } - defer ino.DecRef() + defer ino.DecRef(t) fd, err := t.NewFDFromVFS2(0, ino, kernel.FDFlags{ CloseOnExec: flags&linux.IN_CLOEXEC != 0, @@ -66,7 +66,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription, ino, ok := f.Impl().(*vfs.Inotify) if !ok { // Not an inotify fd. - f.DecRef() + f.DecRef(t) return nil, nil, syserror.EINVAL } @@ -96,7 +96,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern if err != nil { return 0, nil, err } - defer f.DecRef() + defer f.DecRef(t) path, err := copyInPath(t, addr) if err != nil { @@ -109,12 +109,12 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) d, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{}) if err != nil { return 0, nil, err } - defer d.DecRef() + defer d.DecRef(t) fd, err = ino.AddWatch(d.Dentry(), mask) if err != nil { @@ -132,6 +132,6 @@ func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if err != nil { return 0, nil, err } - defer f.DecRef() - return 0, nil, ino.RmWatch(wd) + defer f.DecRef(t) + return 0, nil, ino.RmWatch(t, wd) } diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index fd6ab94b2..38778a388 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -29,7 +29,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Handle ioctls that apply to all FDs. switch args[1].Int() { diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go index bf19028c4..b910b5a74 100644 --- a/pkg/sentry/syscalls/linux/vfs2/lock.go +++ b/pkg/sentry/syscalls/linux/vfs2/lock.go @@ -32,7 +32,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // flock(2): EBADF fd is not an open file descriptor. return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) nonblocking := operation&linux.LOCK_NB != 0 operation &^= linux.LOCK_NB diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go index bbe248d17..519583e4e 100644 --- a/pkg/sentry/syscalls/linux/vfs2/memfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go @@ -47,7 +47,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S } shmMount := t.Kernel().ShmMount() - file, err := tmpfs.NewMemfd(shmMount, t.Credentials(), allowSeals, memfdPrefix+name) + file, err := tmpfs.NewMemfd(t, t.Credentials(), shmMount, allowSeals, memfdPrefix+name) if err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go index 60a43f0a0..dc05c2994 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mmap.go +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go @@ -61,7 +61,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC } defer func() { if opts.MappingIdentity != nil { - opts.MappingIdentity.DecRef() + opts.MappingIdentity.DecRef(t) } }() @@ -71,7 +71,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // mmap unconditionally requires that the FD is readable. if !file.IsReadable() { diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index ea337de7c..4bd5c7ca2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -108,7 +108,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - defer target.Release() + defer target.Release(t) return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts) } @@ -140,7 +140,7 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) opts := vfs.UmountOptions{ Flags: uint32(flags), diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go index 97da6c647..90a511d9a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/path.go +++ b/pkg/sentry/syscalls/linux/vfs2/path.go @@ -42,7 +42,7 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA haveStartRef := false if !path.Absolute { if !path.HasComponents() && !bool(shouldAllowEmptyPath) { - root.DecRef() + root.DecRef(t) return taskPathOperation{}, syserror.ENOENT } if dirfd == linux.AT_FDCWD { @@ -51,13 +51,13 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA } else { dirfile := t.GetFileVFS2(dirfd) if dirfile == nil { - root.DecRef() + root.DecRef(t) return taskPathOperation{}, syserror.EBADF } start = dirfile.VirtualDentry() start.IncRef() haveStartRef = true - dirfile.DecRef() + dirfile.DecRef(t) } } return taskPathOperation{ @@ -71,10 +71,10 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA }, nil } -func (tpop *taskPathOperation) Release() { - tpop.pop.Root.DecRef() +func (tpop *taskPathOperation) Release(t *kernel.Task) { + tpop.pop.Root.DecRef(t) if tpop.haveStartRef { - tpop.pop.Start.DecRef() + tpop.pop.Start.DecRef(t) tpop.haveStartRef = false } } diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index 4a01e4209..9b4848d9e 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -42,8 +42,8 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { return syserror.EINVAL } r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) - defer r.DecRef() - defer w.DecRef() + defer r.DecRef(t) + defer w.DecRef(t) fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{r, w}, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -54,7 +54,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { if _, err := t.CopyOut(addr, fds); err != nil { for _, fd := range fds { if _, file := t.FDTable().Remove(fd); file != nil { - file.DecRef() + file.DecRef(t) } } return err diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index ff1b25d7b..7b9d5e18a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -73,7 +73,7 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan } if ch == nil { - defer file.DecRef() + defer file.DecRef(t) } else { state.file = file state.waiter, _ = waiter.NewChannelEntry(ch) @@ -85,11 +85,11 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan } // releaseState releases all the pollState in "state". -func releaseState(state []pollState) { +func releaseState(t *kernel.Task, state []pollState) { for i := range state { if state[i].file != nil { state[i].file.EventUnregister(&state[i].waiter) - state[i].file.DecRef() + state[i].file.DecRef(t) } } } @@ -110,7 +110,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // result, we stop registering for events but still go through all files // to get their ready masks. state := make([]pollState, len(pfd)) - defer releaseState(state) + defer releaseState(t, state) n := uintptr(0) for i := range pfd { initReadiness(t, &pfd[i], &state[i], ch) @@ -269,7 +269,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add if file == nil { return 0, syserror.EBADF } - file.DecRef() + file.DecRef(t) var mask int16 if (rV & m) != 0 { diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index cd25597a7..a905dae0a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -44,7 +44,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the size is legitimate. si := int(size) @@ -75,7 +75,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Get the destination of the read. dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ @@ -94,7 +94,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt n, err := file.Read(t, dst, opts) if err != syserror.ErrWouldBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -102,7 +102,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -135,7 +135,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt file.EventUnregister(&w) if total > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return total, err } @@ -151,7 +151,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { @@ -188,7 +188,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < 0 { @@ -226,7 +226,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < -1 { @@ -258,7 +258,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of n, err := file.PRead(t, dst, offset, opts) if err != syserror.ErrWouldBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -266,7 +266,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -299,7 +299,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of file.EventUnregister(&w) if total > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return total, err } @@ -314,7 +314,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the size is legitimate. si := int(size) @@ -345,7 +345,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Get the source of the write. src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{ @@ -364,7 +364,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op n, err := file.Write(t, src, opts) if err != syserror.ErrWouldBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } return n, err } @@ -372,7 +372,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } return n, err } @@ -405,7 +405,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op file.EventUnregister(&w) if total > 0 { - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } return total, err } @@ -421,7 +421,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { @@ -458,7 +458,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < 0 { @@ -496,7 +496,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the offset is legitimate. if offset < -1 { @@ -528,7 +528,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o n, err := file.PWrite(t, src, offset, opts) if err != syserror.ErrWouldBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } return n, err } @@ -536,7 +536,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { if n > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return n, err } @@ -569,7 +569,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o file.EventUnregister(&w) if total > 0 { - file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) } return total, err } @@ -601,7 +601,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) newoff, err := file.Seek(t, offset, whence) return uintptr(newoff), nil, err @@ -617,7 +617,7 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Check that the file is readable. if !file.IsReadable() { diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 25cdb7a55..5e6eb13ba 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -66,7 +66,7 @@ func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, file.SetStat(t, vfs.SetStatOptions{ Stat: linux.Statx{ @@ -151,7 +151,7 @@ func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) var opts vfs.SetStatOptions if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil { @@ -197,7 +197,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) if !file.IsWritable() { return 0, nil, syserror.EINVAL @@ -224,7 +224,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) if !file.IsWritable() { return 0, nil, syserror.EBADF @@ -258,7 +258,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) return 0, nil, nil } @@ -438,7 +438,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error { root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root if !path.Absolute { if !path.HasComponents() && !bool(shouldAllowEmptyPath) { @@ -446,7 +446,7 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } else { dirfile := t.GetFileVFS2(dirfd) if dirfile == nil { @@ -457,13 +457,13 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa // VirtualFilesystem.SetStatAt(), since the former may be able // to use opened file state to expedite the SetStat. err := dirfile.SetStat(t, *opts) - dirfile.DecRef() + dirfile.DecRef(t) return err } start = dirfile.VirtualDentry() start.IncRef() - defer start.DecRef() - dirfile.DecRef() + defer start.DecRef(t) + dirfile.DecRef(t) } } return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{ diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go index 623992f6f..b89f34cdb 100644 --- a/pkg/sentry/syscalls/linux/vfs2/signal.go +++ b/pkg/sentry/syscalls/linux/vfs2/signal.go @@ -45,7 +45,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Is this a signalfd? if sfd, ok := file.Impl().(*signalfd.SignalFileDescription); ok { @@ -68,7 +68,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) // Create a new descriptor. fd, err = t.NewFDFromVFS2(0, file, kernel.FDFlags{ diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 8096a8f9c..4a68c64f3 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -196,7 +196,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if e != nil { return 0, nil, e.ToError() } - defer s.DecRef() + defer s.DecRef(t) if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil { return 0, nil, err @@ -230,8 +230,8 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } // Adding to the FD table will cause an extra reference to be acquired. - defer s1.DecRef() - defer s2.DecRef() + defer s1.DecRef(t) + defer s2.DecRef(t) nonblocking := uint32(stype & linux.SOCK_NONBLOCK) if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil { @@ -253,7 +253,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if _, err := t.CopyOut(addr, fds); err != nil { for _, fd := range fds { if _, file := t.FDTable().Remove(fd); file != nil { - file.DecRef() + file.DecRef(t) } } return 0, nil, err @@ -273,7 +273,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -304,7 +304,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -363,7 +363,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -390,7 +390,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -419,7 +419,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -450,7 +450,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -532,7 +532,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -570,7 +570,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -598,7 +598,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -631,7 +631,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -684,7 +684,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -778,7 +778,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC - cms.Release() + cms.Release(t) } if int(msg.Flags) != mflags { @@ -798,7 +798,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } - defer cms.Release() + defer cms.Release(t) controlData := make([]byte, 0, msg.ControlLen) controlData = control.PackControlMessages(t, cms, controlData) @@ -854,7 +854,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -883,7 +883,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag } n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) - cm.Release() + cm.Release(t) if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } @@ -927,7 +927,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -965,7 +965,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) @@ -1069,7 +1069,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) if err != nil { - controlMessages.Release() + controlMessages.Release(t) } return uintptr(n), err } @@ -1087,7 +1087,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags if file == nil { return 0, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // Extract the socket. s, ok := file.Impl().(socket.SocketVFS2) diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 63ab11f8c..16f59fce9 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -53,12 +53,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) outFile := t.GetFileVFS2(outFD) if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) // Check that both files support the required directionality. if !inFile.IsReadable() || !outFile.IsWritable() { @@ -175,7 +175,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // On Linux, inotify behavior is not very consistent with splice(2). We try // our best to emulate Linux for very basic calls to splice, where for some // reason, events are generated for output files, but not input files. - outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) return uintptr(n), nil, nil } @@ -203,12 +203,12 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) outFile := t.GetFileVFS2(outFD) if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) // Check that both files support the required directionality. if !inFile.IsReadable() || !outFile.IsWritable() { @@ -251,7 +251,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if n == 0 { return 0, nil, err } - outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) return uintptr(n), nil, nil } @@ -266,7 +266,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if inFile == nil { return 0, nil, syserror.EBADF } - defer inFile.DecRef() + defer inFile.DecRef(t) if !inFile.IsReadable() { return 0, nil, syserror.EBADF } @@ -275,7 +275,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if outFile == nil { return 0, nil, syserror.EBADF } - defer outFile.DecRef() + defer outFile.DecRef(t) if !outFile.IsWritable() { return 0, nil, syserror.EBADF } @@ -419,8 +419,8 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, err } - inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) - outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) return uintptr(n), nil, nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go index bb1d5cac4..0f5d5189c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/stat.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -65,7 +65,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags } root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root if !path.Absolute { if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { @@ -73,7 +73,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } else { dirfile := t.GetFileVFS2(dirfd) if dirfile == nil { @@ -85,7 +85,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags // former may be able to use opened file state to expedite the // Stat. statx, err := dirfile.Stat(t, opts) - dirfile.DecRef() + dirfile.DecRef(t) if err != nil { return err } @@ -96,8 +96,8 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags } start = dirfile.VirtualDentry() start.IncRef() - defer start.DecRef() - dirfile.DecRef() + defer start.DecRef(t) + dirfile.DecRef(t) } } @@ -132,7 +132,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) statx, err := file.Stat(t, vfs.StatOptions{ Mask: linux.STATX_BASIC_STATS, @@ -177,7 +177,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } root := t.FSContext().RootDirectoryVFS2() - defer root.DecRef() + defer root.DecRef(t) start := root if !path.Absolute { if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 { @@ -185,7 +185,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } if dirfd == linux.AT_FDCWD { start = t.FSContext().WorkingDirectoryVFS2() - defer start.DecRef() + defer start.DecRef(t) } else { dirfile := t.GetFileVFS2(dirfd) if dirfile == nil { @@ -197,7 +197,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // former may be able to use opened file state to expedite the // Stat. statx, err := dirfile.Stat(t, opts) - dirfile.DecRef() + dirfile.DecRef(t) if err != nil { return 0, nil, err } @@ -207,8 +207,8 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } start = dirfile.VirtualDentry() start.IncRef() - defer start.DecRef() - dirfile.DecRef() + defer start.DecRef(t) + dirfile.DecRef(t) } } @@ -282,7 +282,7 @@ func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) err if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) // access(2) and faccessat(2) check permissions using real // UID/GID, not effective UID/GID. @@ -328,7 +328,7 @@ func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, siz if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop) if err != nil { @@ -358,7 +358,7 @@ func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop) if err != nil { @@ -377,7 +377,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop) if err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index 0d0ebf46a..a6491ac37 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -34,7 +34,7 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, file.SyncFS(t) } @@ -47,7 +47,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) return 0, nil, file.Sync(t) } @@ -77,7 +77,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) // TODO(gvisor.dev/issue/1897): Currently, the only file syncing we support // is a full-file sync, i.e. fsync(2). As a result, there are severe diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go index 5ac79bc09..7a26890ef 100644 --- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go @@ -50,11 +50,11 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.EINVAL } vfsObj := t.Kernel().VFS() - file, err := timerfd.New(vfsObj, clock, fileFlags) + file, err := timerfd.New(t, vfsObj, clock, fileFlags) if err != nil { return 0, nil, err } - defer file.DecRef() + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.TFD_CLOEXEC != 0, }) @@ -79,7 +79,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) tfd, ok := file.Impl().(*timerfd.TimerFileDescription) if !ok { @@ -113,7 +113,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) tfd, ok := file.Impl().(*timerfd.TimerFileDescription) if !ok { diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index af455d5c1..ef99246ed 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -49,7 +49,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) if err != nil { @@ -72,7 +72,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) names, err := file.Listxattr(t, uint64(size)) if err != nil { @@ -109,7 +109,7 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli if err != nil { return 0, nil, err } - defer tpop.Release() + defer tpop.Release(t) name, err := copyInXattrName(t, nameAddr) if err != nil { @@ -141,7 +141,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) name, err := copyInXattrName(t, nameAddr) if err != nil { @@ -188,7 +188,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) name, err := copyInXattrName(t, nameAddr) if err != nil { @@ -222,7 +222,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) name, err := copyInXattrName(t, nameAddr) if err != nil { @@ -262,7 +262,7 @@ func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSy if err != nil { return err } - defer tpop.Release() + defer tpop.Release(t) name, err := copyInXattrName(t, nameAddr) if err != nil { @@ -281,7 +281,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if file == nil { return 0, nil, syserror.EBADF } - defer file.DecRef() + defer file.DecRef(t) name, err := copyInXattrName(t, nameAddr) if err != nil { diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 641e3e502..5a0e3e6b5 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -82,7 +82,7 @@ type anonDentry struct { } // Release implements FilesystemImpl.Release. -func (fs *anonFilesystem) Release() { +func (fs *anonFilesystem) Release(ctx context.Context) { } // Sync implements FilesystemImpl.Sync. @@ -294,7 +294,7 @@ func (d *anonDentry) TryIncRef() bool { } // DecRef implements DentryImpl.DecRef. -func (d *anonDentry) DecRef() { +func (d *anonDentry) DecRef(ctx context.Context) { // no-op } @@ -303,7 +303,7 @@ func (d *anonDentry) DecRef() { // Although Linux technically supports inotify on pseudo filesystems (inotify // is implemented at the vfs layer), it is not particularly useful. It is left // unimplemented until someone actually needs it. -func (d *anonDentry) InotifyWithParent(events, cookie uint32, et EventType) {} +func (d *anonDentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) {} // Watches implements DentryImpl.Watches. func (d *anonDentry) Watches() *Watches { @@ -311,4 +311,4 @@ func (d *anonDentry) Watches() *Watches { } // OnZeroWatches implements Dentry.OnZeroWatches. -func (d *anonDentry) OnZeroWatches() {} +func (d *anonDentry) OnZeroWatches(context.Context) {} diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index cea3e6955..bc7ea93ea 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -17,6 +17,7 @@ package vfs import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -102,7 +103,7 @@ type DentryImpl interface { TryIncRef() bool // DecRef decrements the Dentry's reference count. - DecRef() + DecRef(ctx context.Context) // InotifyWithParent notifies all watches on the targets represented by this // dentry and its parent. The parent's watches are notified first, followed @@ -113,7 +114,7 @@ type DentryImpl interface { // // Note that the events may not actually propagate up to the user, depending // on the event masks. - InotifyWithParent(events, cookie uint32, et EventType) + InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) // Watches returns the set of inotify watches for the file corresponding to // the Dentry. Dentries that are hard links to the same underlying file @@ -135,7 +136,7 @@ type DentryImpl interface { // The caller does not need to hold a reference on the dentry. OnZeroWatches // may acquire inotify locks, so to prevent deadlock, no inotify locks should // be held by the caller. - OnZeroWatches() + OnZeroWatches(ctx context.Context) } // IncRef increments d's reference count. @@ -150,8 +151,8 @@ func (d *Dentry) TryIncRef() bool { } // DecRef decrements d's reference count. -func (d *Dentry) DecRef() { - d.impl.DecRef() +func (d *Dentry) DecRef(ctx context.Context) { + d.impl.DecRef(ctx) } // IsDead returns true if d has been deleted or invalidated by its owning @@ -168,8 +169,8 @@ func (d *Dentry) isMounted() bool { // InotifyWithParent notifies all watches on the targets represented by d and // its parent of events. -func (d *Dentry) InotifyWithParent(events, cookie uint32, et EventType) { - d.impl.InotifyWithParent(events, cookie, et) +func (d *Dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) { + d.impl.InotifyWithParent(ctx, events, cookie, et) } // Watches returns the set of inotify watches associated with d. @@ -182,8 +183,8 @@ func (d *Dentry) Watches() *Watches { // OnZeroWatches performs cleanup tasks whenever the number of watches on a // dentry drops to zero. -func (d *Dentry) OnZeroWatches() { - d.impl.OnZeroWatches() +func (d *Dentry) OnZeroWatches(ctx context.Context) { + d.impl.OnZeroWatches(ctx) } // The following functions are exported so that filesystem implementations can @@ -214,11 +215,11 @@ func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) { // CommitDeleteDentry must be called after PrepareDeleteDentry if the deletion // succeeds. -func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) { +func (vfs *VirtualFilesystem) CommitDeleteDentry(ctx context.Context, d *Dentry) { d.dead = true d.mu.Unlock() if d.isMounted() { - vfs.forgetDeadMountpoint(d) + vfs.forgetDeadMountpoint(ctx, d) } } @@ -226,12 +227,12 @@ func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) { // did for reasons outside of VFS' control (e.g. d represents the local state // of a file on a remote filesystem on which the file has already been // deleted). -func (vfs *VirtualFilesystem) InvalidateDentry(d *Dentry) { +func (vfs *VirtualFilesystem) InvalidateDentry(ctx context.Context, d *Dentry) { d.mu.Lock() d.dead = true d.mu.Unlock() if d.isMounted() { - vfs.forgetDeadMountpoint(d) + vfs.forgetDeadMountpoint(ctx, d) } } @@ -278,13 +279,13 @@ func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) { // that was replaced by from. // // Preconditions: PrepareRenameDentry was previously called on from and to. -func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, to *Dentry) { +func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, from, to *Dentry) { from.mu.Unlock() if to != nil { to.dead = true to.mu.Unlock() if to.isMounted() { - vfs.forgetDeadMountpoint(to) + vfs.forgetDeadMountpoint(ctx, to) } } } @@ -303,7 +304,7 @@ func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) { // // forgetDeadMountpoint is analogous to Linux's // fs/namespace.c:__detach_mounts(). -func (vfs *VirtualFilesystem) forgetDeadMountpoint(d *Dentry) { +func (vfs *VirtualFilesystem) forgetDeadMountpoint(ctx context.Context, d *Dentry) { var ( vdsToDecRef []VirtualDentry mountsToDecRef []*Mount @@ -316,9 +317,9 @@ func (vfs *VirtualFilesystem) forgetDeadMountpoint(d *Dentry) { vfs.mounts.seq.EndWrite() vfs.mountMu.Unlock() for _, vd := range vdsToDecRef { - vd.DecRef() + vd.DecRef(ctx) } for _, mnt := range mountsToDecRef { - mnt.DecRef() + mnt.DecRef(ctx) } } diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 5b009b928..1b5af9f73 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -93,9 +93,9 @@ type epollInterest struct { // NewEpollInstanceFD returns a FileDescription representing a new epoll // instance. A reference is taken on the returned FileDescription. -func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) { +func (vfs *VirtualFilesystem) NewEpollInstanceFD(ctx context.Context) (*FileDescription, error) { vd := vfs.NewAnonVirtualDentry("[eventpoll]") - defer vd.DecRef() + defer vd.DecRef(ctx) ep := &EpollInstance{ interest: make(map[epollInterestKey]*epollInterest), } @@ -110,7 +110,7 @@ func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) { } // Release implements FileDescriptionImpl.Release. -func (ep *EpollInstance) Release() { +func (ep *EpollInstance) Release(ctx context.Context) { // Unregister all polled fds. ep.interestMu.Lock() defer ep.interestMu.Unlock() diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 93861fb4a..576ab3920 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -171,7 +171,7 @@ func (fd *FileDescription) TryIncRef() bool { } // DecRef decrements fd's reference count. -func (fd *FileDescription) DecRef() { +func (fd *FileDescription) DecRef(ctx context.Context) { if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 { // Unregister fd from all epoll instances. fd.epollMu.Lock() @@ -196,11 +196,11 @@ func (fd *FileDescription) DecRef() { } // Release implementation resources. - fd.impl.Release() + fd.impl.Release(ctx) if fd.writable { fd.vd.mount.EndWrite() } - fd.vd.DecRef() + fd.vd.DecRef(ctx) fd.flagsMu.Lock() // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { @@ -335,7 +335,7 @@ func (fd *FileDescription) Impl() FileDescriptionImpl { type FileDescriptionImpl interface { // Release is called when the associated FileDescription reaches zero // references. - Release() + Release(ctx context.Context) // OnClose is called when a file descriptor representing the // FileDescription is closed. Note that returning a non-nil error does not @@ -526,7 +526,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St Start: fd.vd, }) stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return stat, err } return fd.impl.Stat(ctx, opts) @@ -541,7 +541,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return err } return fd.impl.SetStat(ctx, opts) @@ -557,7 +557,7 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vd, }) statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return statfs, err } return fd.impl.StatFS(ctx) @@ -674,7 +674,7 @@ func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string Start: fd.vd, }) names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return names, err } names, err := fd.impl.Listxattr(ctx, size) @@ -703,7 +703,7 @@ func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) Start: fd.vd, }) val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return val, err } return fd.impl.Getxattr(ctx, *opts) @@ -719,7 +719,7 @@ func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return err } return fd.impl.Setxattr(ctx, *opts) @@ -735,7 +735,7 @@ func (fd *FileDescription) Removexattr(ctx context.Context, name string) error { Start: fd.vd, }) err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name) - vfsObj.putResolvingPath(rp) + vfsObj.putResolvingPath(ctx, rp) return err } return fd.impl.Removexattr(ctx, name) @@ -752,7 +752,7 @@ func (fd *FileDescription) MappedName(ctx context.Context) string { vfsroot := RootFromContext(ctx) s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd) if vfsroot.Ok() { - vfsroot.DecRef() + vfsroot.DecRef(ctx) } return s } diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 3b7e1c273..1cd607c0a 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -80,9 +80,9 @@ type testFD struct { data DynamicBytesSource } -func newTestFD(vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription { +func newTestFD(ctx context.Context, vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription { vd := vfsObj.NewAnonVirtualDentry("genCountFD") - defer vd.DecRef() + defer vd.DecRef(ctx) var fd testFD fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}) fd.DynamicBytesFileDescriptionImpl.SetDataSource(data) @@ -90,7 +90,7 @@ func newTestFD(vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesS } // Release implements FileDescriptionImpl.Release. -func (fd *testFD) Release() { +func (fd *testFD) Release(context.Context) { } // SetStatusFlags implements FileDescriptionImpl.SetStatusFlags. @@ -109,11 +109,11 @@ func TestGenCountFD(t *testing.T) { ctx := contexttest.Context(t) vfsObj := &VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } - fd := newTestFD(vfsObj, linux.O_RDWR, &genCount{}) - defer fd.DecRef() + fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &genCount{}) + defer fd.DecRef(ctx) // The first read causes Generate to be called to fill the FD's buffer. buf := make([]byte, 2) @@ -167,11 +167,11 @@ func TestWritable(t *testing.T) { ctx := contexttest.Context(t) vfsObj := &VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { + if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) } - fd := newTestFD(vfsObj, linux.O_RDWR, &storeData{data: "init"}) - defer fd.DecRef() + fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &storeData{data: "init"}) + defer fd.DecRef(ctx) buf := make([]byte, 10) ioseq := usermem.BytesIOSequence(buf) diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 6bb9ca180..df3758fd1 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -100,12 +100,12 @@ func (fs *Filesystem) TryIncRef() bool { } // DecRef decrements fs' reference count. -func (fs *Filesystem) DecRef() { +func (fs *Filesystem) DecRef(ctx context.Context) { if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 { fs.vfs.filesystemsMu.Lock() delete(fs.vfs.filesystems, fs) fs.vfs.filesystemsMu.Unlock() - fs.impl.Release() + fs.impl.Release(ctx) } else if refs < 0 { panic("Filesystem.decRef() called without holding a reference") } @@ -149,7 +149,7 @@ func (fs *Filesystem) DecRef() { type FilesystemImpl interface { // Release is called when the associated Filesystem reaches zero // references. - Release() + Release(ctx context.Context) // Sync "causes all pending modifications to filesystem metadata and cached // file data to be written to the underlying [filesystem]", as by syncfs(2). diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 167b731ac..aff220a61 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -100,7 +100,7 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) id := uniqueid.GlobalFromContext(ctx) vd := vfsObj.NewAnonVirtualDentry(fmt.Sprintf("[inotifyfd:%d]", id)) - defer vd.DecRef() + defer vd.DecRef(ctx) fd := &Inotify{ id: id, scratch: make([]byte, inotifyEventBaseSize), @@ -118,7 +118,7 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) // Release implements FileDescriptionImpl.Release. Release removes all // watches and frees all resources for an inotify instance. -func (i *Inotify) Release() { +func (i *Inotify) Release(ctx context.Context) { var ds []*Dentry // We need to hold i.mu to avoid a race with concurrent calls to @@ -144,7 +144,7 @@ func (i *Inotify) Release() { i.mu.Unlock() for _, d := range ds { - d.OnZeroWatches() + d.OnZeroWatches(ctx) } } @@ -350,7 +350,7 @@ func (i *Inotify) AddWatch(target *Dentry, mask uint32) (int32, error) { // RmWatch looks up an inotify watch for the given 'wd' and configures the // target to stop sending events to this inotify instance. -func (i *Inotify) RmWatch(wd int32) error { +func (i *Inotify) RmWatch(ctx context.Context, wd int32) error { i.mu.Lock() // Find the watch we were asked to removed. @@ -374,7 +374,7 @@ func (i *Inotify) RmWatch(wd int32) error { i.mu.Unlock() if remaining == 0 { - w.target.OnZeroWatches() + w.target.OnZeroWatches(ctx) } // Generate the event for the removal. @@ -462,7 +462,7 @@ func (w *Watches) Remove(id uint64) { // Notify queues a new event with watches in this set. Watches with // IN_EXCL_UNLINK are skipped if the event is coming from a child that has been // unlinked. -func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlinked bool) { +func (w *Watches) Notify(ctx context.Context, name string, events, cookie uint32, et EventType, unlinked bool) { var hasExpired bool w.mu.RLock() for _, watch := range w.ws { @@ -476,13 +476,13 @@ func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlin w.mu.RUnlock() if hasExpired { - w.cleanupExpiredWatches() + w.cleanupExpiredWatches(ctx) } } // This function is relatively expensive and should only be called where there // are expired watches. -func (w *Watches) cleanupExpiredWatches() { +func (w *Watches) cleanupExpiredWatches(ctx context.Context) { // Because of lock ordering, we cannot acquire Inotify.mu for each watch // owner while holding w.mu. As a result, store expired watches locally // before removing. @@ -495,15 +495,15 @@ func (w *Watches) cleanupExpiredWatches() { } w.mu.RUnlock() for _, watch := range toRemove { - watch.owner.RmWatch(watch.wd) + watch.owner.RmWatch(ctx, watch.wd) } } // HandleDeletion is called when the watch target is destroyed. Clear the // watch set, detach watches from the inotify instances they belong to, and // generate the appropriate events. -func (w *Watches) HandleDeletion() { - w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */) +func (w *Watches) HandleDeletion(ctx context.Context) { + w.Notify(ctx, "", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */) // As in Watches.Notify, we can't hold w.mu while acquiring Inotify.mu for // the owner of each watch being deleted. Instead, atomically store the @@ -744,12 +744,12 @@ func InotifyEventFromStatMask(mask uint32) uint32 { // InotifyRemoveChild sends the appriopriate notifications to the watch sets of // the child being removed and its parent. Note that unlike most pairs of // parent/child notifications, the child is notified first in this case. -func InotifyRemoveChild(self, parent *Watches, name string) { +func InotifyRemoveChild(ctx context.Context, self, parent *Watches, name string) { if self != nil { - self.Notify("", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */) + self.Notify(ctx, "", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */) } if parent != nil { - parent.Notify(name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */) + parent.Notify(ctx, name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */) } } @@ -762,13 +762,13 @@ func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, } cookie := uniqueid.InotifyCookie(ctx) if oldParent != nil { - oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */) + oldParent.Notify(ctx, oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */) } if newParent != nil { - newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */) + newParent.Notify(ctx, newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */) } // Somewhat surprisingly, self move events do not have a cookie. if renamed != nil { - renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */) + renamed.Notify(ctx, "", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */) } } diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 32f901bd8..d1d29d0cd 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -200,8 +200,8 @@ func (vfs *VirtualFilesystem) MountDisconnected(ctx context.Context, creds *auth if err != nil { return nil, err } - defer root.DecRef() - defer fs.DecRef() + defer root.DecRef(ctx) + defer fs.DecRef(ctx) return vfs.NewDisconnectedMount(fs, root, opts) } @@ -221,7 +221,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr if vd.dentry.dead { vd.dentry.mu.Unlock() vfs.mountMu.Unlock() - vd.DecRef() + vd.DecRef(ctx) return syserror.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and @@ -243,7 +243,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr // This can't fail since we're holding vfs.mountMu. nextmnt.root.IncRef() vd.dentry.mu.Unlock() - vd.DecRef() + vd.DecRef(ctx) vd = VirtualDentry{ mount: nextmnt, dentry: nextmnt.root, @@ -268,7 +268,7 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia if err != nil { return err } - defer mnt.DecRef() + defer mnt.DecRef(ctx) if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil { return err } @@ -293,13 +293,13 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti if err != nil { return err } - defer vd.DecRef() + defer vd.DecRef(ctx) if vd.dentry != vd.mount.root { return syserror.EINVAL } vfs.mountMu.Lock() if mntns := MountNamespaceFromContext(ctx); mntns != nil { - defer mntns.DecRef() + defer mntns.DecRef(ctx) if mntns != vd.mount.ns { vfs.mountMu.Unlock() return syserror.EINVAL @@ -335,10 +335,10 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti vfs.mounts.seq.EndWrite() vfs.mountMu.Unlock() for _, vd := range vdsToDecRef { - vd.DecRef() + vd.DecRef(ctx) } for _, mnt := range mountsToDecRef { - mnt.DecRef() + mnt.DecRef(ctx) } return nil } @@ -479,7 +479,7 @@ func (mnt *Mount) IncRef() { } // DecRef decrements mnt's reference count. -func (mnt *Mount) DecRef() { +func (mnt *Mount) DecRef(ctx context.Context) { refs := atomic.AddInt64(&mnt.refs, -1) if refs&^math.MinInt64 == 0 { // mask out MSB var vd VirtualDentry @@ -490,10 +490,10 @@ func (mnt *Mount) DecRef() { mnt.vfs.mounts.seq.EndWrite() mnt.vfs.mountMu.Unlock() } - mnt.root.DecRef() - mnt.fs.DecRef() + mnt.root.DecRef(ctx) + mnt.fs.DecRef(ctx) if vd.Ok() { - vd.DecRef() + vd.DecRef(ctx) } } } @@ -506,7 +506,7 @@ func (mntns *MountNamespace) IncRef() { } // DecRef decrements mntns' reference count. -func (mntns *MountNamespace) DecRef() { +func (mntns *MountNamespace) DecRef(ctx context.Context) { vfs := mntns.root.fs.VirtualFilesystem() if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 { vfs.mountMu.Lock() @@ -517,10 +517,10 @@ func (mntns *MountNamespace) DecRef() { vfs.mounts.seq.EndWrite() vfs.mountMu.Unlock() for _, vd := range vdsToDecRef { - vd.DecRef() + vd.DecRef(ctx) } for _, mnt := range mountsToDecRef { - mnt.DecRef() + mnt.DecRef(ctx) } } else if refs < 0 { panic("MountNamespace.DecRef() called without holding a reference") @@ -534,7 +534,7 @@ func (mntns *MountNamespace) DecRef() { // getMountAt is analogous to Linux's fs/namei.c:follow_mount(). // // Preconditions: References are held on mnt and d. -func (vfs *VirtualFilesystem) getMountAt(mnt *Mount, d *Dentry) *Mount { +func (vfs *VirtualFilesystem) getMountAt(ctx context.Context, mnt *Mount, d *Dentry) *Mount { // The first mount is special-cased: // // - The caller is assumed to have checked d.isMounted() already. (This @@ -565,7 +565,7 @@ retryFirst: // Raced with umount. continue } - mnt.DecRef() + mnt.DecRef(ctx) mnt = next d = next.root } @@ -578,7 +578,7 @@ retryFirst: // // Preconditions: References are held on mnt and root. vfsroot is not (mnt, // mnt.root). -func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) VirtualDentry { +func (vfs *VirtualFilesystem) getMountpointAt(ctx context.Context, mnt *Mount, vfsroot VirtualDentry) VirtualDentry { // The first mount is special-cased: // // - The caller must have already checked mnt against vfsroot. @@ -602,12 +602,12 @@ retryFirst: if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can only // happen due to a racing change to Mount.key. - parent.DecRef() + parent.DecRef(ctx) goto retryFirst } if !vfs.mounts.seq.ReadOk(epoch) { - point.DecRef() - parent.DecRef() + point.DecRef(ctx) + parent.DecRef(ctx) goto retryFirst } mnt = parent @@ -635,16 +635,16 @@ retryFirst: if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can // only happen due to a racing change to Mount.key. - parent.DecRef() + parent.DecRef(ctx) goto retryNotFirst } if !vfs.mounts.seq.ReadOk(epoch) { - point.DecRef() - parent.DecRef() + point.DecRef(ctx) + parent.DecRef(ctx) goto retryNotFirst } - d.DecRef() - mnt.DecRef() + d.DecRef(ctx) + mnt.DecRef(ctx) mnt = parent d = point } diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index cd78d66bc..e4da15009 100644 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go @@ -47,7 +47,7 @@ func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot, haveRef := false defer func() { if haveRef { - vd.DecRef() + vd.DecRef(ctx) } }() @@ -64,12 +64,12 @@ loop: // of FilesystemImpl.PrependPath() may return nil instead. break loop } - nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot) if !nextVD.Ok() { break loop } if haveRef { - vd.DecRef() + vd.DecRef(ctx) } vd = nextVD haveRef = true @@ -101,7 +101,7 @@ func (vfs *VirtualFilesystem) PathnameReachable(ctx context.Context, vfsroot, vd haveRef := false defer func() { if haveRef { - vd.DecRef() + vd.DecRef(ctx) } }() loop: @@ -112,12 +112,12 @@ loop: if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { break loop } - nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot) if !nextVD.Ok() { return "", nil } if haveRef { - vd.DecRef() + vd.DecRef(ctx) } vd = nextVD haveRef = true @@ -145,7 +145,7 @@ func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd haveRef := false defer func() { if haveRef { - vd.DecRef() + vd.DecRef(ctx) } }() unreachable := false @@ -157,13 +157,13 @@ loop: if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { break loop } - nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot) if !nextVD.Ok() { unreachable = true break loop } if haveRef { - vd.DecRef() + vd.DecRef(ctx) } vd = nextVD haveRef = true diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 9d047ff88..3304372d9 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" @@ -136,31 +137,31 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat return rp } -func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) { +func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) { rp.root = VirtualDentry{} - rp.decRefStartAndMount() + rp.decRefStartAndMount(ctx) rp.mount = nil rp.start = nil - rp.releaseErrorState() + rp.releaseErrorState(ctx) resolvingPathPool.Put(rp) } -func (rp *ResolvingPath) decRefStartAndMount() { +func (rp *ResolvingPath) decRefStartAndMount(ctx context.Context) { if rp.flags&rpflagsHaveStartRef != 0 { - rp.start.DecRef() + rp.start.DecRef(ctx) } if rp.flags&rpflagsHaveMountRef != 0 { - rp.mount.DecRef() + rp.mount.DecRef(ctx) } } -func (rp *ResolvingPath) releaseErrorState() { +func (rp *ResolvingPath) releaseErrorState(ctx context.Context) { if rp.nextStart != nil { - rp.nextStart.DecRef() + rp.nextStart.DecRef(ctx) rp.nextStart = nil } if rp.nextMount != nil { - rp.nextMount.DecRef() + rp.nextMount.DecRef(ctx) rp.nextMount = nil } } @@ -236,13 +237,13 @@ func (rp *ResolvingPath) Advance() { // Restart resets the stream of path components represented by rp to its state // on entry to the current FilesystemImpl method. -func (rp *ResolvingPath) Restart() { +func (rp *ResolvingPath) Restart(ctx context.Context) { rp.pit = rp.origParts[rp.numOrigParts-1] rp.mustBeDir = rp.mustBeDirOrig rp.symlinks = rp.symlinksOrig rp.curPart = rp.numOrigParts - 1 copy(rp.parts[:], rp.origParts[:rp.numOrigParts]) - rp.releaseErrorState() + rp.releaseErrorState(ctx) } func (rp *ResolvingPath) relpathCommit() { @@ -260,13 +261,13 @@ func (rp *ResolvingPath) relpathCommit() { // Mount, CheckRoot returns (unspecified, non-nil error). Otherwise, path // resolution should resolve d's parent normally, and CheckRoot returns (false, // nil). -func (rp *ResolvingPath) CheckRoot(d *Dentry) (bool, error) { +func (rp *ResolvingPath) CheckRoot(ctx context.Context, d *Dentry) (bool, error) { if d == rp.root.dentry && rp.mount == rp.root.mount { // At contextual VFS root (due to e.g. chroot(2)). return true, nil } else if d == rp.mount.root { // At mount root ... - vd := rp.vfs.getMountpointAt(rp.mount, rp.root) + vd := rp.vfs.getMountpointAt(ctx, rp.mount, rp.root) if vd.Ok() { // ... of non-root mount. rp.nextMount = vd.mount @@ -283,11 +284,11 @@ func (rp *ResolvingPath) CheckRoot(d *Dentry) (bool, error) { // to d. If d is a mount point, such that path resolution should switch to // another Mount, CheckMount returns a non-nil error. Otherwise, CheckMount // returns nil. -func (rp *ResolvingPath) CheckMount(d *Dentry) error { +func (rp *ResolvingPath) CheckMount(ctx context.Context, d *Dentry) error { if !d.isMounted() { return nil } - if mnt := rp.vfs.getMountAt(rp.mount, d); mnt != nil { + if mnt := rp.vfs.getMountAt(ctx, rp.mount, d); mnt != nil { rp.nextMount = mnt return resolveMountPointError{} } @@ -389,11 +390,11 @@ func (rp *ResolvingPath) HandleJump(target VirtualDentry) error { return resolveMountRootOrJumpError{} } -func (rp *ResolvingPath) handleError(err error) bool { +func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { switch err.(type) { case resolveMountRootOrJumpError: // Switch to the new Mount. We hold references on the Mount and Dentry. - rp.decRefStartAndMount() + rp.decRefStartAndMount(ctx) rp.mount = rp.nextMount rp.start = rp.nextStart rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef @@ -412,7 +413,7 @@ func (rp *ResolvingPath) handleError(err error) bool { case resolveMountPointError: // Switch to the new Mount. We hold a reference on the Mount, but // borrow the reference on the mount root from the Mount. - rp.decRefStartAndMount() + rp.decRefStartAndMount(ctx) rp.mount = rp.nextMount rp.start = rp.nextMount.root rp.flags = rp.flags&^rpflagsHaveStartRef | rpflagsHaveMountRef @@ -423,12 +424,12 @@ func (rp *ResolvingPath) handleError(err error) bool { // path. rp.relpathCommit() // Restart path resolution on the new Mount. - rp.releaseErrorState() + rp.releaseErrorState(ctx) return true case resolveAbsSymlinkError: // Switch to the new Mount. References are borrowed from rp.root. - rp.decRefStartAndMount() + rp.decRefStartAndMount(ctx) rp.mount = rp.root.mount rp.start = rp.root.dentry rp.flags &^= rpflagsHaveMountRef | rpflagsHaveStartRef @@ -440,7 +441,7 @@ func (rp *ResolvingPath) handleError(err error) bool { // path, including the symlink target we just prepended. rp.relpathCommit() // Restart path resolution on the new Mount. - rp.releaseErrorState() + rp.releaseErrorState(ctx) return true default: diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 522e27475..9c2420683 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -122,7 +122,7 @@ type VirtualFilesystem struct { } // Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes. -func (vfs *VirtualFilesystem) Init() error { +func (vfs *VirtualFilesystem) Init(ctx context.Context) error { if vfs.mountpoints != nil { panic("VFS already initialized") } @@ -145,7 +145,7 @@ func (vfs *VirtualFilesystem) Init() error { devMinor: anonfsDevMinor, } anonfs.vfsfs.Init(vfs, &anonFilesystemType{}, &anonfs) - defer anonfs.vfsfs.DecRef() + defer anonfs.vfsfs.DecRef(ctx) anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{}) if err != nil { // We should not be passing any MountOptions that would cause @@ -192,11 +192,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -214,11 +214,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede dentry: d, } rp.mount.IncRef() - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return vd, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return VirtualDentry{}, err } } @@ -236,7 +236,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } rp.mount.IncRef() name := rp.Component() - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return parentVD, name, nil } if checkInvariants { @@ -244,8 +244,8 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return VirtualDentry{}, "", err } } @@ -260,14 +260,14 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential } if !newpop.Path.Begin.Ok() { - oldVD.DecRef() + oldVD.DecRef(ctx) if newpop.Path.Absolute { return syserror.EEXIST } return syserror.ENOENT } if newpop.FollowFinalSymlink { - oldVD.DecRef() + oldVD.DecRef(ctx) ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink") return syserror.EINVAL } @@ -276,8 +276,8 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential for { err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) if err == nil { - vfs.putResolvingPath(rp) - oldVD.DecRef() + vfs.putResolvingPath(ctx, rp) + oldVD.DecRef(ctx) return nil } if checkInvariants { @@ -285,9 +285,9 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - oldVD.DecRef() + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) + oldVD.DecRef(ctx) return err } } @@ -313,7 +313,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -321,8 +321,8 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -346,7 +346,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -354,8 +354,8 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -408,31 +408,31 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential for { fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) if opts.FileExec { if fd.Mount().Flags.NoExec { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.EACCES } // Only a regular file can be executed. stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE}) if err != nil { - fd.DecRef() + fd.DecRef(ctx) return nil, err } if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG { - fd.DecRef() + fd.DecRef(ctx) return nil, syserror.EACCES } } - fd.Dentry().InotifyWithParent(linux.IN_OPEN, 0, PathEvent) + fd.Dentry().InotifyWithParent(ctx, linux.IN_OPEN, 0, PathEvent) return fd, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return nil, err } } @@ -444,11 +444,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden for { target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return target, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return "", err } } @@ -472,19 +472,19 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti return err } if oldName == "." || oldName == ".." { - oldParentVD.DecRef() + oldParentVD.DecRef(ctx) return syserror.EBUSY } if !newpop.Path.Begin.Ok() { - oldParentVD.DecRef() + oldParentVD.DecRef(ctx) if newpop.Path.Absolute { return syserror.EBUSY } return syserror.ENOENT } if newpop.FollowFinalSymlink { - oldParentVD.DecRef() + oldParentVD.DecRef(ctx) ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink") return syserror.EINVAL } @@ -497,8 +497,8 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts) if err == nil { - vfs.putResolvingPath(rp) - oldParentVD.DecRef() + vfs.putResolvingPath(ctx, rp) + oldParentVD.DecRef(ctx) return nil } if checkInvariants { @@ -506,9 +506,9 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - oldParentVD.DecRef() + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) + oldParentVD.DecRef(ctx) return err } } @@ -531,7 +531,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.RmdirAt(ctx, rp) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -539,8 +539,8 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -552,11 +552,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -568,11 +568,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential for { stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return stat, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return linux.Statx{}, err } } @@ -585,11 +585,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti for { statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return statfs, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return linux.Statfs{}, err } } @@ -612,7 +612,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -620,8 +620,8 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -644,7 +644,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.UnlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } if checkInvariants { @@ -652,8 +652,8 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -671,7 +671,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C for { bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return bep, nil } if checkInvariants { @@ -679,8 +679,8 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err)) } } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return nil, err } } @@ -693,7 +693,7 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede for { names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return names, nil } if err == syserror.ENOTSUP { @@ -701,11 +701,11 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by // default don't exist. - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return nil, err } } @@ -718,11 +718,11 @@ func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Creden for { val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return val, nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return "", err } } @@ -735,11 +735,11 @@ func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Creden for { err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -751,11 +751,11 @@ func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Cre for { err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name) if err == nil { - vfs.putResolvingPath(rp) + vfs.putResolvingPath(ctx, rp) return nil } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) + if !rp.handleError(ctx, err) { + vfs.putResolvingPath(ctx, rp) return err } } @@ -777,7 +777,7 @@ func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { if err := fs.impl.Sync(ctx); err != nil && retErr == nil { retErr = err } - fs.DecRef() + fs.DecRef(ctx) } return retErr } @@ -831,9 +831,9 @@ func (vd VirtualDentry) IncRef() { // DecRef decrements the reference counts on the Mount and Dentry represented // by vd. -func (vd VirtualDentry) DecRef() { - vd.dentry.DecRef() - vd.mount.DecRef() +func (vd VirtualDentry) DecRef(ctx context.Context) { + vd.dentry.DecRef(ctx) + vd.mount.DecRef(ctx) } // Mount returns the Mount associated with vd. It does not take a reference on diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index e0db6cf54..6c137f693 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -12,6 +12,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/refs", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 04ae58e59..22b0a12bd 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -64,14 +65,14 @@ func (d *Device) beforeSave() { } // Release implements fs.FileOperations.Release. -func (d *Device) Release() { +func (d *Device) Release(ctx context.Context) { d.mu.Lock() defer d.mu.Unlock() // Decrease refcount if there is an endpoint associated with this file. if d.endpoint != nil { d.endpoint.RemoveNotify(d.notifyHandle) - d.endpoint.DecRef() + d.endpoint.DecRef(ctx) d.endpoint = nil } } @@ -341,8 +342,8 @@ type tunEndpoint struct { } // DecRef decrements refcount of e, removes NIC if refcount goes to 0. -func (e *tunEndpoint) DecRef() { - e.DecRefWithDestructor(func() { +func (e *tunEndpoint) DecRef(ctx context.Context) { + e.DecRefWithDestructor(ctx, func(context.Context) { e.stack.RemoveNIC(e.nicID) }) } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 59639ba19..9dd5b0184 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -640,7 +640,7 @@ func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Confi func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace) error { root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) for _, m := range c.mounts { log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options) @@ -868,7 +868,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns if err != nil { return fmt.Errorf("can't find mount destination %q: %v", m.Destination, err) } - defer dirent.DecRef() + defer dirent.DecRef(ctx) if err := mns.Mount(ctx, dirent, inode); err != nil { return fmt.Errorf("mount %q error: %v", m.Destination, err) } @@ -889,12 +889,12 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun if err != nil { return fmt.Errorf("can't find mount destination %q: %v", mount.Destination, err) } - defer target.DecRef() + defer target.DecRef(ctx) // Take a ref on the inode that is about to be (re)-mounted. source.root.IncRef() if err := mns.Mount(ctx, target, source.root); err != nil { - source.root.DecRef() + source.root.DecRef(ctx) return fmt.Errorf("bind mount %q error: %v", mount.Destination, err) } @@ -997,12 +997,12 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *Config, mns *fs.M switch err { case nil: // Found '/tmp' in filesystem, check if it's empty. - defer tmp.DecRef() + defer tmp.DecRef(ctx) f, err := tmp.Inode.GetFile(ctx, tmp, fs.FileFlags{Read: true, Directory: true}) if err != nil { return err } - defer f.DecRef() + defer f.DecRef(ctx) serializer := &fs.CollectEntriesSerializer{} if err := f.Readdir(ctx, serializer); err != nil { return err diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 9cd9c5909..e0d077f5a 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -346,7 +346,7 @@ func New(args Args) (*Loader, error) { if err != nil { return nil, fmt.Errorf("failed to create hostfs filesystem: %v", err) } - defer hostFilesystem.DecRef() + defer hostFilesystem.DecRef(k.SupervisorContext()) hostMount, err := k.VFS().NewDisconnectedMount(hostFilesystem, nil, &vfs.MountOptions{}) if err != nil { return nil, fmt.Errorf("failed to create hostfs mount: %v", err) @@ -755,7 +755,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn return nil, fmt.Errorf("creating process: %v", err) } // CreateProcess takes a reference on FDTable if successful. - info.procArgs.FDTable.DecRef() + info.procArgs.FDTable.DecRef(ctx) // Set the foreground process group on the TTY to the global init process // group, since that is what we are about to start running. @@ -890,22 +890,20 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) { // Add the HOME environment variable if it is not already set. if kernel.VFS2Enabled { - defer args.MountNamespaceVFS2.DecRef() - root := args.MountNamespaceVFS2.Root() - defer root.DecRef() ctx := vfs.WithRoot(l.k.SupervisorContext(), root) + defer args.MountNamespaceVFS2.DecRef(ctx) + defer root.DecRef(ctx) envv, err := user.MaybeAddExecUserHomeVFS2(ctx, args.MountNamespaceVFS2, args.KUID, args.Envv) if err != nil { return 0, err } args.Envv = envv } else { - defer args.MountNamespace.DecRef() - root := args.MountNamespace.Root() - defer root.DecRef() ctx := fs.WithRoot(l.k.SupervisorContext(), root) + defer args.MountNamespace.DecRef(ctx) + defer root.DecRef(ctx) envv, err := user.MaybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv) if err != nil { return 0, err @@ -1263,7 +1261,7 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F fdTable := k.NewFDTable() ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, console, stdioFDs) if err != nil { - fdTable.DecRef() + fdTable.DecRef(ctx) return nil, nil, nil, err } return fdTable, ttyFile, ttyFileVFS2, nil diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 8e6fe57e1..aa3fdf96c 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -450,13 +450,13 @@ func TestCreateMountNamespace(t *testing.T) { } root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) for _, p := range tc.expectedPaths { maxTraversals := uint(0) if d, err := mns.FindInode(ctx, root, root, p, &maxTraversals); err != nil { t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err) } else { - d.DecRef() + d.DecRef(ctx) } } }) @@ -491,7 +491,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { } root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) for _, p := range tc.expectedPaths { target := &vfs.PathOperation{ Root: root, @@ -502,7 +502,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { if d, err := l.k.VFS().GetDentryAt(ctx, l.root.procArgs.Credentials, target, &vfs.GetDentryOptions{}); err != nil { t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err) } else { - d.DecRef() + d.DecRef(ctx) } } }) diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index cfe2d36aa..252ca07e3 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -103,7 +103,7 @@ func registerFilesystems(k *kernel.Kernel) error { if err != nil { return fmt.Errorf("creating devtmpfs accessor: %w", err) } - defer a.Release() + defer a.Release(ctx) if err := a.UserspaceInit(ctx); err != nil { return fmt.Errorf("initializing userspace: %w", err) @@ -252,7 +252,7 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) error { root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) target := &vfs.PathOperation{ Root: root, Start: root, @@ -387,7 +387,7 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds } root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) pop := vfs.PathOperation{ Root: root, Start: root, @@ -481,10 +481,10 @@ func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *Co if err != nil { return err } - defer newMnt.DecRef() + defer newMnt.DecRef(ctx) root := mns.Root() - defer root.DecRef() + defer root.DecRef(ctx) if err := c.makeSyntheticMount(ctx, mount.Destination, root, creds); err != nil { return err } -- cgit v1.2.3 From 35312a95c4c8626365b4ece5ffb0bcab44b4bede Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Wed, 5 Aug 2020 20:45:02 -0700 Subject: Add loss recovery option for TCP. /proc/sys/net/ipv4/tcp_recovery is used to enable RACK loss recovery in TCP. PiperOrigin-RevId: 325157807 --- pkg/sentry/fs/proc/sys_net.go | 95 +++++++++++++++++++++++++++++++++++++ pkg/sentry/fsimpl/proc/tasks_sys.go | 49 ++++++++++++++++++- pkg/sentry/inet/inet.go | 17 +++++++ pkg/sentry/inet/test_stack.go | 12 +++++ pkg/sentry/socket/hostinet/stack.go | 11 +++++ pkg/sentry/socket/netstack/stack.go | 14 ++++++ pkg/tcpip/transport/tcp/protocol.go | 33 +++++++++++++ test/syscalls/linux/proc_net.cc | 38 +++++++++++++++ 8 files changed, 268 insertions(+), 1 deletion(-) (limited to 'pkg/sentry/socket/netstack') diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 702fdd392..8615b60f0 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -272,6 +272,96 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque return n, f.tcpSack.stack.SetTCPSACKEnabled(*f.tcpSack.enabled) } +// +stateify savable +type tcpRecovery struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + recovery inet.TCPLossRecovery +} + +func newTCPRecoveryInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ts := &tcpRecovery{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ts, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. +func (*tcpRecovery) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (r *tcpRecovery) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &tcpRecoveryFile{ + tcpRecovery: r, + stack: r.stack, + }), nil +} + +// +stateify savable +type tcpRecoveryFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + tcpRecovery *tcpRecovery + + stack inet.Stack `state:"wait"` +} + +// Read implements fs.FileOperations.Read. +func (f *tcpRecoveryFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + recovery, err := f.stack.TCPRecovery() + if err != nil { + return 0, err + } + f.tcpRecovery.recovery = recovery + s := fmt.Sprintf("%d\n", f.tcpRecovery.recovery) + n, err := dst.CopyOut(ctx, []byte(s)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +func (f *tcpRecoveryFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + f.tcpRecovery.recovery = inet.TCPLossRecovery(v) + if err := f.tcpRecovery.stack.SetTCPRecovery(f.tcpRecovery.recovery); err != nil { + return 0, err + } + return n, nil +} + func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -351,6 +441,11 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine contents["tcp_wmem"] = newTCPMemInode(ctx, msrc, s, tcpWMem) } + // Add tcp_recovery. + if _, err := s.TCPRecovery(); err == nil { + contents["tcp_recovery"] = newTCPRecoveryInode(ctx, msrc, s) + } + d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil) } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 6dac2afa4..b71778128 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -55,7 +55,8 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]*kernfs.Dentry{ "ipv4": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ - "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), + "tcp_recovery": fs.newDentry(root, fs.NextIno(), 0644, &tcpRecoveryData{stack: stack}), + "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -207,3 +208,49 @@ func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset *d.enabled = v != 0 return n, d.stack.SetTCPSACKEnabled(*d.enabled) } + +// tcpRecoveryData implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/tcp_recovery. +// +// +stateify savable +type tcpRecoveryData struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` +} + +var _ vfs.WritableDynamicBytesSource = (*tcpRecoveryData)(nil) + +// Generate implements vfs.DynamicBytesSource. +func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error { + recovery, err := d.stack.TCPRecovery() + if err != nil { + return err + } + + buf.WriteString(fmt.Sprintf("%d\n", recovery)) + return nil +} + +func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit the amount of memory allocated. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + if err := d.stack.SetTCPRecovery(inet.TCPLossRecovery(v)); err != nil { + return 0, err + } + return n, nil +} diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 2916a0644..c0b4831d1 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -56,6 +56,12 @@ type Stack interface { // settings. SetTCPSACKEnabled(enabled bool) error + // TCPRecovery returns the TCP loss detection algorithm. + TCPRecovery() (TCPLossRecovery, error) + + // SetTCPRecovery attempts to change TCP loss detection algorithm. + SetTCPRecovery(recovery TCPLossRecovery) error + // Statistics reports stack statistics. Statistics(stat interface{}, arg string) error @@ -189,3 +195,14 @@ type StatSNMPUDP [8]uint64 // StatSNMPUDPLite describes UdpLite line of /proc/net/snmp. type StatSNMPUDPLite [8]uint64 + +// TCPLossRecovery indicates TCP loss detection and recovery methods to use. +type TCPLossRecovery int32 + +// Loss recovery constants from include/net/tcp.h which are used to set +// /proc/sys/net/ipv4/tcp_recovery. +const ( + TCP_RACK_LOSS_DETECTION TCPLossRecovery = 1 << iota + TCP_RACK_STATIC_REO_WND + TCP_RACK_NO_DUPTHRESH +) diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index d8961fc94..9771f01fc 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -25,6 +25,7 @@ type TestStack struct { TCPRecvBufSize TCPBufferSize TCPSendBufSize TCPBufferSize TCPSACKFlag bool + Recovery TCPLossRecovery } // NewTestStack returns a TestStack with no network interfaces. The value of @@ -91,6 +92,17 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { return nil } +// TCPRecovery implements Stack.TCPRecovery. +func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) { + return s.Recovery, nil +} + +// SetTCPRecovery implements Stack.SetTCPRecovery. +func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error { + s.Recovery = recovery + return nil +} + // Statistics implements inet.Stack.Statistics. func (s *TestStack) Statistics(stat interface{}, arg string) error { return nil diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index a48082631..fda3dcb35 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -53,6 +53,7 @@ type Stack struct { interfaceAddrs map[int32][]inet.InterfaceAddr routes []inet.Route supportsIPv6 bool + tcpRecovery inet.TCPLossRecovery tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool @@ -350,6 +351,16 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + return s.tcpRecovery, nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserror.EACCES +} + // getLine reads one line from proc file, with specified prefix. // The last argument, withHeader, specifies if it contains line header. func getLine(f *os.File, prefix string, withHeader bool) string { diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 67737ae87..f0fe18684 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -207,6 +207,20 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + var recovery tcp.Recovery + if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + return inet.TCPLossRecovery(recovery), nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError() +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { switch stats := stat.(type) { diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b34e47bbd..d9abb8d94 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -80,6 +80,25 @@ const ( // enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool +// Recovery is used by stack.(*Stack).TransportProtocolOption to +// set loss detection algorithm in TCP. +type Recovery int32 + +const ( + // RACKLossDetection indicates RACK is used for loss detection and + // recovery. + RACKLossDetection Recovery = 1 << iota + + // RACKStaticReoWnd indicates the reordering window should not be + // adjusted when DSACK is received. + RACKStaticReoWnd + + // RACKNoDupTh indicates RACK should not consider the classic three + // duplicate acknowledgements rule to mark the segments as lost. This + // is used when reordering is not detected. + RACKNoDupTh +) + // DelayEnabled is used by stack.(Stack*).TransportProtocolOption to // enable/disable Nagle's algorithm in TCP. type DelayEnabled bool @@ -161,6 +180,7 @@ func (s *synRcvdCounter) Threshold() uint64 { type protocol struct { mu sync.RWMutex sackEnabled bool + recovery Recovery delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption @@ -280,6 +300,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case Recovery: + p.mu.Lock() + p.recovery = Recovery(v) + p.mu.Unlock() + return nil + case DelayEnabled: p.mu.Lock() p.delayEnabled = bool(v) @@ -394,6 +420,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *Recovery: + p.mu.RLock() + *v = Recovery(p.recovery) + p.mu.RUnlock() + return nil + case *DelayEnabled: p.mu.RLock() *v = DelayEnabled(p.delayEnabled) @@ -535,6 +567,7 @@ func NewProtocol() stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, + recovery: RACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 3377b65cf..4fab097f4 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -477,6 +477,44 @@ TEST(ProcNetSnmp, CheckSnmp) { EXPECT_EQ(value_count, 1); } +TEST(ProcSysNetIpv4Recovery, Exists) { + EXPECT_THAT(open("/proc/sys/net/ipv4/tcp_recovery", O_RDONLY), + SyscallSucceeds()); +} + +TEST(ProcSysNetIpv4Recovery, CanReadAndWrite) { + // TODO(b/162988252): Enable save/restore for this test after the bug is + // fixed. + DisableSave ds; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/tcp_recovery", O_RDWR)); + + char buf[10] = {'\0'}; + char to_write = '2'; + + // Check initial value is set to 1. + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(to_write) + 1)); + EXPECT_EQ(strcmp(buf, "1\n"), 0); + + // Set tcp_recovery to one of the allowed constants. + EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), + SyscallSucceedsWithValue(sizeof(to_write))); + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(to_write) + 1)); + EXPECT_EQ(strcmp(buf, "2\n"), 0); + + // Set tcp_recovery to any random value. + char kMessage[] = "100"; + EXPECT_THAT(PwriteFd(fd.get(), kMessage, strlen(kMessage), 0), + SyscallSucceedsWithValue(strlen(kMessage))); + EXPECT_THAT(PreadFd(fd.get(), buf, sizeof(kMessage), 0), + SyscallSucceedsWithValue(sizeof(kMessage))); + EXPECT_EQ(strcmp(buf, "100\n"), 0); +} } // namespace } // namespace testing } // namespace gvisor -- cgit v1.2.3