diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/inet/inet.go | 3 | ||||
-rw-r--r-- | pkg/sentry/inet/test_stack.go | 50 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/stack.go | 29 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/route/protocol.go | 43 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/link/fdbased/endpoint.go | 32 | ||||
-rw-r--r-- | pkg/tcpip/link/fdbased/mmap.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/link/fdbased/mmap_unsafe.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/link/fdbased/packet_dispatchers.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/link/qdisc/fifo/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/link/rawfile/rawfile_unsafe.go | 55 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 77 |
15 files changed, 281 insertions, 111 deletions
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 80dda1559..b121fc1b4 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -27,6 +27,9 @@ type Stack interface { // integers. Interfaces() map[int32]Interface + // RemoveInterface removes the specified network interface. + RemoveInterface(idx int32) error + // InterfaceAddrs returns all network interface addresses as a mapping from // interface indexes to a slice of associated interface address properties. InterfaceAddrs() map[int32][]InterfaceAddr diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 218d9dafc..621f47e1f 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -45,23 +45,29 @@ func NewTestStack() *TestStack { } } -// Interfaces implements Stack.Interfaces. +// Interfaces implements Stack. func (s *TestStack) Interfaces() map[int32]Interface { return s.InterfacesMap } -// InterfaceAddrs implements Stack.InterfaceAddrs. +// RemoveInterface implements Stack. +func (s *TestStack) RemoveInterface(idx int32) error { + delete(s.InterfacesMap, idx) + return nil +} + +// InterfaceAddrs implements Stack. func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr { return s.InterfaceAddrsMap } -// AddInterfaceAddr implements Stack.AddInterfaceAddr. +// AddInterfaceAddr implements Stack. func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr) return nil } -// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +// RemoveInterfaceAddr implements Stack. func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { interfaceAddrs, ok := s.InterfaceAddrsMap[idx] if !ok { @@ -79,94 +85,94 @@ func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } -// SupportsIPv6 implements Stack.SupportsIPv6. +// SupportsIPv6 implements Stack. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag } -// TCPReceiveBufferSize implements Stack.TCPReceiveBufferSize. +// TCPReceiveBufferSize implements Stack. func (s *TestStack) TCPReceiveBufferSize() (TCPBufferSize, error) { return s.TCPRecvBufSize, nil } -// SetTCPReceiveBufferSize implements Stack.SetTCPReceiveBufferSize. +// SetTCPReceiveBufferSize implements Stack. func (s *TestStack) SetTCPReceiveBufferSize(size TCPBufferSize) error { s.TCPRecvBufSize = size return nil } -// TCPSendBufferSize implements Stack.TCPSendBufferSize. +// TCPSendBufferSize implements Stack. func (s *TestStack) TCPSendBufferSize() (TCPBufferSize, error) { return s.TCPSendBufSize, nil } -// SetTCPSendBufferSize implements Stack.SetTCPSendBufferSize. +// SetTCPSendBufferSize implements Stack. func (s *TestStack) SetTCPSendBufferSize(size TCPBufferSize) error { s.TCPSendBufSize = size return nil } -// TCPSACKEnabled implements Stack.TCPSACKEnabled. +// TCPSACKEnabled implements Stack. func (s *TestStack) TCPSACKEnabled() (bool, error) { return s.TCPSACKFlag, nil } -// SetTCPSACKEnabled implements Stack.SetTCPSACKEnabled. +// SetTCPSACKEnabled implements Stack. func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { s.TCPSACKFlag = enabled return nil } -// TCPRecovery implements Stack.TCPRecovery. +// TCPRecovery implements Stack. func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) { return s.Recovery, nil } -// SetTCPRecovery implements Stack.SetTCPRecovery. +// SetTCPRecovery implements Stack. func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error { s.Recovery = recovery return nil } -// Statistics implements inet.Stack.Statistics. +// Statistics implements Stack. func (s *TestStack) Statistics(stat interface{}, arg string) error { return nil } -// RouteTable implements Stack.RouteTable. +// RouteTable implements Stack. func (s *TestStack) RouteTable() []Route { return s.RouteList } -// Resume implements Stack.Resume. +// Resume implements Stack. func (s *TestStack) Resume() {} -// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +// RegisteredEndpoints implements Stack. func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } -// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +// CleanupEndpoints implements Stack. func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { return nil } -// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +// RestoreCleanupEndpoints implements Stack. func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} -// SetForwarding implements inet.Stack.SetForwarding. +// SetForwarding implements Stack. func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { s.IPForwarding = enable return nil } -// PortRange implements inet.Stack.PortRange. +// PortRange implements Stack. func (*TestStack) PortRange() (uint16, uint16) { // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). return 32768, 28232 } -// SetPortRange implements inet.Stack.SetPortRange. +// SetPortRange implements Stack. func (*TestStack) SetPortRange(start uint16, end uint16) error { // No-op. return nil diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 7a4e78a5f..61111ac6c 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -309,6 +309,11 @@ func (s *Stack) Interfaces() map[int32]inet.Interface { return interfaces } +// RemoveInterface implements inet.Stack.RemoveInterface. +func (*Stack) RemoveInterface(int32) error { + return linuxerr.EACCES +} + // InterfaceAddrs implements inet.Stack.InterfaceAddrs. func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { addrs := make(map[int32][]inet.InterfaceAddr) @@ -319,12 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { +func (*Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { return linuxerr.EACCES } // RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. -func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { +func (*Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return linuxerr.EACCES } @@ -339,7 +344,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { } // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. -func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { +func (*Stack) SetTCPReceiveBufferSize(inet.TCPBufferSize) error { return linuxerr.EACCES } @@ -349,7 +354,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { } // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. -func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { +func (*Stack) SetTCPSendBufferSize(inet.TCPBufferSize) error { return linuxerr.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(bool) error { +func (*Stack) SetTCPSACKEnabled(bool) error { return linuxerr.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { +func (*Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return linuxerr.EACCES } @@ -470,19 +475,19 @@ func (s *Stack) RouteTable() []inet.Route { } // Resume implements inet.Stack.Resume. -func (s *Stack) Resume() {} +func (*Stack) Resume() {} // RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. -func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } +func (*Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } // CleanupEndpoints implements inet.Stack.CleanupEndpoints. -func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } +func (*Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. -func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} +func (*Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { +func (*Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return linuxerr.EACCES } @@ -493,6 +498,6 @@ func (*Stack) PortRange() (uint16, uint16) { } // SetPortRange implements inet.Stack.SetPortRange. -func (*Stack) SetPortRange(start uint16, end uint16) error { +func (*Stack) SetPortRange(uint16, uint16) error { return linuxerr.EACCES } diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 86f6419dc..d526acb73 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -161,6 +161,47 @@ func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlin return nil } +// delLink handles RTM_DELLINK requests. +func (p *Protocol) delLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifinfomsg linux.InterfaceInfoMessage + attrs, ok := msg.GetData(&ifinfomsg) + if !ok { + return syserr.ErrInvalidArgument + } + if ifinfomsg.Index == 0 { + // The index is unspecified, search by the interface name. + ahdr, value, _, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + switch ahdr.Type { + case linux.IFLA_IFNAME: + if len(value) < 1 { + return syserr.ErrInvalidArgument + } + ifname := string(value[:len(value)-1]) + for idx, ifa := range stack.Interfaces() { + if ifname == ifa.Name { + ifinfomsg.Index = idx + break + } + } + default: + return syserr.ErrInvalidArgument + } + if ifinfomsg.Index == 0 { + return syserr.ErrNoDevice + } + } + return syserr.FromError(stack.RemoveInterface(ifinfomsg.Index)) +} + // addNewLinkMessage appends RTM_NEWLINK message for the given interface into // the message set. func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { @@ -537,6 +578,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms switch hdr.Type { case linux.RTM_GETLINK: return p.getLink(ctx, msg, ms) + case linux.RTM_DELLINK: + return p.delLink(ctx, msg, ms) case linux.RTM_GETROUTE: return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 0fd0ad32c..208ab9909 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -71,6 +71,12 @@ func (s *Stack) Interfaces() map[int32]inet.Interface { return is } +// RemoveInterface implements inet.Stack.RemoveInterface. +func (s *Stack) RemoveInterface(idx int32) error { + nic := tcpip.NICID(idx) + return syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)).ToError() +} + // InterfaceAddrs implements inet.Stack.InterfaceAddrs. func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { nicAddrs := make(map[int32][]inet.InterfaceAddr) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index e8e716db0..48356c343 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -56,6 +56,7 @@ import ( // linkDispatcher reads packets from the link FD and dispatches them to the // NetworkDispatcher. type linkDispatcher interface { + stop() dispatch() (bool, tcpip.Error) } @@ -381,16 +382,27 @@ func isSocketFD(fd int) (bool, error) { // Attach launches the goroutine that reads packets from the file descriptor and // dispatches them via the provided dispatcher. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - // Link endpoints are not savable. When transportation endpoints are - // saved, they stop sending outgoing packets and all incoming packets - // are rejected. - for i := range e.inboundDispatchers { - e.wg.Add(1) - go func(i int) { // S/R-SAFE: See above. - e.dispatchLoop(e.inboundDispatchers[i]) - e.wg.Done() - }(i) + // nil means the NIC is being removed. + if dispatcher == nil && e.dispatcher != nil { + for _, dispatcher := range e.inboundDispatchers { + dispatcher.stop() + } + e.Wait() + e.dispatcher = nil + return + } + if dispatcher != nil && e.dispatcher == nil { + e.dispatcher = dispatcher + // Link endpoints are not savable. When transportation endpoints are + // saved, they stop sending outgoing packets and all incoming packets + // are rejected. + for i := range e.inboundDispatchers { + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) + } } } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index bfae34ab9..3f516cab5 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -114,6 +114,7 @@ func (t tPacketHdr) Payload() []byte { // packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. // See: mmap_amd64_unsafe.go for implementation details. type packetMMapDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -129,18 +130,18 @@ type packetMMapDispatcher struct { ringOffset int } -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, bool, tcpip.Error) { hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) for hdr.tpStatus()&tpStatusUser == 0 { - event := rawfile.PollEvent{ - FD: int32(d.fd), - Events: unix.POLLIN | unix.POLLERR, - } - if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 { + stopped, errno := rawfile.BlockingPollUntilStopped(d.efd, d.fd, unix.POLLIN|unix.POLLERR) + if errno != 0 { if errno == unix.EINTR { continue } - return nil, rawfile.TranslateErrno(errno) + return nil, stopped, rawfile.TranslateErrno(errno) + } + if stopped { + return nil, true, nil } if hdr.tpStatus()&tpStatusCopy != 0 { // This frame is truncated so skip it after flipping the @@ -158,14 +159,14 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { // Release packet to kernel. hdr.setTPStatus(tpStatusKernel) d.ringOffset = (d.ringOffset + 1) % tpFrameNR - return pkt, nil + return pkt, false, nil } // dispatch reads packets from an mmaped ring buffer and dispatches them to the // network stack. func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { - pkt, err := d.readMMappedPacket() - if err != nil { + pkt, stopped, err := d.readMMappedPacket() + if err != nil || stopped { return false, err } var ( diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go index 58d5dfeef..5b786169a 100644 --- a/pkg/tcpip/link/fdbased/mmap_unsafe.go +++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go @@ -47,9 +47,14 @@ func (t tPacketHdr) setTPStatus(status uint32) { } func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFd, err := newStopFd() + if err != nil { + return nil, err + } d := &packetMMapDispatcher{ - fd: fd, - e: e, + stopFd: stopFd, + fd: fd, + e: e, } pageSize := unix.Getpagesize() if tpBlockSize%pageSize != 0 { diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index ab2855a63..fab34c5fa 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -18,6 +18,8 @@ package fdbased import ( + "fmt" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -114,9 +116,36 @@ func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView { return buffer.NewVectorisedView(n, views) } +// stopFd is an eventfd used to signal the stop of a dispatcher. +type stopFd struct { + efd int +} + +func newStopFd() (stopFd, error) { + efd, err := unix.Eventfd(0, unix.EFD_NONBLOCK) + if err != nil { + return stopFd{efd: -1}, fmt.Errorf("failed to create eventfd: %w", err) + } + return stopFd{efd: efd}, nil +} + +// stop writes to the eventfd and notifies the dispatcher to stop. It does not +// block. +func (s *stopFd) stop() { + increment := []byte{1, 0, 0, 0, 0, 0, 0, 0} + if n, err := unix.Write(s.efd, increment); n != len(increment) || err != nil { + // There are two possible errors documented in eventfd(2) for writing: + // 1. We are writing 8 bytes and not 0xffffffffffffff, thus no EINVAL. + // 2. stop is only supposed to be called once, it can't reach the limit, + // thus no EAGAIN. + panic(fmt.Sprintf("write(efd) = (%d, %s), want (%d, nil)", n, err, len(increment))) + } +} + // readVDispatcher uses readv() system call to read inbound packets and // dispatches them. type readVDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -128,7 +157,15 @@ type readVDispatcher struct { } func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { - d := &readVDispatcher{fd: fd, e: e} + stopFd, err := newStopFd() + if err != nil { + return nil, err + } + d := &readVDispatcher{ + stopFd: stopFd, + fd: fd, + e: e, + } skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) return d, nil @@ -136,8 +173,8 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { // dispatch reads one packet from the file descriptor and dispatches it. func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { - n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs()) - if n == 0 || err != nil { + n, err := rawfile.BlockingReadvUntilStopped(d.efd, d.fd, d.buf.nextIovecs()) + if n <= 0 || err != nil { return false, err } @@ -184,6 +221,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { // recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and // dispatches them. type recvMMsgDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -207,7 +245,12 @@ const ( ) func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFd, err := newStopFd() + if err != nil { + return nil, err + } d := &recvMMsgDispatcher{ + stopFd: stopFd, fd: fd, e: e, bufs: make([]*iovecBuffer, MaxMsgsPerRecv), @@ -235,8 +278,8 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { d.msgHdrs[k].Msg.SetIovlen(iovLen) } - nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs) - if err != nil { + nMsgs, err := rawfile.BlockingRecvMMsgUntilStopped(d.efd, d.fd, d.msgHdrs) + if nMsgs == -1 || err != nil { return false, err } // Process each of received packets. diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b1a28491d..40bd5560b 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -115,6 +115,13 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + // nil means the NIC is being removed. + if dispatcher == nil { + e.lower.Attach(nil) + e.Wait() + e.dispatcher = nil + return + } e.dispatcher = dispatcher e.lower.Attach(e) } diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 53448a641..e76fc55b6 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -170,46 +170,63 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) { } } -// BlockingReadv reads from a file descriptor that is set up as non-blocking and -// stores the data in a list of iovecs buffers. If no data is available, it will -// block in a poll() syscall until the file descriptor becomes readable. -func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) { +// BlockingReadvUntilStopped reads from a file descriptor that is set up as +// non-blocking and stores the data in a list of iovecs buffers. If no data is +// available, it will block in a poll() syscall until the file descriptor +// becomes readable or stop is signalled (efd becomes readable). Returns -1 in +// the latter case. +func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip.Error) { for { n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) if e == 0 { return int(n), nil } - event := PollEvent{ - FD: int32(fd), - Events: 1, // POLLIN + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) + if stopped { + return -1, nil } - - _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != unix.EINTR { return 0, TranslateErrno(e) } } } -// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking -// and stores the received messages in a slice of MMsgHdr structures. If no data -// is available, it will block in a poll() syscall until the file descriptor -// becomes readable. -func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { +// BlockingRecvMMsgUntilStopped reads from a file descriptor that is set up as +// non-blocking and stores the received messages in a slice of MMsgHdr +// structures. If no data is available, it will block in a poll() syscall until +// the file descriptor becomes readable or stop is signalled (efd becomes +// readable). Returns -1 in the latter case. +func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { for { n, _, e := unix.RawSyscall6(unix.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0) if e == 0 { return int(n), nil } - event := PollEvent{ - FD: int32(fd), - Events: 1, // POLLIN + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) + if stopped { + return -1, nil } - - if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != unix.EINTR { + if e != 0 && e != unix.EINTR { return 0, TranslateErrno(e) } } } + +// BlockingPollUntilStopped polls for events on fd or until a stop is signalled +// on the event fd efd. Returns true if stopped, i.e., efd has event POLLIN. +func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) { + pevents := [...]PollEvent{ + { + FD: int32(efd), + Events: unix.POLLIN, + }, + { + FD: int32(fd), + Events: events, + }, + } + _, errno := BlockingPoll(&pevents[0], len(pevents), nil) + return pevents[0].Revents&unix.POLLIN != 0, errno +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index c90974693..2257f728e 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -39,7 +39,6 @@ go_test( "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/internal/testutil", diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 4a4448cf9..73407be67 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -32,7 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" @@ -3339,7 +3338,7 @@ func TestCloseLocking(t *testing.T) { defer wg.Done() for i := 0; i < iterations; i++ { - if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil { t.Errorf("CreateNIC(%d, _): %s", nicID2, err) return } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 81fabe29a..c73890c4c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -780,6 +780,9 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { if !ok { return &tcpip.ErrUnknownNICID{} } + if nic.IsLoopback() { + return &tcpip.ErrNotSupported{} + } delete(s.nics, id) // Remove routes in-place. n tracks the number of routes written. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 21951d05a..3089c0ef4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -719,38 +719,59 @@ func TestRemoveUnknownNIC(t *testing.T) { } func TestRemoveNIC(t *testing.T) { - const nicID = 1 + for _, tt := range []struct { + name string + linkep stack.LinkEndpoint + expectErr tcpip.Error + }{ + { + name: "loopback", + linkep: loopback.New(), + expectErr: &tcpip.ErrNotSupported{}, + }, + { + name: "channel", + linkep: channel.New(0, defaultMTU, ""), + expectErr: nil, + }, + } { + t.Run(tt.name, func(t *testing.T) { + const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + e := linkEPWithMockedAttach{ + LinkEndpoint: tt.linkep, + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // NIC should be present in NICInfo and attached to a NetworkDispatcher. - allNICInfo := s.NICInfo() - if _, ok := allNICInfo[nicID]; !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") + } - // Removing a NIC should remove it from NICInfo and e should be detached from - // the NetworkDispatcher. - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - if nicInfo, ok := s.NICInfo()[nicID]; ok { - t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) - } - if e.isAttached() { - t.Error("link endpoint for removed NIC still attached to a network dispatcher") + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want { + t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want) + } + if tt.expectErr == nil { + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) + } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } + } + }) } } |